Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

decode: use exact decoded length rather than estimation #227

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 86 additions & 8 deletions src/decode.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::engine::{general_purpose::STANDARD, DecodeEstimate, Engine};
use crate::engine::Engine;
#[cfg(any(feature = "alloc", feature = "std", test))]
use alloc::vec::Vec;
use core::fmt;
Expand Down Expand Up @@ -89,7 +89,7 @@ impl From<DecodeError> for DecodeSliceError {
#[deprecated(since = "0.21.0", note = "Use Engine::decode")]
#[cfg(any(feature = "alloc", feature = "std", test))]
pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeError> {
STANDARD.decode(input)
crate::engine::general_purpose::STANDARD.decode(input)
}

/// Decode from string reference as octets using the specified [Engine].
Expand Down Expand Up @@ -130,6 +130,73 @@ pub fn decode_engine_slice<E: Engine, T: AsRef<[u8]>>(
engine.decode_slice(input, output)
}

/// Returns the decoded size of the `encoded` input assuming the input is valid
/// base64 string.
///
/// Assumes input is a valid base64-encoded string. Result is unspecified if it
/// isn’t.
///
/// If you don’t need a precise length of the decoded string, you can use
/// [`decoded_len_estimate`] function instead. It’s faster and provides an
/// estimate which is only at most two bytes off from the real length.
///
/// # Examples
///
/// ```
/// use base64::decoded_len;
///
/// assert_eq!(0, decoded_len(b""));
/// assert_eq!(1, decoded_len(b"AA"));
/// assert_eq!(2, decoded_len(b"AAA"));
/// assert_eq!(3, decoded_len(b"AAAA"));
/// assert_eq!(1, decoded_len(b"AA=="));
/// assert_eq!(2, decoded_len(b"AAA="));
/// ```
pub fn decoded_len(encoded: impl AsRef<[u8]>) -> usize {
let encoded = encoded.as_ref();
if encoded.len() < 2 {
return 0;
}
let is_pad = |idx| (encoded[encoded.len() - idx] == b'=') as usize;
let len = encoded.len() - is_pad(1) - is_pad(2);
match len % 4 {
0 => len / 4 * 3,
remainder => len / 4 * 3 + remainder - 1,
}
}

#[test]
fn test_decoded_len() {
for chunks in 0..25 {
let mut input = vec![b'A'; chunks * 4 + 4];
assert_eq!(chunks * 3 + 0, decoded_len(&input[..chunks * 4]));
assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 2]));
assert_eq!(chunks * 3 + 2, decoded_len(&input[..chunks * 4 + 3]));
assert_eq!(chunks * 3 + 3, decoded_len(&input[..chunks * 4 + 4]));

input[chunks * 4 + 3] = b'=';
assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 2]));
assert_eq!(chunks * 3 + 2, decoded_len(&input[..chunks * 4 + 3]));
assert_eq!(chunks * 3 + 2, decoded_len(&input[..chunks * 4 + 4]));
input[chunks * 4 + 2] = b'=';
assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 2]));
assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 3]));
assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 4]));
}

// Mustn’t panic or overflow if given bogus input.
for len in 1..100 {
let mut input = vec![b'A'; len];
let got = decoded_len(&input);
debug_assert!(got <= len);
for padding in 1..=len.min(10) {
input[len - padding] = b'=';
let got = decoded_len(&input);
debug_assert!(got <= len);
}
}
}

/// Returns a conservative estimate of the decoded size of `encoded_len` base64 symbols (rounded up
/// to the next group of 3 decoded bytes).
///
Expand All @@ -141,6 +208,7 @@ pub fn decode_engine_slice<E: Engine, T: AsRef<[u8]>>(
/// ```
/// use base64::decoded_len_estimate;
///
/// assert_eq!(0, decoded_len_estimate(0));
/// assert_eq!(3, decoded_len_estimate(1));
/// assert_eq!(3, decoded_len_estimate(2));
/// assert_eq!(3, decoded_len_estimate(3));
Expand All @@ -149,17 +217,27 @@ pub fn decode_engine_slice<E: Engine, T: AsRef<[u8]>>(
/// assert_eq!(6, decoded_len_estimate(5));
/// ```
pub fn decoded_len_estimate(encoded_len: usize) -> usize {
STANDARD
.internal_decoded_len_estimate(encoded_len)
.decoded_len_estimate()
(encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3
}

#[test]
fn test_decode_len_estimate() {
for chunks in 0..250 {
assert_eq!(chunks * 3, decoded_len_estimate(chunks * 4));
assert_eq!(chunks * 3 + 3, decoded_len_estimate(chunks * 4 + 1));
assert_eq!(chunks * 3 + 3, decoded_len_estimate(chunks * 4 + 2));
assert_eq!(chunks * 3 + 3, decoded_len_estimate(chunks * 4 + 3));
}
// Mustn’t panic or overflow.
assert_eq!(usize::MAX / 4 * 3 + 3, decoded_len_estimate(usize::MAX));
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
alphabet,
engine::{general_purpose, Config, GeneralPurpose},
engine::general_purpose::{NO_PAD, STANDARD},
engine::{Config, GeneralPurpose},
tests::{assert_encode_sanity, random_engine},
};
use rand::{
Expand Down Expand Up @@ -245,7 +323,7 @@ mod tests {

#[test]
fn decode_engine_estimation_works_for_various_lengths() {
let engine = GeneralPurpose::new(&alphabet::STANDARD, general_purpose::NO_PAD);
let engine = GeneralPurpose::new(&crate::alphabet::STANDARD, NO_PAD);
for num_prefix_quads in 0..100 {
for suffix in &["AA", "AAA", "AAAA"] {
let mut prefix = "AAAA".repeat(num_prefix_quads);
Expand Down
25 changes: 8 additions & 17 deletions src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,15 @@ pub(crate) fn encode_with_padding<E: Engine + ?Sized>(
/// input lengths in approximately the top quarter of the range of `usize`.
pub fn encoded_len(bytes_len: usize, padding: bool) -> Option<usize> {
let rem = bytes_len % 3;

let complete_input_chunks = bytes_len / 3;
let complete_chunk_output = complete_input_chunks.checked_mul(4);

if rem > 0 {
if padding {
complete_chunk_output.and_then(|c| c.checked_add(4))
} else {
let encoded_rem = match rem {
1 => 2,
2 => 3,
_ => unreachable!("Impossible remainder"),
};
complete_chunk_output.and_then(|c| c.checked_add(encoded_rem))
}
let chunks = bytes_len / 3 + (rem > 0 && padding) as usize;
let encoded_len = chunks.checked_mul(4)?;
Some(if !padding && rem > 0 {
// This doesn’t overflow. encoded_len is divisible by four thus it’s at
// most usize::MAX - 3. rem ≤ 2 so we’re adding at most three.
encoded_len + rem + 1
} else {
complete_chunk_output
}
encoded_len
})
}

/// Write padding characters.
Expand Down
72 changes: 4 additions & 68 deletions src/engine/general_purpose/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodePaddingMode},
engine::{general_purpose::INVALID_VALUE, DecodePaddingMode},
DecodeError, PAD_BYTE,
};

Expand All @@ -21,30 +21,6 @@ const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
const DECODED_BLOCK_LEN: usize =
CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;

#[doc(hidden)]
pub struct GeneralPurposeEstimate {
/// Total number of decode chunks, including a possibly partial last chunk
num_chunks: usize,
decoded_len_estimate: usize,
}

impl GeneralPurposeEstimate {
pub(crate) fn new(encoded_len: usize) -> Self {
// Formulas that won't overflow
Self {
num_chunks: encoded_len / INPUT_CHUNK_LEN
+ (encoded_len % INPUT_CHUNK_LEN > 0) as usize,
decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3,
}
}
}

impl DecodeEstimate for GeneralPurposeEstimate {
fn decoded_len_estimate(&self) -> usize {
self.decoded_len_estimate
}
}

/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
/// Returns the number of bytes written, or an error.
// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
Expand All @@ -53,12 +29,11 @@ impl DecodeEstimate for GeneralPurposeEstimate {
#[inline]
pub(crate) fn decode_helper(
input: &[u8],
estimate: GeneralPurposeEstimate,
output: &mut [u8],
decode_table: &[u8; 256],
decode_allow_trailing_bits: bool,
padding_mode: DecodePaddingMode,
) -> Result<usize, DecodeError> {
) -> Result<(), DecodeError> {
let remainder_len = input.len() % INPUT_CHUNK_LEN;

// Because the fast decode loop writes in groups of 8 bytes (unrolled to
Expand Down Expand Up @@ -99,7 +74,8 @@ pub(crate) fn decode_helper(
};

// rounded up to include partial chunks
let mut remaining_chunks = estimate.num_chunks;
let mut remaining_chunks =
input.len() / INPUT_CHUNK_LEN + (input.len() % INPUT_CHUNK_LEN > 0) as usize;

let mut input_index = 0;
let mut output_index = 0;
Expand Down Expand Up @@ -340,44 +316,4 @@ mod tests {
decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
}

#[test]
fn estimate_short_lengths() {
for (range, (num_chunks, decoded_len_estimate)) in [
(0..=0, (0, 0)),
(1..=4, (1, 3)),
(5..=8, (1, 6)),
(9..=12, (2, 9)),
(13..=16, (2, 12)),
(17..=20, (3, 15)),
] {
for encoded_len in range {
let estimate = GeneralPurposeEstimate::new(encoded_len);
assert_eq!(num_chunks, estimate.num_chunks);
assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate);
}
}
}

#[test]
fn estimate_via_u128_inflation() {
// cover both ends of usize
(0..1000)
.chain(usize::MAX - 1000..=usize::MAX)
.for_each(|encoded_len| {
// inflate to 128 bit type to be able to safely use the easy formulas
let len_128 = encoded_len as u128;

let estimate = GeneralPurposeEstimate::new(encoded_len);
assert_eq!(
((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128))
as usize,
estimate.num_chunks
);
assert_eq!(
((len_128 + 3) / 4 * 3) as usize,
estimate.decoded_len_estimate
);
})
}
}
10 changes: 6 additions & 4 deletions src/engine/general_purpose/decode_suffix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use crate::{
/// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided
/// parameters.
///
/// Returns the total number of bytes decoded, including the ones indicated as already written by
/// `output_index`.
/// Expects output to be large enough to fit decoded data exactly without any
/// unused space. In debug builds panics if final output length (`output_index`
/// plus any bytes written by this function) doesn’t equal length of the output.
pub(crate) fn decode_suffix(
input: &[u8],
input_index: usize,
Expand All @@ -16,7 +17,7 @@ pub(crate) fn decode_suffix(
decode_table: &[u8; 256],
decode_allow_trailing_bits: bool,
padding_mode: DecodePaddingMode,
) -> Result<usize, DecodeError> {
) -> Result<(), DecodeError> {
// Decode any leftovers that aren't a complete input block of 8 bytes.
// Use a u64 as a stack-resident 8 byte buffer.
let mut leftover_bits: u64 = 0;
Expand Down Expand Up @@ -157,5 +158,6 @@ pub(crate) fn decode_suffix(
leftover_bits_appended_to_buf += 8;
}

Ok(output_index)
debug_assert_eq!(output.len(), output_index);
Ok(())
}
14 changes: 1 addition & 13 deletions src/engine/general_purpose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use core::convert::TryInto;

mod decode;
pub(crate) mod decode_suffix;
pub use decode::GeneralPurposeEstimate;

pub(crate) const INVALID_VALUE: u8 = 255;

Expand Down Expand Up @@ -40,7 +39,6 @@ impl GeneralPurpose {

impl super::Engine for GeneralPurpose {
type Config = GeneralPurposeConfig;
type DecodeEstimate = GeneralPurposeEstimate;

fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize {
let mut input_index: usize = 0;
Expand Down Expand Up @@ -161,19 +159,9 @@ impl super::Engine for GeneralPurpose {
output_index
}

fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate {
GeneralPurposeEstimate::new(input_len)
}

fn internal_decode(
&self,
input: &[u8],
output: &mut [u8],
estimate: Self::DecodeEstimate,
) -> Result<usize, DecodeError> {
fn internal_decode(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecodeError> {
decode::decode_helper(
input,
estimate,
output,
&self.decode_table,
self.config.decode_allow_trailing_bits,
Expand Down
Loading