Skip to content

Commit

Permalink
rework entirely, also avoid UTF8 processing if not required by the sc…
Browse files Browse the repository at this point in the history
…hema
  • Loading branch information
etseidl committed Dec 13, 2024
1 parent c4d9474 commit 7a7fd0e
Showing 1 changed file with 86 additions and 65 deletions.
151 changes: 86 additions & 65 deletions parquet/src/column/writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,24 +878,44 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
}
}

/// Returns `true` if this column's logical type is a UTF-8 string.
fn is_utf8(&self) -> bool {
self.get_descriptor().logical_type() == Some(LogicalType::String)
|| self.get_descriptor().converted_type() == ConvertedType::UTF8
}

fn truncate_min_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) {
truncation_length
.filter(|l| data.len() > *l)
.and_then(|l| match str::from_utf8(data) {
Ok(str_data) => truncate_utf8(str_data, l),
Err(_) => Some(data[..l].to_vec()),
})
.and_then(|l|
// don't do extra work if this column isn't UTF-8
if self.is_utf8() {
match str::from_utf8(data) {
Ok(str_data) => truncate_utf8(str_data, l),
Err(_) => Some(data[..l].to_vec()),
}
} else {
Some(data[..l].to_vec())
}
)
.map(|truncated| (truncated, true))
.unwrap_or_else(|| (data.to_vec(), false))
}

fn truncate_max_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) {
truncation_length
.filter(|l| data.len() > *l)
.and_then(|l| match str::from_utf8(data) {
Ok(str_data) => truncate_utf8(str_data, l).and_then(increment_utf8),
Err(_) => increment(data[..l].to_vec()),
})
.and_then(|l|
// don't do extra work if this column isn't UTF-8
if self.is_utf8() {
match str::from_utf8(data) {
Ok(str_data) => truncate_and_increment_utf8(str_data, l),
Err(_) => increment(data[..l].to_vec()),
}
} else {
increment(data[..l].to_vec())
}
)
.map(|truncated| (truncated, true))
.unwrap_or_else(|| (data.to_vec(), false))
}
Expand Down Expand Up @@ -1418,57 +1438,61 @@ fn compare_greater_byte_array_decimals(a: &[u8], b: &[u8]) -> bool {
(a[1..]) > (b[1..])
}

/// Truncate a UTF8 slice to the longest prefix that is still a valid UTF8 string,
/// while being less than `length` bytes and non-empty
/// Truncate a UTF-8 slice to the longest prefix that is still a valid UTF-8 string,
/// while being less than `length` bytes and non-empty. Returns `None` if truncation
/// is not possible within those constraints.
///
/// The caller guarantees that data.len() > length.
fn truncate_utf8(data: &str, length: usize) -> Option<Vec<u8>> {
let split = (1..=length).rfind(|x| data.is_char_boundary(*x))?;
Some(data.as_bytes()[..split].to_vec())
}

/// Try and increment the bytes from right to left.
/// Truncate a UTF-8 slice and increment it's final character. The returned value is the
/// longest such slice that is still a valid UTF-8 string while being less than `length`
/// bytes and non-empty. Returns `None` if no such transformation is possible.
///
/// Returns `None` if all bytes are set to `u8::MAX`.
fn increment(mut data: Vec<u8>) -> Option<Vec<u8>> {
for byte in data.iter_mut().rev() {
let (incremented, overflow) = byte.overflowing_add(1);
*byte = incremented;
/// The caller guarantees that data.len() > length.
fn truncate_and_increment_utf8(data: &str, length: usize) -> Option<Vec<u8>> {
// UTF-8 is max 4 bytes, so start search 3 back from desired length
let lower_bound = length.saturating_sub(3);
let split = (lower_bound..=length).rfind(|x| data.is_char_boundary(*x))?;
increment_utf8(data.get(..split)?)
}

if !overflow {
return Some(data);
/// Increment the final character in a UTF-8 string in such a way that the returned result
/// is still a valid UTF-8 string. The returned string may be shorter than the input if the
/// last character(s) cannot be incremented (due to overflow or producing invalid code points).
/// Returns `None` if the string cannot be incremented.
///
/// Note that this implementation will not promote an N-byte code point to (N+1) bytes.
fn increment_utf8(data: &str) -> Option<Vec<u8>> {
for (idx, code_point) in data.char_indices().rev() {
let curr_len = code_point.len_utf8();
let original = code_point as u32;
if let Some(next_char) = char::from_u32(original + 1) {
// do not allow increasing byte width of incremented char
if next_char.len_utf8() == curr_len {
let mut result = data.as_bytes()[..idx + curr_len].to_vec();
next_char.encode_utf8(&mut result[idx..]);
return Some(result);
}
}
}

None
}

/// Try and increment the the string's bytes from right to left, returning when the result
/// is a valid UTF8 string. Returns `None` when it can't increment any byte.
fn increment_utf8(mut data: Vec<u8>) -> Option<Vec<u8>> {
const UTF8_CONTINUATION: u8 = 0x80;
const UTF8_CONTINUATION_MASK: u8 = 0xc0;
/// Try and increment the bytes from right to left.
///
/// Returns `None` if all bytes are set to `u8::MAX`.
fn increment(mut data: Vec<u8>) -> Option<Vec<u8>> {
for byte in data.iter_mut().rev() {
let (incremented, overflow) = byte.overflowing_add(1);
*byte = incremented;

let mut len = data.len();
for idx in (0..data.len()).rev() {
let original = data[idx];
let (byte, overflow) = original.overflowing_add(1);
if !overflow {
data[idx] = byte;
if str::from_utf8(&data).is_ok() {
if len != data.len() {
data.truncate(len);
}
return Some(data);
}
// Incrementing "original" did not yield a valid unicode character, so it overflowed
// its available bits. If it was a continuation byte (b10xxxxxx) then set to min
// continuation (b10000000). Otherwise it was the first byte so set reset the first
// byte back to its original value (so data remains a valid string) and reduce "len".
if original & UTF8_CONTINUATION_MASK == UTF8_CONTINUATION {
data[idx] = UTF8_CONTINUATION;
} else {
data[idx] = original;
len = idx;
}
return Some(data);
}
}

Expand Down Expand Up @@ -3158,7 +3182,7 @@ mod tests {
#[test]
fn test_increment_utf8() {
let test_inc = |o: &str, expected: &str| {
if let Ok(v) = String::from_utf8(increment_utf8(o.as_bytes().to_vec()).unwrap()) {
if let Ok(v) = String::from_utf8(increment_utf8(o).unwrap()) {
// Got the expected result...
assert_eq!(v, expected);
// and it's greater than the original string
Expand All @@ -3181,7 +3205,7 @@ mod tests {
test_inc("a\u{7f}", "b");

// 1-byte max should not truncate as it would need 2-byte code points
assert!(increment_utf8("\u{7f}\u{7f}".as_bytes().to_vec()).is_none());
assert!(increment_utf8("\u{7f}\u{7f}").is_none());

// UTF8 string
test_inc("❤️🧡💛💚💙💜", "❤️🧡💛💚💙💝");
Expand All @@ -3196,7 +3220,7 @@ mod tests {
test_inc("a\u{7ff}", "b");

// Max 2-byte should not truncate as it would need 3-byte code points
assert!(increment_utf8("\u{7ff}\u{7ff}".as_bytes().to_vec()).is_none());
assert!(increment_utf8("\u{7ff}\u{7ff}").is_none());

// 3-byte without overflow [U+800, U+800] -> [U+800, U+801] (note that these
// characters should render right to left).
Expand All @@ -3206,7 +3230,7 @@ mod tests {
test_inc("a\u{ffff}", "b");

// Max 3-byte should not truncate as it would need 4-byte code points
assert!(increment_utf8("\u{ffff}\u{ffff}".as_bytes().to_vec()).is_none());
assert!(increment_utf8("\u{ffff}\u{ffff}").is_none());

// 4-byte without overflow
test_inc("𐀀𐀀", "𐀀𐀁");
Expand All @@ -3215,10 +3239,11 @@ mod tests {
test_inc("a\u{10ffff}", "b");

// Max 4-byte should not truncate
assert!(increment_utf8("\u{10ffff}\u{10ffff}".as_bytes().to_vec()).is_none());
assert!(increment_utf8("\u{10ffff}\u{10ffff}").is_none());

// Skip over surrogate pair range (0xD800..=0xDFFF)
test_inc("a\u{D7FF}", "a\u{e000}");
//test_inc("a\u{D7FF}", "a\u{e000}");
test_inc("a\u{D7FF}", "b");
}

#[test]
Expand All @@ -3238,42 +3263,38 @@ mod tests {
let r = truncate_utf8("\u{0836}", 1);
assert!(r.is_none());

// test truncate and increment for max bounds on utf-8 stats
// Test truncate and increment for max bounds on UTF-8 statistics
// 7-bit (i.e. ASCII)
let r = truncate_utf8("yyyyyyyyy", 8)
.and_then(increment_utf8)
.unwrap();
let r = truncate_and_increment_utf8("yyyyyyyyy", 8).unwrap();
assert_eq!(&r, "yyyyyyyz".as_bytes());

// 2-byte without overflow
let r = truncate_utf8("ééééé", 8).and_then(increment_utf8).unwrap();
let r = truncate_and_increment_utf8("ééééé", 8).unwrap();
assert_eq!(&r, "éééê".as_bytes());

// 2-byte that overflows lowest byte
let r = truncate_utf8("\u{ff}\u{ff}\u{ff}\u{ff}\u{ff}", 8)
.and_then(increment_utf8)
.unwrap();
let r = truncate_and_increment_utf8("\u{ff}\u{ff}\u{ff}\u{ff}\u{ff}", 8).unwrap();
assert_eq!(&r, "\u{ff}\u{ff}\u{ff}\u{100}".as_bytes());

// max 2-byte should not truncate as it would need 3-byte code points
let r = truncate_utf8("߿߿߿߿߿", 8).and_then(increment_utf8);
let r = truncate_and_increment_utf8("߿߿߿߿߿", 8);
assert!(r.is_none());

// 3-byte without overflow
let r = truncate_utf8("ࠀࠀࠀ", 8).and_then(increment_utf8).unwrap();
// 3-byte without overflow [U+800, U+800, U+800] -> [U+800, U+801] (note that these
// characters should render right to left).
let r = truncate_and_increment_utf8("ࠀࠀࠀ", 8).unwrap();
assert_eq!(&r, "ࠀࠁ".as_bytes());
assert_eq!(r.len(), 6);

// max 3-byte should not truncate as it would need 4-byte code points
let r = truncate_utf8("\u{ffff}\u{ffff}\u{ffff}", 8).and_then(increment_utf8);
let r = truncate_and_increment_utf8("\u{ffff}\u{ffff}\u{ffff}", 8);
assert!(r.is_none());

// 4-byte without overflow
let r = truncate_utf8("𐀀𐀀𐀀", 8).and_then(increment_utf8).unwrap();
let r = truncate_and_increment_utf8("𐀀𐀀𐀀", 8).unwrap();
assert_eq!(&r, "𐀀𐀁".as_bytes());

// max 4-byte should not truncate
let r = truncate_utf8("\u{10ffff}\u{10ffff}", 8).and_then(increment_utf8);
let r = truncate_and_increment_utf8("\u{10ffff}\u{10ffff}", 8);
assert!(r.is_none());
}

Expand Down

0 comments on commit 7a7fd0e

Please sign in to comment.