Skip to content

Commit

Permalink
Implement native support StringViewArray for regexp_is_match and `r…
Browse files Browse the repository at this point in the history
…egexp_is_match_scalar` function, deprecate `regexp_is_match_utf8` and `regexp_is_match_utf8_scalar` (#6376)

* Implement native support StringViewArray for regex_is_match function

* Update test cases cover StringViewArray length more then 12 bytes

* Add StringView benchmark for regexp_is_match

Signed-off-by: Tai Le Manh <[email protected]>

* Implement native support StringViewArray for regex_is_match function

Signed-off-by: Tai Le Manh <[email protected]>

* Remove duplicate implementation, fix clippy, add docs

more

---------

Signed-off-by: Tai Le Manh <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
tlm365 and alamb authored Sep 21, 2024
1 parent c90713b commit d05cf6d
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 39 deletions.
2 changes: 1 addition & 1 deletion arrow-string/src/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ fn like_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, Arr
///
/// This trait helps to abstract over the different types of string arrays
/// so that we don't need to duplicate the implementation for each type.
trait StringArrayType<'a>: ArrayAccessor<Item = &'a str> + Sized {
pub trait StringArrayType<'a>: ArrayAccessor<Item = &'a str> + Sized {
fn is_ascii(&self) -> bool;
fn iter(&self) -> ArrayIter<Self>;
}
Expand Down
228 changes: 203 additions & 25 deletions arrow-string/src/regexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
//! Defines kernel to extract substrings based on a regular
//! expression of a \[Large\]StringArray
use crate::like::StringArrayType;

use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, ListBuilder};
use arrow_array::cast::AsArray;
use arrow_array::*;
use arrow_buffer::NullBuffer;
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, Field};
use regex::Regex;

use std::collections::HashMap;
use std::sync::Arc;

Expand All @@ -35,16 +38,64 @@ use std::sync::Arc;
/// special search modes, such as case insensitive and multi-line mode.
/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
/// for more information.
#[deprecated(since = "54.0.0", note = "please use `regex_is_match` instead")]
pub fn regexp_is_match_utf8<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex_array: &GenericStringArray<OffsetSize>,
flags_array: Option<&GenericStringArray<OffsetSize>>,
) -> Result<BooleanArray, ArrowError> {
regexp_is_match(array, regex_array, flags_array)
}

/// Return BooleanArray indicating which strings in an array match an array of
/// regular expressions.
///
/// This is equivalent to the SQL `array ~ regex_array`, supporting
/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`].
///
/// If `regex_array` element has an empty value, the corresponding result value is always true.
///
/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag,
/// which allow special search modes, such as case-insensitive and multi-line mode.
/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
/// for more information.
///
/// # See Also
/// * [`regexp_is_match_scalar`] for matching a single regular expression against an array of strings
/// * [`regexp_match`] for extracting groups from a string array based on a regular expression
///
/// # Example
/// ```
/// # use arrow_array::{StringArray, BooleanArray};
/// # use arrow_string::regexp::regexp_is_match;
/// // First array is the array of strings to match
/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]);
/// // Second array is the array of regular expressions to match against
/// let regex_array = StringArray::from(vec!["^Foo", "^Foo", "Bar$", "Baz"]);
/// // Third array is the array of flags to use for each regular expression, if desired
/// // (the type must be provided to satisfy type inference for the third parameter)
/// let flags_array: Option<&StringArray> = None;
/// // The result is a BooleanArray indicating when each string in `array`
/// // matches the corresponding regular expression in `regex_array`
/// let result = regexp_is_match(&array, &regex_array, flags_array).unwrap();
/// assert_eq!(result, BooleanArray::from(vec![true, false, true, true]));
/// ```
pub fn regexp_is_match<'a, S1, S2, S3>(
array: &'a S1,
regex_array: &'a S2,
flags_array: Option<&'a S3>,
) -> Result<BooleanArray, ArrowError>
where
&'a S1: StringArrayType<'a>,
&'a S2: StringArrayType<'a>,
&'a S3: StringArrayType<'a>,
{
if array.len() != regex_array.len() {
return Err(ArrowError::ComputeError(
"Cannot perform comparison operation on arrays of different length".to_string(),
));
}

let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());

let mut patterns: HashMap<String, Regex> = HashMap::new();
Expand Down Expand Up @@ -107,25 +158,63 @@ pub fn regexp_is_match_utf8<OffsetSize: OffsetSizeTrait>(
.nulls(nulls)
.build_unchecked()
};

Ok(BooleanArray::from(data))
}

/// Perform SQL `array ~ regex_array` operation on [`StringArray`] /
/// [`LargeStringArray`] and a scalar.
///
/// See the documentation on [`regexp_is_match_utf8`] for more details.
#[deprecated(since = "54.0.0", note = "please use `regex_is_match_scalar` instead")]
pub fn regexp_is_match_utf8_scalar<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex: &str,
flag: Option<&str>,
) -> Result<BooleanArray, ArrowError> {
regexp_is_match_scalar(array, regex, flag)
}

/// Return BooleanArray indicating which strings in an array match a single regular expression.
///
/// This is equivalent to the SQL `array ~ regex_array`, supporting
/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] and a scalar.
///
/// See the documentation on [`regexp_is_match`] for more details on arguments
///
/// # See Also
/// * [`regexp_is_match`] for matching an array of regular expression against an array of strings
/// * [`regexp_match`] for extracting groups from a string array based on a regular expression
///
/// # Example
/// ```
/// # use arrow_array::{StringArray, BooleanArray};
/// # use arrow_string::regexp::regexp_is_match_scalar;
/// // array of strings to match
/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]);
/// let regexp = "^Foo"; // regular expression to match against
/// let flags: Option<&str> = None; // flags can control the matching behavior
/// // The result is a BooleanArray indicating when each string in `array`
/// // matches the regular expression `regexp`
/// let result = regexp_is_match_scalar(&array, regexp, None).unwrap();
/// assert_eq!(result, BooleanArray::from(vec![true, false, true, false]));
/// ```
pub fn regexp_is_match_scalar<'a, S>(
array: &'a S,
regex: &str,
flag: Option<&str>,
) -> Result<BooleanArray, ArrowError>
where
&'a S: StringArrayType<'a>,
{
let null_bit_buffer = array.nulls().map(|x| x.inner().sliced());
let mut result = BooleanBufferBuilder::new(array.len());

let pattern = match flag {
Some(flag) => format!("(?{flag}){regex}"),
None => regex.to_string(),
};

if pattern.is_empty() {
result.append_n(array.len(), true);
} else {
Expand All @@ -150,6 +239,7 @@ pub fn regexp_is_match_utf8_scalar<OffsetSize: OffsetSizeTrait>(
vec![],
)
};

Ok(BooleanArray::from(data))
}

Expand Down Expand Up @@ -303,6 +393,9 @@ fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
/// The flags parameter is an optional text string containing zero or more single-letter flags
/// that change the function's behavior.
///
/// # See Also
/// * [`regexp_is_match`] for matching (rather than extracting) a regular expression against an array of strings
///
/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
pub fn regexp_match(
array: &dyn Array,
Expand Down Expand Up @@ -517,8 +610,8 @@ mod tests {
($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
#[test]
fn $test_name() {
let left = StringArray::from($left);
let right = StringArray::from($right);
let left = $left;
let right = $right;
let res = $op(&left, &right, None).unwrap();
let expected = $expected;
assert_eq!(expected.len(), res.len());
Expand All @@ -531,9 +624,9 @@ mod tests {
($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
#[test]
fn $test_name() {
let left = StringArray::from($left);
let right = StringArray::from($right);
let flag = Some(StringArray::from($flag));
let left = $left;
let right = $right;
let flag = Some($flag);
let res = $op(&left, &right, flag.as_ref()).unwrap();
let expected = $expected;
assert_eq!(expected.len(), res.len());
Expand All @@ -549,7 +642,7 @@ mod tests {
($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
#[test]
fn $test_name() {
let left = StringArray::from($left);
let left = $left;
let res = $op(&left, $right, None).unwrap();
let expected = $expected;
assert_eq!(expected.len(), res.len());
Expand All @@ -569,7 +662,7 @@ mod tests {
($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
#[test]
fn $test_name() {
let left = StringArray::from($left);
let left = $left;
let flag = Some($flag);
let res = $op(&left, $right, flag).unwrap();
let expected = $expected;
Expand All @@ -590,41 +683,126 @@ mod tests {
}

test_flag_utf8!(
test_utf8_array_regexp_is_match,
vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"],
vec!["^ar", "^AR", "ow$", "OW$", "foo", ""],
regexp_is_match_utf8,
test_array_regexp_is_match_utf8,
StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
regexp_is_match::<StringArray, StringArray, StringArray>,
[true, false, true, false, false, true]
);
test_flag_utf8!(
test_utf8_array_regexp_is_match_insensitive,
vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"],
vec!["^ar", "^AR", "ow$", "OW$", "foo", ""],
vec!["i"; 6],
regexp_is_match_utf8,
test_array_regexp_is_match_utf8_insensitive,
StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
StringArray::from(vec!["i"; 6]),
regexp_is_match,
[true, true, true, true, false, true]
);

test_flag_utf8_scalar!(
test_utf8_array_regexp_is_match_scalar,
vec!["arrow", "ARROW", "parquet", "PARQUET"],
test_array_regexp_is_match_utf8_scalar,
StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
"^ar",
regexp_is_match_utf8_scalar,
regexp_is_match_scalar,
[true, false, false, false]
);
test_flag_utf8_scalar!(
test_utf8_array_regexp_is_match_empty_scalar,
vec!["arrow", "ARROW", "parquet", "PARQUET"],
test_array_regexp_is_match_utf8_scalar_empty,
StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
"",
regexp_is_match_utf8_scalar,
regexp_is_match_scalar,
[true, true, true, true]
);
test_flag_utf8_scalar!(
test_utf8_array_regexp_is_match_insensitive_scalar,
vec!["arrow", "ARROW", "parquet", "PARQUET"],
test_array_regexp_is_match_utf8_scalar_insensitive,
StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
"^ar",
"i",
regexp_is_match_utf8_scalar,
regexp_is_match_scalar,
[true, true, false, false]
);

test_flag_utf8!(
tes_array_regexp_is_match,
StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
[true, false, true, false, false, true]
);
test_flag_utf8!(
test_array_regexp_is_match_2,
StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
regexp_is_match::<StringViewArray, GenericStringArray<i32>, GenericStringArray<i32>>,
[true, false, true, false, false, true]
);
test_flag_utf8!(
test_array_regexp_is_match_insensitive,
StringViewArray::from(vec![
"Official Rust implementation of Apache Arrow",
"apache/arrow-rs",
"apache/arrow-rs",
"parquet",
"parquet",
"row",
"row",
]),
StringViewArray::from(vec![
".*rust implement.*",
"^ap",
"^AP",
"et$",
"ET$",
"foo",
""
]),
StringViewArray::from(vec!["i"; 7]),
regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
[true, true, true, true, true, false, true]
);
test_flag_utf8!(
test_array_regexp_is_match_insensitive_2,
LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
StringArray::from(vec!["i"; 6]),
regexp_is_match::<GenericStringArray<i64>, StringViewArray, GenericStringArray<i32>>,
[true, true, true, true, false, true]
);

test_flag_utf8_scalar!(
test_array_regexp_is_match_scalar,
StringViewArray::from(vec![
"apache/arrow-rs",
"APACHE/ARROW-RS",
"parquet",
"PARQUET",
]),
"^ap",
regexp_is_match_scalar::<StringViewArray>,
[true, false, false, false]
);
test_flag_utf8_scalar!(
test_array_regexp_is_match_scalar_empty,
StringViewArray::from(vec![
"apache/arrow-rs",
"APACHE/ARROW-RS",
"parquet",
"PARQUET",
]),
"",
regexp_is_match_scalar::<StringViewArray>,
[true, true, true, true]
);
test_flag_utf8_scalar!(
test_array_regexp_is_match_scalar_insensitive,
StringViewArray::from(vec![
"apache/arrow-rs",
"APACHE/ARROW-RS",
"parquet",
"PARQUET",
]),
"^ap",
"i",
regexp_is_match_scalar::<StringViewArray>,
[true, true, false, false]
);
}
Loading

0 comments on commit d05cf6d

Please sign in to comment.