From 5b873131e5297cf5042ff1f4e387f266153b5ac8 Mon Sep 17 00:00:00 2001 From: Amos Wenger Date: Mon, 5 Feb 2024 21:37:35 +0100 Subject: [PATCH] Introduce process_till_header --- rc-zip/src/fsm/entry/mod.rs | 119 ++++++++++++++++++++---------------- 1 file changed, 65 insertions(+), 54 deletions(-) diff --git a/rc-zip/src/fsm/entry/mod.rs b/rc-zip/src/fsm/entry/mod.rs index 51d0c4d..b9e6f35 100644 --- a/rc-zip/src/fsm/entry/mod.rs +++ b/rc-zip/src/fsm/entry/mod.rs @@ -42,9 +42,6 @@ enum State { ReadLocalHeader, ReadData { - /// The entry metadata - entry: Entry, - /// Whether the entry has a data descriptor has_data_descriptor: bool, @@ -65,9 +62,6 @@ enum State { }, ReadDataDescriptor { - /// The entry metadata - entry: Entry, - /// Whether the entry is zip64 (because its compressed size or uncompressed size is u32::MAX) is_zip64: bool, @@ -76,9 +70,6 @@ enum State { }, Validate { - /// The entry metadata - entry: Entry, - /// Size we've decompressed + crc32 hash we've computed metrics: EntryReadMetrics, @@ -125,9 +116,58 @@ impl EntryFsm { } } - /// Like `process`, but only processes the header: - pub fn process_header_only(&mut self) -> Option<&LocalFileHeader> { - todo!() + /// Like `process`, but only processes the header. If this returns + /// `Ok(None)`, the caller should read more data and call this function + /// again. + pub fn process_till_header(&mut self) -> Result, Error> { + match &self.state { + State::ReadLocalHeader => { + self.internal_process_local_header()?; + } + _ => { + // already good + } + } + + // this will be non-nil if we've parsed the local header, otherwise, + Ok(self.entry.as_ref()) + } + + fn internal_process_local_header(&mut self) -> Result { + assert!( + matches!(self.state, State::ReadLocalHeader), + "internal_process_local_header called in wrong state", + ); + + let mut input = Partial::new(self.buffer.data()); + match LocalFileHeader::parser.parse_next(&mut input) { + Ok(header) => { + let consumed = input.as_bytes().offset_from(&self.buffer.data()); + tracing::trace!(local_file_header = ?header, consumed, "parsed local file header"); + let decompressor = AnyDecompressor::new( + header.method, + self.entry.as_ref().map(|entry| entry.uncompressed_size), + )?; + + if self.entry.is_none() { + self.entry = Some(header.as_entry()?); + } + + self.state = State::ReadData { + is_zip64: header.compressed_size == u32::MAX + || header.uncompressed_size == u32::MAX, + has_data_descriptor: header.has_data_descriptor(), + compressed_bytes: 0, + uncompressed_bytes: 0, + hasher: crc32fast::Hasher::new(), + decompressor, + }; + self.buffer.consume(consumed); + Ok(true) + } + Err(ErrMode::Incomplete(_)) => Ok(false), + Err(_e) => Err(Error::Format(FormatError::InvalidLocalHeader)), + } } /// Process the input and write the output to the given buffer @@ -158,40 +198,10 @@ impl EntryFsm { use State as S; match &mut self.state { S::ReadLocalHeader => { - let mut input = Partial::new(self.buffer.data()); - match LocalFileHeader::parser.parse_next(&mut input) { - Ok(header) => { - let consumed = input.as_bytes().offset_from(&self.buffer.data()); - tracing::trace!(local_file_header = ?header, consumed, "parsed local file header"); - let decompressor = AnyDecompressor::new( - header.method, - self.entry.as_ref().map(|entry| entry.uncompressed_size), - )?; - - self.state = S::ReadData { - entry: match &self.entry { - Some(entry) => entry.clone(), - None => header.as_entry()?, - }, - is_zip64: header.compressed_size == u32::MAX - || header.uncompressed_size == u32::MAX, - has_data_descriptor: header.has_data_descriptor(), - compressed_bytes: 0, - uncompressed_bytes: 0, - hasher: crc32fast::Hasher::new(), - decompressor, - }; - self.buffer.consume(consumed); - self.process(out) - } - Err(ErrMode::Incomplete(_)) => { - Ok(FsmResult::Continue((self, Default::default()))) - } - Err(_e) => Err(Error::Format(FormatError::InvalidLocalHeader)), - } + self.internal_process_local_header()?; + self.process(out) } S::ReadData { - entry, compressed_bytes, uncompressed_bytes, hasher, @@ -202,6 +212,7 @@ impl EntryFsm { // don't feed the decompressor bytes beyond the entry's compressed size + let entry = self.entry.as_ref().unwrap(); let in_buf_max_len = cmp::min( in_buf.len(), entry.compressed_size as usize - *compressed_bytes as usize, @@ -237,16 +248,16 @@ impl EntryFsm { if outcome.bytes_written == 0 && self.eof { // we're done, let's read the data descriptor (if there's one) - transition!(self.state => (S::ReadData { entry, has_data_descriptor, is_zip64, uncompressed_bytes, hasher, .. }) { + transition!(self.state => (S::ReadData { has_data_descriptor, is_zip64, uncompressed_bytes, hasher, .. }) { let metrics = EntryReadMetrics { uncompressed_size: uncompressed_bytes, crc32: hasher.finalize(), }; if has_data_descriptor { - S::ReadDataDescriptor { entry, metrics, is_zip64 } + S::ReadDataDescriptor { metrics, is_zip64 } } else { - S::Validate { entry, metrics, descriptor: None } + S::Validate { metrics, descriptor: None } } }); return self.process(out); @@ -273,8 +284,8 @@ impl EntryFsm { self.buffer .consume(input.as_bytes().offset_from(&self.buffer.data())); trace!("data descriptor = {:#?}", descriptor); - transition!(self.state => (S::ReadDataDescriptor { metrics, entry, .. }) { - S::Validate { entry, metrics, descriptor: Some(descriptor) } + transition!(self.state => (S::ReadDataDescriptor { metrics, .. }) { + S::Validate { metrics, descriptor: Some(descriptor) } }); self.process(out) } @@ -285,17 +296,17 @@ impl EntryFsm { } } S::Validate { - entry, metrics, descriptor, } => { - let entry_crc32 = self.entry.as_ref().map(|e| e.crc32).unwrap_or_default(); - let expected_crc32 = if entry_crc32 != 0 { - entry_crc32 + let entry = self.entry.as_ref().unwrap(); + + let expected_crc32 = if entry.crc32 != 0 { + entry.crc32 } else if let Some(descriptor) = descriptor.as_ref() { descriptor.crc32 } else { - entry.crc32 + 0 }; if entry.uncompressed_size != metrics.uncompressed_size {