diff --git a/Cargo.lock b/Cargo.lock index bd18f4c592f4..cd3efe3ae2d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1098,9 +1098,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cacache" @@ -3900,7 +3900,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -5218,7 +5218,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", - "heck 0.4.1", + "heck 0.5.0", "itertools 0.13.0", "log", "multimap", @@ -6092,6 +6092,7 @@ name = "re_log_encoding" version = "0.22.0-alpha.1+dev" dependencies = [ "arrow", + "bytes", "criterion", "ehttp", "js-sys", @@ -6111,6 +6112,8 @@ dependencies = [ "serde_test", "similar-asserts", "thiserror 1.0.65", + "tokio", + "tokio-stream", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", diff --git a/Cargo.toml b/Cargo.toml index ae3606935004..a454ba22a3b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -162,6 +162,7 @@ bit-vec = "0.8" bitflags = { version = "2.4", features = ["bytemuck"] } blackbox = "0.2.0" bytemuck = { version = "1.18", features = ["extern_crate_alloc"] } +bytes = "1.0" camino = "1.1" cargo_metadata = "0.18" cargo-run-wasm = "0.3.2" diff --git a/crates/store/re_log_encoding/Cargo.toml b/crates/store/re_log_encoding/Cargo.toml index abe9378978f3..329a1a8e70a0 100644 --- a/crates/store/re_log_encoding/Cargo.toml +++ b/crates/store/re_log_encoding/Cargo.toml @@ -23,7 +23,14 @@ all-features = true default = [] ## Enable loading data from an .rrd file. -decoder = ["dep:rmp-serde", "dep:lz4_flex", "re_log_types/serde"] +decoder = [ + "dep:rmp-serde", + "dep:lz4_flex", + "re_log_types/serde", + "dep:tokio", + "dep:tokio-stream", + "dep:bytes", +] ## Enable encoding of log messages to an .rrd file/stream. encoder = ["dep:rmp-serde", "dep:lz4_flex", "re_log_types/serde"] @@ -57,9 +64,12 @@ parking_lot.workspace = true thiserror.workspace = true # Optional external dependencies: +bytes = { workspace = true, optional = true } ehttp = { workspace = true, optional = true, features = ["streaming"] } lz4_flex = { workspace = true, optional = true } rmp-serde = { workspace = true, optional = true } +tokio = { workspace = true, optional = true, features = ["io-util"] } +tokio-stream = { workspace = true, optional = true } web-time = { workspace = true, optional = true } # Web dependencies: diff --git a/crates/store/re_log_encoding/src/codec/file/decoder.rs b/crates/store/re_log_encoding/src/codec/file/decoder.rs index a9fe39e652b2..a8721fd4882d 100644 --- a/crates/store/re_log_encoding/src/codec/file/decoder.rs +++ b/crates/store/re_log_encoding/src/codec/file/decoder.rs @@ -6,9 +6,6 @@ use re_log_types::LogMsg; use re_protos::missing_field; pub(crate) fn decode(data: &mut impl std::io::Read) -> Result<(u64, Option), DecodeError> { - use re_protos::external::prost::Message; - use re_protos::log_msg::v0::{ArrowMsg, BlueprintActivationCommand, Encoding, SetStoreInfo}; - let mut read_bytes = 0u64; let header = MessageHeader::decode(data)?; read_bytes += std::mem::size_of::() as u64 + header.len; @@ -16,13 +13,22 @@ pub(crate) fn decode(data: &mut impl std::io::Read) -> Result<(u64, Option Result, DecodeError> { + use re_protos::external::prost::Message; + use re_protos::log_msg::v0::{ArrowMsg, BlueprintActivationCommand, Encoding, SetStoreInfo}; + + let msg = match message_kind { MessageKind::SetStoreInfo => { - let set_store_info = SetStoreInfo::decode(&buf[..])?; + let set_store_info = SetStoreInfo::decode(buf)?; Some(LogMsg::SetStoreInfo(set_store_info.try_into()?)) } MessageKind::ArrowMsg => { - let arrow_msg = ArrowMsg::decode(&buf[..])?; + let arrow_msg = ArrowMsg::decode(buf)?; if arrow_msg.encoding() != Encoding::ArrowIpc { return Err(DecodeError::Codec(CodecError::UnsupportedEncoding)); } @@ -43,7 +49,7 @@ pub(crate) fn decode(data: &mut impl std::io::Read) -> Result<(u64, Option { - let blueprint_activation_command = BlueprintActivationCommand::decode(&buf[..])?; + let blueprint_activation_command = BlueprintActivationCommand::decode(buf)?; Some(LogMsg::BlueprintActivationCommand( blueprint_activation_command.try_into()?, )) @@ -51,5 +57,5 @@ pub(crate) fn decode(data: &mut impl std::io::Read) -> Result<(u64, Option None, }; - Ok((read_bytes, msg)) + Ok(msg) } diff --git a/crates/store/re_log_encoding/src/codec/file/mod.rs b/crates/store/re_log_encoding/src/codec/file/mod.rs index e491777409ca..3bc5276c1a33 100644 --- a/crates/store/re_log_encoding/src/codec/file/mod.rs +++ b/crates/store/re_log_encoding/src/codec/file/mod.rs @@ -48,6 +48,11 @@ impl MessageHeader { let mut buf = [0; std::mem::size_of::()]; data.read_exact(&mut buf)?; + Self::from_bytes(&buf) + } + + #[cfg(feature = "decoder")] + pub fn from_bytes(buf: &[u8]) -> Result { #[allow(clippy::unwrap_used)] // cannot fail let kind = u64::from_le_bytes(buf[0..8].try_into().unwrap()); let kind = match kind { diff --git a/crates/store/re_log_encoding/src/decoder/mod.rs b/crates/store/re_log_encoding/src/decoder/mod.rs index 0c10b855b9fc..80a076bab187 100644 --- a/crates/store/re_log_encoding/src/decoder/mod.rs +++ b/crates/store/re_log_encoding/src/decoder/mod.rs @@ -1,6 +1,8 @@ //! Decoding [`LogMsg`]:es from `.rrd` files/streams. pub mod stream; +#[cfg(feature = "decoder")] +pub mod streaming; use std::io::BufRead as _; use std::io::Read; @@ -412,14 +414,14 @@ mod tests { }; // TODO(#3741): remove this once we are all in on arrow-rs - fn strip_arrow_extensions_from_log_messages(log_msg: Vec) -> Vec { + pub fn strip_arrow_extensions_from_log_messages(log_msg: Vec) -> Vec { log_msg .into_iter() .map(LogMsg::strip_arrow_extension_types) .collect() } - fn fake_log_messages() -> Vec { + pub fn fake_log_messages() -> Vec { let store_id = StoreId::random(StoreKind::Blueprint); let arrow_msg = re_chunk::Chunk::builder("test_entity".into()) @@ -527,8 +529,6 @@ mod tests { ]; for options in options { - println!("{options:?}"); - let mut data = vec![]; // write "2 files" i.e. 2 streams that end with end-of-stream marker diff --git a/crates/store/re_log_encoding/src/decoder/streaming.rs b/crates/store/re_log_encoding/src/decoder/streaming.rs new file mode 100644 index 000000000000..aca4093d2bac --- /dev/null +++ b/crates/store/re_log_encoding/src/decoder/streaming.rs @@ -0,0 +1,377 @@ +use std::pin::Pin; + +use bytes::{Buf, BytesMut}; +use re_build_info::CrateVersion; +use re_log::external::log::warn; +use re_log_types::LogMsg; +use tokio::io::{AsyncBufRead, AsyncReadExt}; +use tokio_stream::Stream; + +use crate::{ + codec::file::{self}, + Compression, EncodingOptions, VersionPolicy, +}; + +use super::{read_options, DecodeError, FileHeader}; + +pub struct StreamingDecoder { + version: CrateVersion, + options: EncodingOptions, + reader: R, + // buffer used for uncompressing data. This is a tiny optimization + // to (potentially) avoid allocation for each (compressed) message + uncompressed: Vec, + // internal buffer for unprocessed bytes + unprocessed_bytes: BytesMut, + // flag to indicate if we're expecting more data to be read. + expect_more_data: bool, +} + +impl StreamingDecoder { + pub async fn new(version_policy: VersionPolicy, mut reader: R) -> Result { + let mut data = [0_u8; FileHeader::SIZE]; + + reader + .read_exact(&mut data) + .await + .map_err(DecodeError::Read)?; + + let (version, options) = read_options(version_policy, &data)?; + + Ok(Self { + version, + options, + reader, + uncompressed: Vec::new(), + unprocessed_bytes: BytesMut::new(), + expect_more_data: false, + }) + } + + /// Returns true if `data` can be successfully decoded into a `FileHeader`. + fn peek_file_header(data: &[u8]) -> bool { + let mut read = std::io::Cursor::new(data); + FileHeader::decode(&mut read).is_ok() + } +} + +/// `StreamingDecoder` relies on the underlying reader for the wakeup mechanism. +/// The fact that we can have concatanated file or corrupted file pushes us to keep +/// the state of the decoder in the struct itself (through `unprocessed_bytes` and `expect_more_data`). +impl Stream for StreamingDecoder { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + loop { + let Self { + options, + reader, + uncompressed, + unprocessed_bytes, + expect_more_data, + .. + } = &mut *self; + + let serializer = options.serializer; + let compression = options.compression; + let mut buf_length = 0; + + // poll_fill_buf() implicitly handles the EOF case, so we don't need to check for it + match Pin::new(reader).poll_fill_buf(cx) { + std::task::Poll::Ready(Ok([])) => { + if unprocessed_bytes.is_empty() { + return std::task::Poll::Ready(None); + } + // there's more unprocessed data, but there's nothing in the underlying + // bytes stream - this indicates a corrupted stream + if *expect_more_data { + warn!("There's {} unprocessed data, but not enough for decoding a full message", unprocessed_bytes.len()); + return std::task::Poll::Ready(None); + } + } + + std::task::Poll::Ready(Ok(buf)) => { + unprocessed_bytes.extend_from_slice(buf); + buf_length = buf.len(); + } + + std::task::Poll::Ready(Err(err)) => { + return std::task::Poll::Ready(Some(Err(DecodeError::Read(err)))); + } + + std::task::Poll::Pending => return std::task::Poll::Pending, + }; + + // check if this is a start of a new concatenated file + if unprocessed_bytes.len() >= FileHeader::SIZE + && Self::peek_file_header(&unprocessed_bytes[..FileHeader::SIZE]) + { + let data = &unprocessed_bytes[..FileHeader::SIZE]; + // We've found another file header in the middle of the stream, it's time to switch + // gears and start over on this new file. + match read_options(VersionPolicy::Warn, data) { + Ok((version, options)) => { + self.version = CrateVersion::max(self.version, version); + self.options = options; + + Pin::new(&mut self.reader).consume(buf_length); + self.unprocessed_bytes.advance(FileHeader::SIZE); + + continue; + } + Err(err) => return std::task::Poll::Ready(Some(Err(err))), + } + } + + let (msg, processed_length) = match serializer { + crate::Serializer::MsgPack => { + let header_size = super::MessageHeader::SIZE; + if unprocessed_bytes.len() < header_size { + // Not enough data to read the header, need to wait for more + self.expect_more_data = true; + Pin::new(&mut self.reader).consume(buf_length); + + continue; + } + let data = &unprocessed_bytes[..header_size]; + let header = super::MessageHeader::from_bytes(data); + + match header { + super::MessageHeader::Data { + compressed_len, + uncompressed_len, + } => { + let uncompressed_len = uncompressed_len as usize; + let compressed_len = compressed_len as usize; + + // read the data + let (data, length) = match compression { + Compression::Off => { + if unprocessed_bytes.len() < uncompressed_len + header_size { + self.expect_more_data = true; + Pin::new(&mut self.reader).consume(buf_length); + + continue; + } + + ( + &unprocessed_bytes + [header_size..uncompressed_len + header_size], + uncompressed_len, + ) + } + + Compression::LZ4 => { + if unprocessed_bytes.len() < compressed_len + header_size { + // Not enough data to read the message, need to wait for more + self.expect_more_data = true; + Pin::new(&mut self.reader).consume(buf_length); + + continue; + } + + uncompressed + .resize(uncompressed.len().max(uncompressed_len), 0); + let data = &unprocessed_bytes + [header_size..compressed_len + header_size]; + if let Err(err) = + lz4_flex::block::decompress_into(data, uncompressed) + { + return std::task::Poll::Ready(Some(Err( + DecodeError::Lz4(err), + ))); + } + + (&uncompressed[..], compressed_len) + } + }; + + // decode the message + let msg = rmp_serde::from_slice::(data); + + match msg { + Ok(msg) => (Some(msg), length + header_size), + Err(err) => { + return std::task::Poll::Ready(Some(Err( + DecodeError::MsgPack(err), + ))); + } + } + } + + super::MessageHeader::EndOfStream => return std::task::Poll::Ready(None), + } + } + + crate::Serializer::Protobuf => { + let header_size = std::mem::size_of::(); + if unprocessed_bytes.len() < header_size { + // Not enough data to read the header, need to wait for more + self.expect_more_data = true; + Pin::new(&mut self.reader).consume(buf_length); + + continue; + } + let data = &unprocessed_bytes[..header_size]; + let header = file::MessageHeader::from_bytes(data)?; + + if unprocessed_bytes.len() < header.len as usize + header_size { + // Not enough data to read the message, need to wait for more + self.expect_more_data = true; + Pin::new(&mut self.reader).consume(buf_length); + + continue; + } + + // decode the message + let data = &unprocessed_bytes[header_size..header_size + header.len as usize]; + let msg = file::decoder::decode_bytes(header.kind, data)?; + + (msg, header.len as usize + header_size) + } + }; + + let Some(mut msg) = msg else { + // we've reached the end of the stream (i.e. read the EoS header), we check if there's another file concatenated + if unprocessed_bytes.len() < processed_length + FileHeader::SIZE { + return std::task::Poll::Ready(None); + } + + let data = + &unprocessed_bytes[processed_length..processed_length + FileHeader::SIZE]; + if Self::peek_file_header(data) { + re_log::debug!( + "Reached end of stream, but it seems we have a concatenated file, continuing" + ); + + Pin::new(&mut self.reader).consume(buf_length); + continue; + } + + re_log::debug!("Reached end of stream, iterator complete"); + return std::task::Poll::Ready(None); + }; + + if let LogMsg::SetStoreInfo(msg) = &mut msg { + // Propagate the protocol version from the header into the `StoreInfo` so that all + // parts of the app can easily access it. + msg.info.store_version = Some(self.version); + } + + Pin::new(&mut self.reader).consume(buf_length); + self.unprocessed_bytes.advance(processed_length); + self.expect_more_data = false; + + return std::task::Poll::Ready(Some(Ok(msg))); + } + } +} + +#[cfg(all(test, feature = "decoder", feature = "encoder"))] +mod tests { + use re_build_info::CrateVersion; + use tokio_stream::StreamExt; + + use crate::{ + decoder::{ + streaming::StreamingDecoder, + tests::{fake_log_messages, strip_arrow_extensions_from_log_messages}, + }, + Compression, EncodingOptions, Serializer, VersionPolicy, + }; + + #[tokio::test] + async fn test_streaming_decoder_handles_corrupted_input_file() { + let rrd_version = CrateVersion::LOCAL; + + let messages = fake_log_messages(); + + let options = [ + EncodingOptions { + compression: Compression::Off, + serializer: Serializer::MsgPack, + }, + EncodingOptions { + compression: Compression::LZ4, + serializer: Serializer::MsgPack, + }, + EncodingOptions { + compression: Compression::Off, + serializer: Serializer::Protobuf, + }, + EncodingOptions { + compression: Compression::LZ4, + serializer: Serializer::Protobuf, + }, + ]; + + for options in options { + let mut data = vec![]; + crate::encoder::encode_ref(rrd_version, options, messages.iter().map(Ok), &mut data) + .unwrap(); + + // We cut the input file by one byte to simulate a corrupted file and check that we don't end up in an infinite loop + // waiting for more data when there's none to be read. + let data = &data[..data.len() - 1]; + + let buf_reader = tokio::io::BufReader::new(std::io::Cursor::new(data)); + + let decoder = StreamingDecoder::new(VersionPolicy::Error, buf_reader) + .await + .unwrap(); + + let decoded_messages = strip_arrow_extensions_from_log_messages( + decoder.collect::, _>>().await.unwrap(), + ); + + similar_asserts::assert_eq!(decoded_messages, messages); + } + } + + #[tokio::test] + async fn test_streaming_decoder_happy_paths() { + let rrd_version = CrateVersion::LOCAL; + + let messages = fake_log_messages(); + + let options = [ + EncodingOptions { + compression: Compression::Off, + serializer: Serializer::MsgPack, + }, + EncodingOptions { + compression: Compression::LZ4, + serializer: Serializer::MsgPack, + }, + EncodingOptions { + compression: Compression::Off, + serializer: Serializer::Protobuf, + }, + EncodingOptions { + compression: Compression::LZ4, + serializer: Serializer::Protobuf, + }, + ]; + + for options in options { + let mut data = vec![]; + crate::encoder::encode_ref(rrd_version, options, messages.iter().map(Ok), &mut data) + .unwrap(); + + let buf_reader = tokio::io::BufReader::new(std::io::Cursor::new(data)); + + let decoder = StreamingDecoder::new(VersionPolicy::Error, buf_reader) + .await + .unwrap(); + + let decoded_messages = strip_arrow_extensions_from_log_messages( + decoder.collect::, _>>().await.unwrap(), + ); + + similar_asserts::assert_eq!(decoded_messages, messages); + } + } +} diff --git a/crates/store/re_log_encoding/src/lib.rs b/crates/store/re_log_encoding/src/lib.rs index 709b96fcdb4f..3fad54e5a963 100644 --- a/crates/store/re_log_encoding/src/lib.rs +++ b/crates/store/re_log_encoding/src/lib.rs @@ -217,23 +217,28 @@ impl MessageHeader { #[cfg(feature = "decoder")] pub fn decode(read: &mut impl std::io::Read) -> Result { - fn u32_from_le_slice(bytes: &[u8]) -> u32 { - u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) - } - let mut buffer = [0_u8; Self::SIZE]; read.read_exact(&mut buffer) .map_err(decoder::DecodeError::Read)?; - if u32_from_le_slice(&buffer[0..4]) == 0 && u32_from_le_slice(&buffer[4..]) == 0 { - Ok(Self::EndOfStream) + Ok(Self::from_bytes(&buffer)) + } + + #[cfg(feature = "decoder")] + pub fn from_bytes(data: &[u8]) -> Self { + fn u32_from_le_slice(bytes: &[u8]) -> u32 { + u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) + } + + if u32_from_le_slice(&data[0..4]) == 0 && u32_from_le_slice(&data[4..]) == 0 { + Self::EndOfStream } else { - let compressed = u32_from_le_slice(&buffer[0..4]); - let uncompressed = u32_from_le_slice(&buffer[4..]); - Ok(Self::Data { + let compressed = u32_from_le_slice(&data[0..4]); + let uncompressed = u32_from_le_slice(&data[4..]); + Self::Data { compressed_len: compressed, uncompressed_len: uncompressed, - }) + } } } }