Skip to content

Commit

Permalink
Specialize ASCII case for substr() (apache#12444)
Browse files Browse the repository at this point in the history
* Specialize ASCII case for substr()

* cleanup + don't validate ASCII for short prefix
  • Loading branch information
2010YOUY01 authored Sep 17, 2024
1 parent 269a473 commit 55707dc
Showing 1 changed file with 122 additions and 24 deletions.
146 changes: 122 additions & 24 deletions datafusion/functions/src/unicode/substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
// under the License.

use std::any::Any;
use std::cmp::max;
use std::sync::Arc;

use crate::string::common::StringArrayType;
use crate::utils::{make_scalar_function, utf8_to_str_type};
use arrow::array::{
make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView,
GenericStringArray, OffsetSizeTrait, StringViewArray,
make_view, Array, ArrayIter, ArrayRef, AsArray, ByteView, GenericStringArray,
Int64Array, OffsetSizeTrait, StringViewArray,
};
use arrow::datatypes::DataType;
use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
use datafusion_common::cast::as_int64_array;
use datafusion_common::{exec_datafusion_err, exec_err, Result};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

Expand Down Expand Up @@ -119,19 +119,27 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
}

// Convert the given `start` and `count` to valid byte indices within `input` string
//
// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)`
// `start` is 1-based, if `count` is not provided count to the end of the string
// Input indices are character-based, and return values are byte indices
// The input bounds can be outside string bounds, this function will return
// the intersection between input bounds and valid string bounds
// `input_ascii_only` is used to optimize this function if `input` is ASCII-only
//
// * Example
// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx]
// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)`
// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)`
// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)`
fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, usize) {
let start = start - 1;
fn get_true_start_end(
input: &str,
start: i64,
count: Option<u64>,
is_input_ascii_only: bool,
) -> (usize, usize) {
let start = start.checked_sub(1).unwrap_or(start);

let end = match count {
Some(count) => start + count as i64,
None => input.len() as i64,
Expand All @@ -142,6 +150,14 @@ fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, us
let end = end.clamp(0, input.len() as i64) as usize;
let count = end - start;

// If input is ASCII-only, byte-based indices equals to char-based indices
if is_input_ascii_only {
return (start, end);
}

// Otherwise, calculate byte indices from char indices
// Note this decoding is relatively expensive for this simple `substr` function,,
// so the implementation attempts to decode in one pass (and caused the complexity)
let (mut st, mut ed) = (input.len(), input.len());
let mut start_counting = false;
let mut cnt = 0;
Expand Down Expand Up @@ -186,6 +202,53 @@ fn make_and_append_view(
null_builder.append_non_null();
}

// String characters are variable length encoded in UTF-8, `substr()` function's
// arguments are character-based, converting them into byte-based indices
// requires expensive decoding.
// However, checking if a string is ASCII-only is relatively cheap.
// If strings are ASCII only, use byte-based indices instead.
//
// A common pattern to call `substr()` is taking a small prefix of a long
// string, such as `substr(long_str_with_1k_chars, 1, 32)`.
// In such case the overhead of ASCII-validation may not be worth it, so
// skip the validation for short prefix for now.
fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
string_array: &V,
start: &Int64Array,
count: Option<&Int64Array>,
) -> bool {
let is_short_prefix = match count {
Some(count) => {
let short_prefix_threshold = 32.0;
let n_sample = 10;

// HACK: can be simplified if function has specialized
// implementation for `ScalarValue` (implement without `make_scalar_function()`)
let avg_prefix_len = start
.iter()
.zip(count.iter())
.take(n_sample)
.map(|(start, count)| {
let start = start.unwrap_or(0);
let count = count.unwrap_or(0);
// To get substring, need to decode from 0 to start+count instead of start to start+count
start + count
})
.sum::<i64>();

avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold
}
None => false,
};

if is_short_prefix {
// Skip ASCII validation for short prefix
false
} else {
string_array.is_ascii()
}
}

// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
// From<u128> for ByteView
fn string_view_substr(
Expand All @@ -196,6 +259,14 @@ fn string_view_substr(
let mut null_builder = NullBufferBuilder::new(string_view_array.len());

let start_array = as_int64_array(&args[0])?;
let count_array_opt = if args.len() == 2 {
Some(as_int64_array(&args[1])?)
} else {
None
};

let enable_ascii_fast_path =
enable_ascii_fast_path(&string_view_array, start_array, count_array_opt);

// In either case of `substr(s, i)` or `substr(s, i, cnt)`
// If any of input argument is `NULL`, the result is `NULL`
Expand All @@ -207,7 +278,8 @@ fn string_view_substr(
.zip(start_array.iter())
{
if let (Some(str), Some(start)) = (str_opt, start_opt) {
let (start, end) = get_true_start_end(str, start, None);
let (start, end) =
get_true_start_end(str, start, None, enable_ascii_fast_path);
let substr = &str[start..end];

make_and_append_view(
Expand All @@ -224,7 +296,7 @@ fn string_view_substr(
}
}
2 => {
let count_array = as_int64_array(&args[1])?;
let count_array = count_array_opt.unwrap();
for (((str_opt, raw_view), start_opt), count_opt) in string_view_array
.iter()
.zip(string_view_array.views().iter())
Expand All @@ -239,8 +311,17 @@ fn string_view_substr(
"negative substring length not allowed: substr(<str>, {start}, {count})"
);
} else {
let (start, end) =
get_true_start_end(str, start, Some(count as u64));
if start == i64::MIN {
return exec_err!(
"negative overflow when calculating skip value"
);
}
let (start, end) = get_true_start_end(
str,
start,
Some(count as u64),
enable_ascii_fast_path,
);
let substr = &str[start..end];

make_and_append_view(
Expand Down Expand Up @@ -283,23 +364,35 @@ fn string_view_substr(

fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
V: StringArrayType<'a>,
T: OffsetSizeTrait,
{
let start_array = as_int64_array(&args[0])?;
let count_array_opt = if args.len() == 2 {
Some(as_int64_array(&args[1])?)
} else {
None
};

let enable_ascii_fast_path =
enable_ascii_fast_path(&string_array, start_array, count_array_opt);

match args.len() {
1 => {
let iter = ArrayIter::new(string_array);
let start_array = as_int64_array(&args[0])?;

let result = iter
.zip(start_array.iter())
.map(|(string, start)| match (string, start) {
(Some(string), Some(start)) => {
if start <= 0 {
Some(string.to_string())
} else {
Some(string.chars().skip(start as usize - 1).collect())
}
let (start, end) = get_true_start_end(
string,
start,
None,
enable_ascii_fast_path,
); // start, end is byte-based
let substr = &string[start..end];
Some(substr.to_string())
}
_ => None,
})
Expand All @@ -308,8 +401,7 @@ where
}
2 => {
let iter = ArrayIter::new(string_array);
let start_array = as_int64_array(&args[0])?;
let count_array = as_int64_array(&args[1])?;
let count_array = count_array_opt.unwrap();

let result = iter
.zip(start_array.iter())
Expand All @@ -322,11 +414,17 @@ where
"negative substring length not allowed: substr(<str>, {start}, {count})"
)
} else {
let skip = max(0, start.checked_sub(1).ok_or_else(
|| exec_datafusion_err!("negative overflow when calculating skip value")
)?);
let count = max(0, count + (if start < 1 { start - 1 } else { 0 }));
Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::<String>()))
if start == i64::MIN {
return exec_err!("negative overflow when calculating skip value")
}
let (start, end) = get_true_start_end(
string,
start,
Some(count as u64),
enable_ascii_fast_path,
); // start, end is byte-based
let substr = &string[start..end];
Ok(Some(substr.to_string()))
}
}
_ => Ok(None),
Expand Down

0 comments on commit 55707dc

Please sign in to comment.