use compile_fmt::{clip_ascii, compile_assert, compile_panic, fmt, Ascii};
use crate::wrappers::{SkipWhitespace, Skipper};
macro_rules! const_try {
($result:expr) => {
match $result {
Ok(value) => value,
Err(err) => return Err(err),
}
};
}
#[derive(Debug)]
struct DecodeError {
invalid_char: u8,
alphabet: Option<Ascii<'static>>,
}
impl DecodeError {
const fn invalid_char(invalid_char: u8, alphabet: Option<Ascii<'static>>) -> Self {
Self {
invalid_char,
alphabet,
}
}
const fn panic(self, input_pos: usize) -> ! {
if self.invalid_char.is_ascii() {
if let Some(alphabet) = self.alphabet {
compile_panic!(
"Character '", self.invalid_char as char => fmt::<char>(), "' at position ",
input_pos => fmt::<usize>(), " is not a part of \
the decoder alphabet '", alphabet => clip_ascii(64, ""), "'"
);
} else {
compile_panic!(
"Character '", self.invalid_char as char => fmt::<char>(), "' at position ",
input_pos => fmt::<usize>(), " is not a hex digit"
);
}
} else {
compile_panic!(
"Non-ASCII character with decimal code ", self.invalid_char => fmt::<u8>(),
" encountered at position ", input_pos => fmt::<usize>()
);
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Encoding {
alphabet: Ascii<'static>,
table: [u8; 128],
bits_per_char: u8,
}
impl Encoding {
const NO_MAPPING: u8 = u8::MAX;
const BASE64: Self =
Self::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/");
const BASE64_URL: Self =
Self::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_");
#[allow(clippy::cast_possible_truncation)]
pub const fn new(alphabet: &'static str) -> Self {
let bits_per_char = match alphabet.len() {
2 => 1,
4 => 2,
8 => 3,
16 => 4,
32 => 5,
64 => 6,
other => compile_panic!(
"Invalid alphabet length ", other => fmt::<usize>(),
"; must be one of 2, 4, 8, 16, 32, or 64"
),
};
let mut table = [Self::NO_MAPPING; 128];
let alphabet_bytes = alphabet.as_bytes();
let alphabet = Ascii::new(alphabet); let mut index = 0;
while index < alphabet_bytes.len() {
let byte = alphabet_bytes[index];
let byte_idx = byte as usize;
compile_assert!(
table[byte_idx] == Self::NO_MAPPING,
"Alphabet character '", byte as char => fmt::<char>(), "' is mentioned several times"
);
table[byte_idx] = index as u8;
index += 1;
}
Self {
alphabet,
table,
bits_per_char,
}
}
const fn lookup(&self, ascii_char: u8) -> Result<u8, DecodeError> {
if !ascii_char.is_ascii() {
return Err(DecodeError::invalid_char(ascii_char, Some(self.alphabet)));
}
let mapping = self.table[ascii_char as usize];
if mapping == Self::NO_MAPPING {
Err(DecodeError::invalid_char(ascii_char, Some(self.alphabet)))
} else {
Ok(mapping)
}
}
}
#[derive(Debug, Clone, Copy)]
struct HexDecoderState(Option<u8>);
impl HexDecoderState {
const fn byte_value(val: u8) -> Result<u8, DecodeError> {
Ok(match val {
b'0'..=b'9' => val - b'0',
b'A'..=b'F' => val - b'A' + 10,
b'a'..=b'f' => val - b'a' + 10,
_ => return Err(DecodeError::invalid_char(val, None)),
})
}
const fn new() -> Self {
Self(None)
}
#[allow(clippy::option_if_let_else)] const fn update(mut self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
let byte = const_try!(Self::byte_value(byte));
let output = if let Some(b) = self.0 {
self.0 = None;
Some((b << 4) + byte)
} else {
self.0 = Some(byte);
None
};
Ok((self, output))
}
const fn is_final(self) -> bool {
self.0.is_none()
}
}
#[derive(Debug, Clone, Copy)]
struct CustomDecoderState {
table: Encoding,
partial_byte: u8,
filled_bits: u8,
}
impl CustomDecoderState {
const fn new(table: Encoding) -> Self {
Self {
table,
partial_byte: 0,
filled_bits: 0,
}
}
#[allow(clippy::comparison_chain)] const fn update(mut self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
let byte = const_try!(self.table.lookup(byte));
let output = if self.filled_bits < 8 - self.table.bits_per_char {
self.partial_byte = (self.partial_byte << self.table.bits_per_char) + byte;
self.filled_bits += self.table.bits_per_char;
None
} else if self.filled_bits == 8 - self.table.bits_per_char {
let output = (self.partial_byte << self.table.bits_per_char) + byte;
self.partial_byte = 0;
self.filled_bits = 0;
Some(output)
} else {
let remaining_bits = 8 - self.filled_bits;
let new_filled_bits = self.table.bits_per_char - remaining_bits;
let output = (self.partial_byte << remaining_bits) + (byte >> new_filled_bits);
self.partial_byte = byte % (1 << new_filled_bits);
self.filled_bits = new_filled_bits;
Some(output)
};
Ok((self, output))
}
const fn is_final(&self) -> bool {
self.partial_byte == 0
}
}
#[derive(Debug, Clone, Copy)]
enum DecoderState {
Hex(HexDecoderState),
Base64(CustomDecoderState),
Custom(CustomDecoderState),
}
impl DecoderState {
const fn update(self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
Ok(match self {
Self::Hex(state) => {
let (updated_state, output) = const_try!(state.update(byte));
(Self::Hex(updated_state), output)
}
Self::Base64(state) => {
if byte == b'=' {
(self, None)
} else {
let (updated_state, output) = const_try!(state.update(byte));
(Self::Base64(updated_state), output)
}
}
Self::Custom(state) => {
let (updated_state, output) = const_try!(state.update(byte));
(Self::Custom(updated_state), output)
}
})
}
const fn is_final(&self) -> bool {
match self {
Self::Hex(state) => state.is_final(),
Self::Base64(state) | Self::Custom(state) => state.is_final(),
}
}
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum Decoder {
Hex,
Base64,
Base64Url,
Custom(Encoding),
}
impl Decoder {
pub const fn custom(alphabet: &'static str) -> Self {
Self::Custom(Encoding::new(alphabet))
}
pub const fn skip_whitespace(self) -> SkipWhitespace {
SkipWhitespace(self)
}
const fn new_state(self) -> DecoderState {
match self {
Self::Hex => DecoderState::Hex(HexDecoderState::new()),
Self::Base64 => DecoderState::Base64(CustomDecoderState::new(Encoding::BASE64)),
Self::Base64Url => DecoderState::Base64(CustomDecoderState::new(Encoding::BASE64_URL)),
Self::Custom(encoding) => DecoderState::Custom(CustomDecoderState::new(encoding)),
}
}
pub const fn decode<const N: usize>(self, input: &[u8]) -> [u8; N] {
self.do_decode(input, None)
}
pub(crate) const fn do_decode<const N: usize>(
self,
input: &[u8],
skipper: Option<Skipper>,
) -> [u8; N] {
let mut bytes = [0_u8; N];
let mut in_index = 0;
let mut out_index = 0;
let mut state = self.new_state();
while in_index < input.len() {
if let Some(skipper) = skipper {
let new_in_index = skipper.skip(input, in_index);
if new_in_index != in_index {
in_index = new_in_index;
continue;
}
}
let update = match state.update(input[in_index]) {
Ok(update) => update,
Err(err) => err.panic(in_index),
};
state = update.0;
if let Some(byte) = update.1 {
if out_index < N {
bytes[out_index] = byte;
}
out_index += 1;
}
in_index += 1;
}
compile_assert!(
out_index <= N,
"Output overflow: the input decodes to ", out_index => fmt::<usize>(),
" bytes, while type inference implies ", N => fmt::<usize>(), ". \
Either fix the input or change the output buffer length correspondingly"
);
compile_assert!(
out_index == N,
"Output underflow: the input decodes to ", out_index => fmt::<usize>(),
" bytes, while type inference implies ", N => fmt::<usize>(), ". \
Either fix the input or change the output buffer length correspondingly"
);
assert!(
state.is_final(),
"Left-over state after processing input. This usually means that the input \
is incorrect (e.g., an odd number of hex digits)."
);
bytes
}
pub(crate) const fn do_decode_len(self, input: &[u8], skipper: Option<Skipper>) -> usize {
let mut in_index = 0;
let mut out_index = 0;
let mut state = self.new_state();
while in_index < input.len() {
if let Some(skipper) = skipper {
let new_in_index = skipper.skip(input, in_index);
if new_in_index != in_index {
in_index = new_in_index;
continue;
}
}
let update = match state.update(input[in_index]) {
Ok(update) => update,
Err(err) => err.panic(in_index),
};
state = update.0;
if update.1.is_some() {
out_index += 1;
}
in_index += 1;
}
out_index
}
}