Skip to content

Commit

Permalink
add test and docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
etseidl committed Dec 16, 2024
1 parent 006a388 commit f251b00
Showing 1 changed file with 68 additions and 2 deletions.
70 changes: 68 additions & 2 deletions parquet/src/column/writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,19 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
.unwrap_or_else(|| (data.to_vec(), false))
}

/// Truncates a binary statistic to at most `truncation_length` bytes, and then increment the
/// final byte(s) to yield a valid upper bound. This may result in a result of less than
/// `truncation_length` bytes if the last byte(s) overflows.
///
/// If truncation is not possible, returns `data`.
///
/// The `bool` in the returned tuple indicates whether truncation occurred or not.
///
/// UTF-8 Note:
/// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will
/// also remain valid UTF-8 (but again may be less than `truncation_length` bytes). If `data`
/// does not contain valid UTF-8, then truncation will occur as if the column is non-string
/// binary.
fn truncate_max_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) {
truncation_length
.filter(|l| data.len() > *l)
Expand Down Expand Up @@ -1500,10 +1513,13 @@ fn increment(mut data: Vec<u8>) -> Option<Vec<u8>> {

#[cfg(test)]
mod tests {
use crate::file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH;
use crate::{
file::{properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, writer::SerializedFileWriter},
schema::parser::parse_message_type,
};
use core::str;
use rand::distributions::uniform::SampleUniform;
use std::sync::Arc;
use std::{fs::File, sync::Arc};

use crate::column::{
page::PageReader,
Expand Down Expand Up @@ -3297,6 +3313,56 @@ mod tests {
assert!(r.is_none());
}

#[test]
// Check fallback truncation of statistics that should be UTF-8, but aren't
// (see https://github.com/apache/arrow-rs/pull/6870).
fn test_byte_array_truncate_invalid_utf8_statistics() {
let message_type = "
message test_schema {
OPTIONAL BYTE_ARRAY a (UTF8);
}
";
let schema = Arc::new(parse_message_type(message_type).unwrap());

// Create Vec<ByteArray> containing non-UTF8 bytes
let data = vec![ByteArray::from(vec![128u8; 32]); 7];
let def_levels = [1, 1, 1, 1, 0, 1, 0, 1, 0, 1];
let file: File = tempfile::tempfile().unwrap();
let props = Arc::new(
WriterProperties::builder()
.set_statistics_enabled(EnabledStatistics::Chunk)
.set_statistics_truncate_length(Some(8))
.build(),
);

let mut writer = SerializedFileWriter::new(&file, schema, props).unwrap();
let mut row_group_writer = writer.next_row_group().unwrap();

let mut col_writer = row_group_writer.next_column().unwrap().unwrap();
col_writer
.typed::<ByteArrayType>()
.write_batch(&data, Some(&def_levels), None)
.unwrap();
col_writer.close().unwrap();
row_group_writer.close().unwrap();
let file_metadata = writer.close().unwrap();
assert!(file_metadata.row_groups[0].columns[0].meta_data.is_some());
let stats = file_metadata.row_groups[0].columns[0]
.meta_data
.as_ref()
.unwrap()
.statistics
.as_ref()
.unwrap();
assert!(!stats.is_max_value_exact.unwrap());
// Truncation of invalid UTF-8 should fall back to binary truncation, so last byte should
// be incremented by 1.
assert_eq!(
stats.max_value,
Some([128, 128, 128, 128, 128, 128, 128, 129].to_vec())
);
}

#[test]
fn test_increment_max_binary_chars() {
let r = increment(vec![0xFF, 0xFE, 0xFD, 0xFF, 0xFF]);
Expand Down

0 comments on commit f251b00

Please sign in to comment.