Skip to content

Commit

Permalink
Introduce process_till_header
Browse files Browse the repository at this point in the history
  • Loading branch information
fasterthanlime committed Feb 5, 2024
1 parent f579684 commit 5b87313
Showing 1 changed file with 65 additions and 54 deletions.
119 changes: 65 additions & 54 deletions rc-zip/src/fsm/entry/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ enum State {
ReadLocalHeader,

ReadData {
/// The entry metadata
entry: Entry,

/// Whether the entry has a data descriptor
has_data_descriptor: bool,

Expand All @@ -65,9 +62,6 @@ enum State {
},

ReadDataDescriptor {
/// The entry metadata
entry: Entry,

/// Whether the entry is zip64 (because its compressed size or uncompressed size is u32::MAX)
is_zip64: bool,

Expand All @@ -76,9 +70,6 @@ enum State {
},

Validate {
/// The entry metadata
entry: Entry,

/// Size we've decompressed + crc32 hash we've computed
metrics: EntryReadMetrics,

Expand Down Expand Up @@ -125,9 +116,58 @@ impl EntryFsm {
}
}

/// Like `process`, but only processes the header:
pub fn process_header_only(&mut self) -> Option<&LocalFileHeader> {
todo!()
/// Like `process`, but only processes the header. If this returns
/// `Ok(None)`, the caller should read more data and call this function
/// again.
pub fn process_till_header(&mut self) -> Result<Option<&Entry>, Error> {
match &self.state {
State::ReadLocalHeader => {
self.internal_process_local_header()?;
}
_ => {
// already good
}
}

// this will be non-nil if we've parsed the local header, otherwise,
Ok(self.entry.as_ref())
}

fn internal_process_local_header(&mut self) -> Result<bool, Error> {
assert!(
matches!(self.state, State::ReadLocalHeader),
"internal_process_local_header called in wrong state",
);

let mut input = Partial::new(self.buffer.data());
match LocalFileHeader::parser.parse_next(&mut input) {
Ok(header) => {
let consumed = input.as_bytes().offset_from(&self.buffer.data());
tracing::trace!(local_file_header = ?header, consumed, "parsed local file header");
let decompressor = AnyDecompressor::new(
header.method,
self.entry.as_ref().map(|entry| entry.uncompressed_size),
)?;

if self.entry.is_none() {
self.entry = Some(header.as_entry()?);
}

self.state = State::ReadData {
is_zip64: header.compressed_size == u32::MAX
|| header.uncompressed_size == u32::MAX,
has_data_descriptor: header.has_data_descriptor(),
compressed_bytes: 0,
uncompressed_bytes: 0,
hasher: crc32fast::Hasher::new(),
decompressor,
};
self.buffer.consume(consumed);
Ok(true)
}
Err(ErrMode::Incomplete(_)) => Ok(false),
Err(_e) => Err(Error::Format(FormatError::InvalidLocalHeader)),
}
}

/// Process the input and write the output to the given buffer
Expand Down Expand Up @@ -158,40 +198,10 @@ impl EntryFsm {
use State as S;
match &mut self.state {
S::ReadLocalHeader => {
let mut input = Partial::new(self.buffer.data());
match LocalFileHeader::parser.parse_next(&mut input) {
Ok(header) => {
let consumed = input.as_bytes().offset_from(&self.buffer.data());
tracing::trace!(local_file_header = ?header, consumed, "parsed local file header");
let decompressor = AnyDecompressor::new(
header.method,
self.entry.as_ref().map(|entry| entry.uncompressed_size),
)?;

self.state = S::ReadData {
entry: match &self.entry {
Some(entry) => entry.clone(),
None => header.as_entry()?,
},
is_zip64: header.compressed_size == u32::MAX
|| header.uncompressed_size == u32::MAX,
has_data_descriptor: header.has_data_descriptor(),
compressed_bytes: 0,
uncompressed_bytes: 0,
hasher: crc32fast::Hasher::new(),
decompressor,
};
self.buffer.consume(consumed);
self.process(out)
}
Err(ErrMode::Incomplete(_)) => {
Ok(FsmResult::Continue((self, Default::default())))
}
Err(_e) => Err(Error::Format(FormatError::InvalidLocalHeader)),
}
self.internal_process_local_header()?;
self.process(out)
}
S::ReadData {
entry,
compressed_bytes,
uncompressed_bytes,
hasher,
Expand All @@ -202,6 +212,7 @@ impl EntryFsm {

// don't feed the decompressor bytes beyond the entry's compressed size

let entry = self.entry.as_ref().unwrap();
let in_buf_max_len = cmp::min(
in_buf.len(),
entry.compressed_size as usize - *compressed_bytes as usize,
Expand Down Expand Up @@ -237,16 +248,16 @@ impl EntryFsm {

if outcome.bytes_written == 0 && self.eof {
// we're done, let's read the data descriptor (if there's one)
transition!(self.state => (S::ReadData { entry, has_data_descriptor, is_zip64, uncompressed_bytes, hasher, .. }) {
transition!(self.state => (S::ReadData { has_data_descriptor, is_zip64, uncompressed_bytes, hasher, .. }) {
let metrics = EntryReadMetrics {
uncompressed_size: uncompressed_bytes,
crc32: hasher.finalize(),
};

if has_data_descriptor {
S::ReadDataDescriptor { entry, metrics, is_zip64 }
S::ReadDataDescriptor { metrics, is_zip64 }
} else {
S::Validate { entry, metrics, descriptor: None }
S::Validate { metrics, descriptor: None }
}
});
return self.process(out);
Expand All @@ -273,8 +284,8 @@ impl EntryFsm {
self.buffer
.consume(input.as_bytes().offset_from(&self.buffer.data()));
trace!("data descriptor = {:#?}", descriptor);
transition!(self.state => (S::ReadDataDescriptor { metrics, entry, .. }) {
S::Validate { entry, metrics, descriptor: Some(descriptor) }
transition!(self.state => (S::ReadDataDescriptor { metrics, .. }) {
S::Validate { metrics, descriptor: Some(descriptor) }
});
self.process(out)
}
Expand All @@ -285,17 +296,17 @@ impl EntryFsm {
}
}
S::Validate {
entry,
metrics,
descriptor,
} => {
let entry_crc32 = self.entry.as_ref().map(|e| e.crc32).unwrap_or_default();
let expected_crc32 = if entry_crc32 != 0 {
entry_crc32
let entry = self.entry.as_ref().unwrap();

let expected_crc32 = if entry.crc32 != 0 {
entry.crc32
} else if let Some(descriptor) = descriptor.as_ref() {
descriptor.crc32
} else {
entry.crc32
0
};

if entry.uncompressed_size != metrics.uncompressed_size {
Expand Down

0 comments on commit 5b87313

Please sign in to comment.