From 18c494b948d11688d4a2d9a6713fc3ae3141b28c Mon Sep 17 00:00:00 2001 From: jp0317 Date: Sat, 16 Nov 2024 05:43:41 +0000 Subject: [PATCH] Reduce panics --- parquet/examples/read_with_rowgroup.rs | 2 +- parquet/src/arrow/async_reader/mod.rs | 25 ++++++------ parquet/src/errors.rs | 7 ++++ parquet/src/file/metadata/mod.rs | 11 +++--- parquet/src/file/metadata/reader.rs | 25 ++++++------ parquet/src/file/serialized_reader.rs | 55 ++++++++++++++++++++++---- parquet/src/file/statistics.rs | 26 ++++++++++++ parquet/src/format.rs | 6 +++ parquet/src/schema/types.rs | 6 +++ parquet/src/thrift.rs | 45 +++++++++++++++++---- parquet/tests/arrow_reader/bad_data.rs | 2 +- 11 files changed, 163 insertions(+), 47 deletions(-) diff --git a/parquet/examples/read_with_rowgroup.rs b/parquet/examples/read_with_rowgroup.rs index 8cccc7fe14ac..09b5e90d7ea1 100644 --- a/parquet/examples/read_with_rowgroup.rs +++ b/parquet/examples/read_with_rowgroup.rs @@ -165,7 +165,7 @@ impl InMemoryRowGroup { let mut vs = std::mem::take(&mut self.column_chunks); for (leaf_idx, meta) in self.metadata.columns().iter().enumerate() { if self.mask.leaf_included(leaf_idx) { - let (start, len) = meta.byte_range(); + let (start, len) = meta.byte_range()?; let data = reader .get_bytes(start as usize..(start + len) as usize) .await?; diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index 8b315cc9f784..e7dbfb8b7411 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -745,11 +745,11 @@ impl<'a> InMemoryRowGroup<'a> { .filter(|&(idx, (chunk, _chunk_meta))| { chunk.is_none() && projection.leaf_included(idx) }) - .flat_map(|(idx, (_chunk, chunk_meta))| { + .flat_map(|(idx, (_chunk, chunk_meta))| -> Result>> { // If the first page does not start at the beginning of the column, // then we need to also fetch a dictionary page. let mut ranges = vec![]; - let (start, _len) = chunk_meta.byte_range(); + let (start, _len) = chunk_meta.byte_range()?; match offset_index[idx].page_locations.first() { Some(first) if first.offset as u64 != start => { ranges.push(start as usize..first.offset as usize); @@ -760,8 +760,11 @@ impl<'a> InMemoryRowGroup<'a> { ranges.extend(selection.scan_ranges(&offset_index[idx].page_locations)); page_start_offsets.push(ranges.iter().map(|range| range.start).collect()); - ranges + Ok(ranges) }) + .collect::>() + .into_iter() + .flat_map(|ranges| ranges) .collect(); let mut chunk_data = input.get_byte_ranges(fetch_ranges).await?.into_iter(); @@ -779,25 +782,25 @@ impl<'a> InMemoryRowGroup<'a> { } *chunk = Some(Arc::new(ColumnChunkData::Sparse { - length: self.metadata.column(idx).byte_range().1 as usize, + length: self.metadata.column(idx).byte_range()?.1 as usize, data: offsets.into_iter().zip(chunks.into_iter()).collect(), })) } } } else { - let fetch_ranges = self + let fetch_ranges: Result>> = self .column_chunks .iter() .enumerate() .filter(|&(idx, chunk)| chunk.is_none() && projection.leaf_included(idx)) .map(|(idx, _chunk)| { let column = self.metadata.column(idx); - let (start, length) = column.byte_range(); - start as usize..(start + length) as usize + let (start, length) = column.byte_range()?; + Ok(start as usize..(start + length) as usize) }) .collect(); - let mut chunk_data = input.get_byte_ranges(fetch_ranges).await?.into_iter(); + let mut chunk_data = input.get_byte_ranges(fetch_ranges?).await?.into_iter(); for (idx, chunk) in self.column_chunks.iter_mut().enumerate() { if chunk.is_some() || !projection.leaf_included(idx) { @@ -806,7 +809,7 @@ impl<'a> InMemoryRowGroup<'a> { if let Some(data) = chunk_data.next() { *chunk = Some(Arc::new(ColumnChunkData::Dense { - offset: self.metadata.column(idx).byte_range().0 as usize, + offset: self.metadata.column(idx).byte_range()?.0 as usize, data, })); } @@ -1008,8 +1011,8 @@ mod tests { assert_eq!(async_batches, sync_batches); let requests = requests.lock().unwrap(); - let (offset_1, length_1) = metadata.row_group(0).column(1).byte_range(); - let (offset_2, length_2) = metadata.row_group(0).column(2).byte_range(); + let (offset_1, length_1) = metadata.row_group(0).column(1).byte_range().unwrap(); + let (offset_2, length_2) = metadata.row_group(0).column(2).byte_range().unwrap(); assert_eq!( &requests[..], diff --git a/parquet/src/errors.rs b/parquet/src/errors.rs index 6adbffa2a2e5..f7fb1ead0ccc 100644 --- a/parquet/src/errors.rs +++ b/parquet/src/errors.rs @@ -17,6 +17,7 @@ //! Common Parquet errors and macros. +use core::num::TryFromIntError; use std::error::Error; use std::{cell, io, result, str}; @@ -76,6 +77,12 @@ impl Error for ParquetError { } } +impl From for ParquetError { + fn from(e: TryFromIntError) -> ParquetError { + ParquetError::General(format!("Integer overflow: {e}")) + } +} + impl From for ParquetError { fn from(e: io::Error) -> ParquetError { ParquetError::External(Box::new(e)) diff --git a/parquet/src/file/metadata/mod.rs b/parquet/src/file/metadata/mod.rs index 32b985710023..e548c57ecb9e 100644 --- a/parquet/src/file/metadata/mod.rs +++ b/parquet/src/file/metadata/mod.rs @@ -959,17 +959,16 @@ impl ColumnChunkMetaData { } /// Returns the offset and length in bytes of the column chunk within the file - pub fn byte_range(&self) -> (u64, u64) { + pub fn byte_range(&self) -> Result<(u64, u64)> { let col_start = match self.dictionary_page_offset() { Some(dictionary_page_offset) => dictionary_page_offset, None => self.data_page_offset(), }; let col_len = self.compressed_size(); - assert!( - col_start >= 0 && col_len >= 0, - "column start and length should not be negative" - ); - (col_start as u64, col_len as u64) + if col_start < 0 || col_len < 0 { + return Err(general_err!("column start and length should not be negative")); + } + Ok((col_start as u64, col_len as u64)) } /// Returns statistics that are set for this column chunk, diff --git a/parquet/src/file/metadata/reader.rs b/parquet/src/file/metadata/reader.rs index 2a927f15fb64..3bb8208814aa 100644 --- a/parquet/src/file/metadata/reader.rs +++ b/parquet/src/file/metadata/reader.rs @@ -617,7 +617,7 @@ impl ParquetMetaDataReader { for rg in t_file_metadata.row_groups { row_groups.push(RowGroupMetaData::from_thrift(schema_descr.clone(), rg)?); } - let column_orders = Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr); + let column_orders = Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr)?; let file_metadata = FileMetaData::new( t_file_metadata.version, @@ -635,15 +635,13 @@ impl ParquetMetaDataReader { fn parse_column_orders( t_column_orders: Option>, schema_descr: &SchemaDescriptor, - ) -> Option> { + ) -> Result>> { match t_column_orders { Some(orders) => { // Should always be the case - assert_eq!( - orders.len(), - schema_descr.num_columns(), - "Column order length mismatch" - ); + if orders.len() != schema_descr.num_columns() { + return Err(general_err!("Column order length mismatch")); + }; let mut res = Vec::new(); for (i, column) in schema_descr.columns().iter().enumerate() { match orders[i] { @@ -657,9 +655,9 @@ impl ParquetMetaDataReader { } } } - Some(res) + Ok(Some(res)) } - None => None, + None => Ok(None), } } } @@ -731,7 +729,7 @@ mod tests { ]); assert_eq!( - ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr), + ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr).unwrap(), Some(vec![ ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED), ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED) @@ -740,20 +738,21 @@ mod tests { // Test when no column orders are defined. assert_eq!( - ParquetMetaDataReader::parse_column_orders(None, &schema_descr), + ParquetMetaDataReader::parse_column_orders(None, &schema_descr).unwrap(), None ); } #[test] - #[should_panic(expected = "Column order length mismatch")] fn test_metadata_column_orders_len_mismatch() { let schema = SchemaType::group_type_builder("schema").build().unwrap(); let schema_descr = SchemaDescriptor::new(Arc::new(schema)); let t_column_orders = Some(vec![TColumnOrder::TYPEORDER(TypeDefinedOrder::new())]); - ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr); + let res = ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr); + assert!(res.is_err()); + assert!(format!("{:?}", res.unwrap_err()).contains("Column order length mismatch")); } #[test] diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index 3262d1fba704..06684af53fe7 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -435,7 +435,7 @@ pub(crate) fn decode_page( let is_sorted = dict_header.is_sorted.unwrap_or(false); Page::DictionaryPage { buf: buffer, - num_values: dict_header.num_values as u32, + num_values: dict_header.num_values.try_into()?, encoding: Encoding::try_from(dict_header.encoding)?, is_sorted, } @@ -446,7 +446,7 @@ pub(crate) fn decode_page( .ok_or_else(|| ParquetError::General("Missing V1 data page header".to_string()))?; Page::DataPage { buf: buffer, - num_values: header.num_values as u32, + num_values: header.num_values.try_into()?, encoding: Encoding::try_from(header.encoding)?, def_level_encoding: Encoding::try_from(header.definition_level_encoding)?, rep_level_encoding: Encoding::try_from(header.repetition_level_encoding)?, @@ -460,12 +460,12 @@ pub(crate) fn decode_page( let is_compressed = header.is_compressed.unwrap_or(true); Page::DataPageV2 { buf: buffer, - num_values: header.num_values as u32, + num_values: header.num_values.try_into()?, encoding: Encoding::try_from(header.encoding)?, - num_nulls: header.num_nulls as u32, - num_rows: header.num_rows as u32, - def_levels_byte_len: header.definition_levels_byte_length as u32, - rep_levels_byte_len: header.repetition_levels_byte_length as u32, + num_nulls: header.num_nulls.try_into()?, + num_rows: header.num_rows.try_into()?, + def_levels_byte_len: header.definition_levels_byte_length.try_into()?, + rep_levels_byte_len: header.repetition_levels_byte_length.try_into()?, is_compressed, statistics: statistics::from_thrift(physical_type, header.statistics)?, } @@ -535,7 +535,7 @@ impl SerializedPageReader { props: ReaderPropertiesPtr, ) -> Result { let decompressor = create_codec(meta.compression(), props.codec_options())?; - let (start, len) = meta.byte_range(); + let (start, len) = meta.byte_range()?; let state = match page_locations { Some(locations) => { @@ -578,6 +578,27 @@ impl Iterator for SerializedPageReader { } } +fn verify_page_header_len(header_len: usize, remaining_bytes: usize) -> Result<()> { + if header_len > remaining_bytes { + return Err(eof_err!("Invalid page header")); + } + Ok(()) +} + +fn verify_page_size( + compressed_size: i32, + uncompressed_size: i32, + remaining_bytes: usize, +) -> Result<()> { + // The page's compressed size should not exceed the remaining bytes that are + // available to read. The page's uncompressed size is the expected size + // after decompression, which can never be negative. + if compressed_size < 0 || compressed_size as usize > remaining_bytes || uncompressed_size < 0 { + return Err(eof_err!("Invalid page header")); + } + Ok(()) +} + impl PageReader for SerializedPageReader { fn get_next_page(&mut self) -> Result> { loop { @@ -596,10 +617,16 @@ impl PageReader for SerializedPageReader { *header } else { let (header_len, header) = read_page_header_len(&mut read)?; + verify_page_header_len(header_len, *remaining)?; *offset += header_len; *remaining -= header_len; header }; + verify_page_size( + header.compressed_page_size, + header.uncompressed_page_size, + *remaining, + )?; let data_len = header.compressed_page_size as usize; *offset += data_len; *remaining -= data_len; @@ -683,6 +710,7 @@ impl PageReader for SerializedPageReader { } else { let mut read = self.reader.get_read(*offset as u64)?; let (header_len, header) = read_page_header_len(&mut read)?; + verify_page_header_len(header_len, *remaining_bytes)?; *offset += header_len; *remaining_bytes -= header_len; let page_meta = if let Ok(page_meta) = (&header).try_into() { @@ -733,12 +761,23 @@ impl PageReader for SerializedPageReader { next_page_header, } => { if let Some(buffered_header) = next_page_header.take() { + verify_page_size( + buffered_header.compressed_page_size, + buffered_header.uncompressed_page_size, + *remaining_bytes, + )?; // The next page header has already been peeked, so just advance the offset *offset += buffered_header.compressed_page_size as usize; *remaining_bytes -= buffered_header.compressed_page_size as usize; } else { let mut read = self.reader.get_read(*offset as u64)?; let (header_len, header) = read_page_header_len(&mut read)?; + verify_page_header_len(header_len, *remaining_bytes)?; + verify_page_size( + header.compressed_page_size, + header.uncompressed_page_size, + *remaining_bytes, + )?; let data_page_size = header.compressed_page_size as usize; *offset += header_len + data_page_size; *remaining_bytes -= header_len + data_page_size; diff --git a/parquet/src/file/statistics.rs b/parquet/src/file/statistics.rs index 2e05b83369cf..d2382f2db6ed 100644 --- a/parquet/src/file/statistics.rs +++ b/parquet/src/file/statistics.rs @@ -157,6 +157,32 @@ pub fn from_thrift( stats.max_value }; + fn check_len(min: &Option>, max: &Option>, len: usize) -> Result<()> { + if let Some(min) = min { + if min.len() < len { + return Err(ParquetError::General(format!( + "Insufficient bytes to parse min statistic", + ))); + } + } + if let Some(max) = max { + if max.len() < len { + return Err(ParquetError::General(format!( + "Insufficient bytes to parse max statistic", + ))); + } + } + Ok(()) + } + + match physical_type { + Type::BOOLEAN => check_len(&min, &max, 1), + Type::INT32 | Type::FLOAT => check_len(&min, &max, 4), + Type::INT64 | Type::DOUBLE => check_len(&min, &max, 8), + Type::INT96 => check_len(&min, &max, 12), + _ => Ok(()) + }?; + // Values are encoded using PLAIN encoding definition, except that // variable-length byte arrays do not include a length prefix. // diff --git a/parquet/src/format.rs b/parquet/src/format.rs index 287d08b7a95c..3cfd79642a5d 100644 --- a/parquet/src/format.rs +++ b/parquet/src/format.rs @@ -1738,6 +1738,12 @@ impl crate::thrift::TSerializable for IntType { bit_width: f_1.expect("auto-generated code should have checked for presence of required fields"), is_signed: f_2.expect("auto-generated code should have checked for presence of required fields"), }; + if ret.bit_width != 8 && ret.bit_width != 16 && ret.bit_width != 32 && ret.bit_width != 64 { + return Err(thrift::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::InvalidData, + "Bit width must be 8, 16, 32, or 64 for Integer logical type", + ))); + } Ok(ret) } fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index b7ba95eb56bb..cd26dcd4609e 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -1122,6 +1122,10 @@ pub fn from_thrift(elements: &[SchemaElement]) -> Result { )); } + if !schema_nodes[0].is_group() { + return Err(general_err!("Expected root node to be a group type")); + } + Ok(schema_nodes.remove(0)) } @@ -1227,6 +1231,8 @@ fn from_thrift_helper(elements: &[SchemaElement], index: usize) -> Result<(usize if !is_root_node { builder = builder.with_repetition(rep); } + } else if !is_root_node { + return Err(general_err!("Repetition level must be defined for non-root types")); } Ok((next_index, Arc::new(builder.build().unwrap()))) } diff --git a/parquet/src/thrift.rs b/parquet/src/thrift.rs index ceb6b1c29fe8..19440f79c5f9 100644 --- a/parquet/src/thrift.rs +++ b/parquet/src/thrift.rs @@ -67,7 +67,17 @@ impl<'a> TCompactSliceInputProtocol<'a> { let mut shift = 0; loop { let byte = self.read_byte()?; - in_progress |= ((byte & 0x7F) as u64) << shift; + let val = (byte & 0x7F) as u64; + let val = val.checked_shl(shift).map_or_else( + || { + Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::InvalidData, + message: format!("cannot left-shift {} by {} bits", val, shift), + })) + }, + |res| Ok(res), + )?; + in_progress |= val; shift += 7; if byte & 0x80 == 0 { return Ok(in_progress); @@ -96,13 +106,22 @@ impl<'a> TCompactSliceInputProtocol<'a> { } } +macro_rules! thrift_unimplemented { + () => { + Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::NotImplemented, + message: "not implemented".to_string(), + })) + } +} + impl TInputProtocol for TCompactSliceInputProtocol<'_> { fn read_message_begin(&mut self) -> thrift::Result { - unimplemented!() + thrift_unimplemented!() } fn read_message_end(&mut self) -> thrift::Result<()> { - unimplemented!() + thrift_unimplemented!() } fn read_struct_begin(&mut self) -> thrift::Result> { @@ -147,7 +166,19 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> { ), _ => { if field_delta != 0 { - self.last_read_field_id += field_delta as i16; + self.last_read_field_id = + self.last_read_field_id.checked_add(field_delta as i16).map_or_else( + || { + Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::InvalidData, + message: format!( + "cannot add {} to {}", + field_delta, self.last_read_field_id + ), + })) + }, + |res| Ok(res), + )?; } else { self.last_read_field_id = self.read_i16()?; }; @@ -226,15 +257,15 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> { } fn read_set_begin(&mut self) -> thrift::Result { - unimplemented!() + thrift_unimplemented!() } fn read_set_end(&mut self) -> thrift::Result<()> { - unimplemented!() + thrift_unimplemented!() } fn read_map_begin(&mut self) -> thrift::Result { - unimplemented!() + thrift_unimplemented!() } fn read_map_end(&mut self) -> thrift::Result<()> { diff --git a/parquet/tests/arrow_reader/bad_data.rs b/parquet/tests/arrow_reader/bad_data.rs index 74342031432a..cfd61e82d32b 100644 --- a/parquet/tests/arrow_reader/bad_data.rs +++ b/parquet/tests/arrow_reader/bad_data.rs @@ -106,7 +106,7 @@ fn test_arrow_rs_gh_6229_dict_header() { let err = read_file("ARROW-RS-GH-6229-DICTHEADER.parquet").unwrap_err(); assert_eq!( err.to_string(), - "External: Parquet argument error: EOF: eof decoding byte array" + "External: Parquet argument error: Parquet error: Integer overflow: out of range integral type conversion attempted" ); }