From 551539c7877f1a8925d0fcf0353926b35f486243 Mon Sep 17 00:00:00 2001 From: rkuklik Date: Fri, 22 Nov 2024 13:27:33 +0100 Subject: [PATCH 1/2] feat(write): remove `StrConsumer` and wrap `fmt` --- src/write/encoder_string_writer.rs | 207 ----------------------------- src/write/encoder_utf8.rs | 187 ++++++++++++++++++++++++++ src/write/mod.rs | 8 +- 3 files changed, 190 insertions(+), 212 deletions(-) delete mode 100644 src/write/encoder_string_writer.rs create mode 100644 src/write/encoder_utf8.rs diff --git a/src/write/encoder_string_writer.rs b/src/write/encoder_string_writer.rs deleted file mode 100644 index 83d4082..0000000 --- a/src/write/encoder_string_writer.rs +++ /dev/null @@ -1,207 +0,0 @@ -use super::encoder::EncoderWriter; -use crate::engine::Engine; -use std::io; - -/// A `Write` implementation that base64-encodes data using the provided config and accumulates the -/// resulting base64 utf8 `&str` in a [`StrConsumer`] implementation (typically `String`), which is -/// then exposed via `into_inner()`. -/// -/// # Examples -/// -/// Buffer base64 in a new String: -/// -/// ``` -/// use std::io::Write; -/// use base64::engine::general_purpose; -/// -/// let mut enc = base64::write::EncoderStringWriter::new(&general_purpose::STANDARD); -/// -/// enc.write_all(b"asdf").unwrap(); -/// -/// // get the resulting String -/// let b64_string = enc.into_inner(); -/// -/// assert_eq!("YXNkZg==", &b64_string); -/// ``` -/// -/// Or, append to an existing `String`, which implements `StrConsumer`: -/// -/// ``` -/// use std::io::Write; -/// use base64::engine::general_purpose; -/// -/// let mut buf = String::from("base64: "); -/// -/// let mut enc = base64::write::EncoderStringWriter::from_consumer( -/// &mut buf, -/// &general_purpose::STANDARD); -/// -/// enc.write_all(b"asdf").unwrap(); -/// -/// // release the &mut reference on buf -/// let _ = enc.into_inner(); -/// -/// assert_eq!("base64: YXNkZg==", &buf); -/// ``` -/// -/// # Performance -/// -/// Because it has to validate that the base64 is UTF-8, it is about 80% as fast as writing plain -/// bytes to a `io::Write`. -pub struct EncoderStringWriter<'e, E: Engine, S: StrConsumer> { - encoder: EncoderWriter<'e, E, Utf8SingleCodeUnitWriter>, -} - -impl<'e, E: Engine, S: StrConsumer> EncoderStringWriter<'e, E, S> { - /// Create a `EncoderStringWriter` that will append to the provided `StrConsumer`. - pub fn from_consumer(str_consumer: S, engine: &'e E) -> Self { - EncoderStringWriter { - encoder: EncoderWriter::new(Utf8SingleCodeUnitWriter { str_consumer }, engine), - } - } - - /// Encode all remaining buffered data, including any trailing incomplete input triples and - /// associated padding. - /// - /// Returns the base64-encoded form of the accumulated written data. - pub fn into_inner(mut self) -> S { - self.encoder - .finish() - .expect("Writing to a consumer should never fail") - .str_consumer - } -} - -impl<'e, E: Engine> EncoderStringWriter<'e, E, String> { - /// Create a `EncoderStringWriter` that will encode into a new `String` with the provided config. - pub fn new(engine: &'e E) -> Self { - EncoderStringWriter::from_consumer(String::new(), engine) - } -} - -impl<'e, E: Engine, S: StrConsumer> io::Write for EncoderStringWriter<'e, E, S> { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.encoder.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.encoder.flush() - } -} - -/// An abstraction around consuming `str`s produced by base64 encoding. -pub trait StrConsumer { - /// Consume the base64 encoded data in `buf` - fn consume(&mut self, buf: &str); -} - -/// As for `io::Write`, `StrConsumer` is implemented automatically for `&mut S`. -impl StrConsumer for &mut S { - fn consume(&mut self, buf: &str) { - (**self).consume(buf); - } -} - -/// Pushes the str onto the end of the String -impl StrConsumer for String { - fn consume(&mut self, buf: &str) { - self.push_str(buf); - } -} - -/// A `Write` that only can handle bytes that are valid single-byte UTF-8 code units. -/// -/// This is safe because we only use it when writing base64, which is always valid UTF-8. -struct Utf8SingleCodeUnitWriter { - str_consumer: S, -} - -impl io::Write for Utf8SingleCodeUnitWriter { - fn write(&mut self, buf: &[u8]) -> io::Result { - // Because we expect all input to be valid utf-8 individual bytes, we can encode any buffer - // length - let s = std::str::from_utf8(buf).expect("Input must be valid UTF-8"); - - self.str_consumer.consume(s); - - Ok(buf.len()) - } - - fn flush(&mut self) -> io::Result<()> { - // no op - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - engine::Engine, tests::random_engine, write::encoder_string_writer::EncoderStringWriter, - }; - use rand::Rng; - use std::cmp; - use std::io::Write; - - #[test] - fn every_possible_split_of_input() { - let mut rng = rand::thread_rng(); - let mut orig_data = Vec::::new(); - let mut normal_encoded = String::new(); - - let size = 5_000; - - for i in 0..size { - orig_data.clear(); - normal_encoded.clear(); - - orig_data.resize(size, 0); - rng.fill(&mut orig_data[..]); - - let engine = random_engine(&mut rng); - engine.encode_string(&orig_data, &mut normal_encoded); - - let mut stream_encoder = EncoderStringWriter::new(&engine); - // Write the first i bytes, then the rest - stream_encoder.write_all(&orig_data[0..i]).unwrap(); - stream_encoder.write_all(&orig_data[i..]).unwrap(); - - let stream_encoded = stream_encoder.into_inner(); - - assert_eq!(normal_encoded, stream_encoded); - } - } - #[test] - fn incremental_writes() { - let mut rng = rand::thread_rng(); - let mut orig_data = Vec::::new(); - let mut normal_encoded = String::new(); - - let size = 5_000; - - for _ in 0..size { - orig_data.clear(); - normal_encoded.clear(); - - orig_data.resize(size, 0); - rng.fill(&mut orig_data[..]); - - let engine = random_engine(&mut rng); - engine.encode_string(&orig_data, &mut normal_encoded); - - let mut stream_encoder = EncoderStringWriter::new(&engine); - // write small nibbles of data - let mut offset = 0; - while offset < size { - let nibble_size = cmp::min(rng.gen_range(0..=64), size - offset); - let len = stream_encoder - .write(&orig_data[offset..offset + nibble_size]) - .unwrap(); - offset += len; - } - - let stream_encoded = stream_encoder.into_inner(); - - assert_eq!(normal_encoded, stream_encoded); - } - } -} diff --git a/src/write/encoder_utf8.rs b/src/write/encoder_utf8.rs new file mode 100644 index 0000000..a40ee06 --- /dev/null +++ b/src/write/encoder_utf8.rs @@ -0,0 +1,187 @@ +use std::fmt; +use std::io; +use std::str::from_utf8; + +use crate::engine::Engine; +use crate::write::EncoderWriter; + +/// A [`io::Write`] wrapper for types implementing [`fmt::Write`] +/// +/// It needn't be used directly, but as a parameter for [`EncoderWriter`]. +/// +/// # Examples +/// +/// Write base64 into a new [`String`]: +/// +/// ``` +/// use std::io::Write; +/// use base64::engine::general_purpose; +/// +/// let mut enc = base64::write::EncoderWriter::string(&general_purpose::STANDARD); +/// +/// enc.write_all(b"asdf").unwrap(); +/// +/// // get the resulting String +/// let b64_string = enc.formatter(); +/// +/// assert_eq!("YXNkZg==", &b64_string); +/// ``` +/// +/// Or, append to an existing [`String`], which implements [`fmt::Write`]: +/// +/// ``` +/// use std::io::Write; +/// use base64::engine::general_purpose; +/// +/// let mut buf = String::from("base64: "); +/// +/// let mut enc = base64::write::EncoderWriter::utf8( +/// &mut buf, +/// &general_purpose::STANDARD); +/// +/// enc.write_all(b"asdf").unwrap(); +/// +/// // release the &mut reference on buf +/// let _ = enc.formatter(); +/// +/// assert_eq!("base64: YXNkZg==", &buf); +/// ``` +/// +/// # Performance +/// +/// Because it has to validate that the base64 is UTF-8, it is about 80% as fast as writing plain +/// bytes to a `io::Write`. +pub struct Utf8Compat { + inner: W, +} + +impl Utf8Compat { + /// Create wrapper implementing [`io::Write`] for [`fmt::Write`] + pub fn new(writer: W) -> Self { + Self { inner: writer } + } + + /// Extract the underlying writer + pub fn writer(self) -> W { + self.inner + } +} + +impl From for Utf8Compat { + fn from(value: W) -> Self { + Self::new(value) + } +} + +impl io::Write for Utf8Compat { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner + .write_str(from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?) + .map_err(io::Error::other) + .map(|()| buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl<'e, E: Engine, W: fmt::Write> EncoderWriter<'e, E, Utf8Compat> { + /// Create a [`EncoderWriter`] that will write [`&str`] + pub fn utf8(writer: W, engine: &'e E) -> Self { + Self::new(Utf8Compat::new(writer), engine) + } + + /// Encode all remaining buffered data, including any trailing incomplete input triples and + /// associated padding. + /// + /// Returns the base64-encoded form of the accumulated written data. + pub fn formatter(mut self) -> W { + self.finish() + .expect("Writing to a consumer should never fail") + .writer() + } +} + +impl<'e, E: Engine> EncoderWriter<'e, E, Utf8Compat> { + /// Create [`EncoderWriter`] writing to [`String`] + pub fn string(engine: &'e E) -> Self { + EncoderWriter::utf8(String::new(), engine) + } +} + +#[cfg(test)] +mod tests { + use std::cmp; + use std::io::Write; + + use rand::Rng; + + use crate::engine::Engine; + use crate::tests::random_engine; + use crate::write::EncoderWriter; + + #[test] + fn every_possible_split_of_input() { + let mut rng = rand::thread_rng(); + let mut orig_data = Vec::::new(); + let mut normal_encoded = String::new(); + + let size = 5_000; + + for i in 0..size { + orig_data.clear(); + normal_encoded.clear(); + + orig_data.resize(size, 0); + rng.fill(&mut orig_data[..]); + + let engine = random_engine(&mut rng); + engine.encode_string(&orig_data, &mut normal_encoded); + + let mut stream_encoder = EncoderWriter::string(&engine); + // Write the first i bytes, then the rest + stream_encoder.write_all(&orig_data[0..i]).unwrap(); + stream_encoder.write_all(&orig_data[i..]).unwrap(); + + let stream_encoded = stream_encoder.formatter(); + + assert_eq!(normal_encoded, stream_encoded); + } + } + + #[test] + fn incremental_writes() { + let mut rng = rand::thread_rng(); + let mut orig_data = Vec::::new(); + let mut normal_encoded = String::new(); + + let size = 5_000; + + for _ in 0..size { + orig_data.clear(); + normal_encoded.clear(); + + orig_data.resize(size, 0); + rng.fill(&mut orig_data[..]); + + let engine = random_engine(&mut rng); + engine.encode_string(&orig_data, &mut normal_encoded); + + let mut stream_encoder = EncoderWriter::string(&engine); + // write small nibbles of data + let mut offset = 0; + while offset < size { + let nibble_size = cmp::min(rng.gen_range(0..=64), size - offset); + let len = stream_encoder + .write(&orig_data[offset..offset + nibble_size]) + .unwrap(); + offset += len; + } + + let stream_encoded = stream_encoder.formatter(); + + assert_eq!(normal_encoded, stream_encoded); + } + } +} diff --git a/src/write/mod.rs b/src/write/mod.rs index 2a617db..55d75fd 100644 --- a/src/write/mod.rs +++ b/src/write/mod.rs @@ -1,11 +1,9 @@ //! Implementations of `io::Write` to transparently handle base64. mod encoder; -mod encoder_string_writer; +mod encoder_utf8; -pub use self::{ - encoder::EncoderWriter, - encoder_string_writer::{EncoderStringWriter, StrConsumer}, -}; +pub use self::encoder::EncoderWriter; +pub use self::encoder_utf8::Utf8Compat; #[cfg(test)] mod encoder_tests; From 7c1ca93103e719510bfe99af4cfeac0a55e4f3da Mon Sep 17 00:00:00 2001 From: rkuklik Date: Fri, 22 Nov 2024 13:36:11 +0100 Subject: [PATCH 2/2] fix(write): benchmarks and msrv --- benches/benchmarks.rs | 8 ++++---- src/write/encoder_utf8.rs | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs index 8f04185..60a33d2 100644 --- a/benches/benchmarks.rs +++ b/benches/benchmarks.rs @@ -117,10 +117,10 @@ fn do_encode_bench_string_stream(b: &mut Bencher, &size: &usize) { fill(&mut v); b.iter(|| { - let mut stream_enc = write::EncoderStringWriter::new(&STANDARD); + let mut stream_enc = write::EncoderWriter::string(&STANDARD); stream_enc.write_all(&v).unwrap(); stream_enc.flush().unwrap(); - let _ = stream_enc.into_inner(); + let _ = stream_enc.formatter(); }); } @@ -131,10 +131,10 @@ fn do_encode_bench_string_reuse_buf_stream(b: &mut Bencher, &size: &usize) { let mut buf = String::new(); b.iter(|| { buf.clear(); - let mut stream_enc = write::EncoderStringWriter::from_consumer(&mut buf, &STANDARD); + let mut stream_enc = write::EncoderWriter::utf8(&mut buf, &STANDARD); stream_enc.write_all(&v).unwrap(); stream_enc.flush().unwrap(); - let _ = stream_enc.into_inner(); + let _ = stream_enc.formatter(); }); } diff --git a/src/write/encoder_utf8.rs b/src/write/encoder_utf8.rs index a40ee06..f82e40b 100644 --- a/src/write/encoder_utf8.rs +++ b/src/write/encoder_utf8.rs @@ -77,7 +77,7 @@ impl io::Write for Utf8Compat { fn write(&mut self, buf: &[u8]) -> io::Result { self.inner .write_str(from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?) - .map_err(io::Error::other) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) .map(|()| buf.len()) }