Skip to content

Commit

Permalink
Improvements to UTF-8 statistics truncation (#6870)
Browse files Browse the repository at this point in the history
* fix a few edge cases with utf-8 incrementing

* add todo

* simplify truncation

* add another test

* note case where string should render right to left

* rework entirely, also avoid UTF8 processing if not required by the schema

* more consistent naming

* modify some tests to truncate in the middle of a multibyte char

* add test and docstring

* document truncate_min_value too
  • Loading branch information
etseidl authored Dec 16, 2024
1 parent fc6936a commit 9ffa065
Showing 1 changed file with 236 additions and 57 deletions.
293 changes: 236 additions & 57 deletions parquet/src/column/writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,24 +878,67 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
}
}

/// Returns `true` if this column's logical type is a UTF-8 string.
fn is_utf8(&self) -> bool {
self.get_descriptor().logical_type() == Some(LogicalType::String)
|| self.get_descriptor().converted_type() == ConvertedType::UTF8
}

/// Truncates a binary statistic to at most `truncation_length` bytes.
///
/// If truncation is not possible, returns `data`.
///
/// The `bool` in the returned tuple indicates whether truncation occurred or not.
///
/// UTF-8 Note:
/// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will
/// also remain valid UTF-8, but may be less tnan `truncation_length` bytes to avoid splitting
/// on non-character boundaries.
fn truncate_min_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) {
truncation_length
.filter(|l| data.len() > *l)
.and_then(|l| match str::from_utf8(data) {
Ok(str_data) => truncate_utf8(str_data, l),
Err(_) => Some(data[..l].to_vec()),
})
.and_then(|l|
// don't do extra work if this column isn't UTF-8
if self.is_utf8() {
match str::from_utf8(data) {
Ok(str_data) => truncate_utf8(str_data, l),
Err(_) => Some(data[..l].to_vec()),
}
} else {
Some(data[..l].to_vec())
}
)
.map(|truncated| (truncated, true))
.unwrap_or_else(|| (data.to_vec(), false))
}

/// Truncates a binary statistic to at most `truncation_length` bytes, and then increment the
/// final byte(s) to yield a valid upper bound. This may result in a result of less than
/// `truncation_length` bytes if the last byte(s) overflows.
///
/// If truncation is not possible, returns `data`.
///
/// The `bool` in the returned tuple indicates whether truncation occurred or not.
///
/// UTF-8 Note:
/// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will
/// also remain valid UTF-8 (but again may be less than `truncation_length` bytes). If `data`
/// does not contain valid UTF-8, then truncation will occur as if the column is non-string
/// binary.
fn truncate_max_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) {
truncation_length
.filter(|l| data.len() > *l)
.and_then(|l| match str::from_utf8(data) {
Ok(str_data) => truncate_utf8(str_data, l).and_then(increment_utf8),
Err(_) => increment(data[..l].to_vec()),
})
.and_then(|l|
// don't do extra work if this column isn't UTF-8
if self.is_utf8() {
match str::from_utf8(data) {
Ok(str_data) => truncate_and_increment_utf8(str_data, l),
Err(_) => increment(data[..l].to_vec()),
}
} else {
increment(data[..l].to_vec())
}
)
.map(|truncated| (truncated, true))
.unwrap_or_else(|| (data.to_vec(), false))
}
Expand Down Expand Up @@ -1418,13 +1461,50 @@ fn compare_greater_byte_array_decimals(a: &[u8], b: &[u8]) -> bool {
(a[1..]) > (b[1..])
}

/// Truncate a UTF8 slice to the longest prefix that is still a valid UTF8 string,
/// while being less than `length` bytes and non-empty
/// Truncate a UTF-8 slice to the longest prefix that is still a valid UTF-8 string,
/// while being less than `length` bytes and non-empty. Returns `None` if truncation
/// is not possible within those constraints.
///
/// The caller guarantees that data.len() > length.
fn truncate_utf8(data: &str, length: usize) -> Option<Vec<u8>> {
let split = (1..=length).rfind(|x| data.is_char_boundary(*x))?;
Some(data.as_bytes()[..split].to_vec())
}

/// Truncate a UTF-8 slice and increment it's final character. The returned value is the
/// longest such slice that is still a valid UTF-8 string while being less than `length`
/// bytes and non-empty. Returns `None` if no such transformation is possible.
///
/// The caller guarantees that data.len() > length.
fn truncate_and_increment_utf8(data: &str, length: usize) -> Option<Vec<u8>> {
// UTF-8 is max 4 bytes, so start search 3 back from desired length
let lower_bound = length.saturating_sub(3);
let split = (lower_bound..=length).rfind(|x| data.is_char_boundary(*x))?;
increment_utf8(data.get(..split)?)
}

/// Increment the final character in a UTF-8 string in such a way that the returned result
/// is still a valid UTF-8 string. The returned string may be shorter than the input if the
/// last character(s) cannot be incremented (due to overflow or producing invalid code points).
/// Returns `None` if the string cannot be incremented.
///
/// Note that this implementation will not promote an N-byte code point to (N+1) bytes.
fn increment_utf8(data: &str) -> Option<Vec<u8>> {
for (idx, original_char) in data.char_indices().rev() {
let original_len = original_char.len_utf8();
if let Some(next_char) = char::from_u32(original_char as u32 + 1) {
// do not allow increasing byte width of incremented char
if next_char.len_utf8() == original_len {
let mut result = data.as_bytes()[..idx + original_len].to_vec();
next_char.encode_utf8(&mut result[idx..]);
return Some(result);
}
}
}

None
}

/// Try and increment the bytes from right to left.
///
/// Returns `None` if all bytes are set to `u8::MAX`.
Expand All @@ -1441,29 +1521,15 @@ fn increment(mut data: Vec<u8>) -> Option<Vec<u8>> {
None
}

/// Try and increment the the string's bytes from right to left, returning when the result
/// is a valid UTF8 string. Returns `None` when it can't increment any byte.
fn increment_utf8(mut data: Vec<u8>) -> Option<Vec<u8>> {
for idx in (0..data.len()).rev() {
let original = data[idx];
let (byte, overflow) = original.overflowing_add(1);
if !overflow {
data[idx] = byte;
if str::from_utf8(&data).is_ok() {
return Some(data);
}
data[idx] = original;
}
}

None
}

#[cfg(test)]
mod tests {
use crate::file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH;
use crate::{
file::{properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, writer::SerializedFileWriter},
schema::parser::parse_message_type,
};
use core::str;
use rand::distributions::uniform::SampleUniform;
use std::sync::Arc;
use std::{fs::File, sync::Arc};

use crate::column::{
page::PageReader,
Expand Down Expand Up @@ -3140,39 +3206,69 @@ mod tests {

#[test]
fn test_increment_utf8() {
let test_inc = |o: &str, expected: &str| {
if let Ok(v) = String::from_utf8(increment_utf8(o).unwrap()) {
// Got the expected result...
assert_eq!(v, expected);
// and it's greater than the original string
assert!(*v > *o);
// Also show that BinaryArray level comparison works here
let mut greater = ByteArray::new();
greater.set_data(Bytes::from(v));
let mut original = ByteArray::new();
original.set_data(Bytes::from(o.as_bytes().to_vec()));
assert!(greater > original);
} else {
panic!("Expected incremented UTF8 string to also be valid.");
}
};

// Basic ASCII case
let v = increment_utf8("hello".as_bytes().to_vec()).unwrap();
assert_eq!(&v, "hellp".as_bytes());
test_inc("hello", "hellp");

// 1-byte ending in max 1-byte
test_inc("a\u{7f}", "b");

// Also show that BinaryArray level comparison works here
let mut greater = ByteArray::new();
greater.set_data(Bytes::from(v));
let mut original = ByteArray::new();
original.set_data(Bytes::from("hello".as_bytes().to_vec()));
assert!(greater > original);
// 1-byte max should not truncate as it would need 2-byte code points
assert!(increment_utf8("\u{7f}\u{7f}").is_none());

// UTF8 string
let s = "❤️🧡💛💚💙💜";
let v = increment_utf8(s.as_bytes().to_vec()).unwrap();
test_inc("❤️🧡💛💚💙💜", "❤️🧡💛💚💙💝");

if let Ok(new) = String::from_utf8(v) {
assert_ne!(&new, s);
assert_eq!(new, "❤️🧡💛💚💙💝");
assert!(new.as_bytes().last().unwrap() > s.as_bytes().last().unwrap());
} else {
panic!("Expected incremented UTF8 string to also be valid.")
}
// 2-byte without overflow
test_inc("éééé", "éééê");

// Max UTF8 character - should be a No-Op
let s = char::MAX.to_string();
assert_eq!(s.len(), 4);
let v = increment_utf8(s.as_bytes().to_vec());
assert!(v.is_none());
// 2-byte that overflows lowest byte
test_inc("\u{ff}\u{ff}", "\u{ff}\u{100}");

// 2-byte ending in max 2-byte
test_inc("a\u{7ff}", "b");

// Max 2-byte should not truncate as it would need 3-byte code points
assert!(increment_utf8("\u{7ff}\u{7ff}").is_none());

// 3-byte without overflow [U+800, U+800] -> [U+800, U+801] (note that these
// characters should render right to left).
test_inc("ࠀࠀ", "ࠀࠁ");

// 3-byte ending in max 3-byte
test_inc("a\u{ffff}", "b");

// Max 3-byte should not truncate as it would need 4-byte code points
assert!(increment_utf8("\u{ffff}\u{ffff}").is_none());

// Handle multi-byte UTF8 characters
let s = "a\u{10ffff}";
let v = increment_utf8(s.as_bytes().to_vec());
assert_eq!(&v.unwrap(), "b\u{10ffff}".as_bytes());
// 4-byte without overflow
test_inc("𐀀𐀀", "𐀀𐀁");

// 4-byte ending in max unicode
test_inc("a\u{10ffff}", "b");

// Max 4-byte should not truncate
assert!(increment_utf8("\u{10ffff}\u{10ffff}").is_none());

// Skip over surrogate pair range (0xD800..=0xDFFF)
//test_inc("a\u{D7FF}", "a\u{e000}");
test_inc("a\u{D7FF}", "b");
}

#[test]
Expand All @@ -3182,7 +3278,6 @@ mod tests {
let r = truncate_utf8(data, data.as_bytes().len()).unwrap();
assert_eq!(r.len(), data.as_bytes().len());
assert_eq!(&r, data.as_bytes());
println!("len is {}", data.len());

// We slice it away from the UTF8 boundary
let r = truncate_utf8(data, 13).unwrap();
Expand All @@ -3192,6 +3287,90 @@ mod tests {
// One multi-byte code point, and a length shorter than it, so we can't slice it
let r = truncate_utf8("\u{0836}", 1);
assert!(r.is_none());

// Test truncate and increment for max bounds on UTF-8 statistics
// 7-bit (i.e. ASCII)
let r = truncate_and_increment_utf8("yyyyyyyyy", 8).unwrap();
assert_eq!(&r, "yyyyyyyz".as_bytes());

// 2-byte without overflow
let r = truncate_and_increment_utf8("ééééé", 7).unwrap();
assert_eq!(&r, "ééê".as_bytes());

// 2-byte that overflows lowest byte
let r = truncate_and_increment_utf8("\u{ff}\u{ff}\u{ff}\u{ff}\u{ff}", 8).unwrap();
assert_eq!(&r, "\u{ff}\u{ff}\u{ff}\u{100}".as_bytes());

// max 2-byte should not truncate as it would need 3-byte code points
let r = truncate_and_increment_utf8("߿߿߿߿߿", 8);
assert!(r.is_none());

// 3-byte without overflow [U+800, U+800, U+800] -> [U+800, U+801] (note that these
// characters should render right to left).
let r = truncate_and_increment_utf8("ࠀࠀࠀࠀ", 8).unwrap();
assert_eq!(&r, "ࠀࠁ".as_bytes());

// max 3-byte should not truncate as it would need 4-byte code points
let r = truncate_and_increment_utf8("\u{ffff}\u{ffff}\u{ffff}", 8);
assert!(r.is_none());

// 4-byte without overflow
let r = truncate_and_increment_utf8("𐀀𐀀𐀀𐀀", 9).unwrap();
assert_eq!(&r, "𐀀𐀁".as_bytes());

// max 4-byte should not truncate
let r = truncate_and_increment_utf8("\u{10ffff}\u{10ffff}", 8);
assert!(r.is_none());
}

#[test]
// Check fallback truncation of statistics that should be UTF-8, but aren't
// (see https://github.com/apache/arrow-rs/pull/6870).
fn test_byte_array_truncate_invalid_utf8_statistics() {
let message_type = "
message test_schema {
OPTIONAL BYTE_ARRAY a (UTF8);
}
";
let schema = Arc::new(parse_message_type(message_type).unwrap());

// Create Vec<ByteArray> containing non-UTF8 bytes
let data = vec![ByteArray::from(vec![128u8; 32]); 7];
let def_levels = [1, 1, 1, 1, 0, 1, 0, 1, 0, 1];
let file: File = tempfile::tempfile().unwrap();
let props = Arc::new(
WriterProperties::builder()
.set_statistics_enabled(EnabledStatistics::Chunk)
.set_statistics_truncate_length(Some(8))
.build(),
);

let mut writer = SerializedFileWriter::new(&file, schema, props).unwrap();
let mut row_group_writer = writer.next_row_group().unwrap();

let mut col_writer = row_group_writer.next_column().unwrap().unwrap();
col_writer
.typed::<ByteArrayType>()
.write_batch(&data, Some(&def_levels), None)
.unwrap();
col_writer.close().unwrap();
row_group_writer.close().unwrap();
let file_metadata = writer.close().unwrap();
assert!(file_metadata.row_groups[0].columns[0].meta_data.is_some());
let stats = file_metadata.row_groups[0].columns[0]
.meta_data
.as_ref()
.unwrap()
.statistics
.as_ref()
.unwrap();
assert!(!stats.is_max_value_exact.unwrap());
// Truncation of invalid UTF-8 should fall back to binary truncation, so last byte should
// be incremented by 1.
assert_eq!(
stats.max_value,
Some([128, 128, 128, 128, 128, 128, 128, 129].to_vec())
);
}

#[test]
Expand Down

0 comments on commit 9ffa065

Please sign in to comment.