From 65e683995116c4dca4099c094114a02c682e455e Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Wed, 18 Sep 2024 00:05:22 +0700 Subject: [PATCH] Implement native support StringViewArray for regex_is_match function Signed-off-by: Tai Le Manh --- arrow-string/src/regexp.rs | 239 +++++++++++++++++++++++++++---------- 1 file changed, 179 insertions(+), 60 deletions(-) diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index 4924e33df485..e24b79000662 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -29,6 +29,89 @@ use regex::Regex; use std::collections::HashMap; use std::sync::Arc; +/// Perform SQL `array ~ regex_array` operation on [`StringArray`] / [`LargeStringArray`]. +/// If `regex_array` element has an empty value, the corresponding result value is always true. +/// +/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] flag, which allow +/// special search modes, such as case insensitive and multi-line mode. +/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) +/// for more information. +#[deprecated(since = "54.0.0", note = "please use `regex_is_match` instead")] +pub fn regexp_is_match_utf8( + array: &GenericStringArray, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, +) -> Result { + if array.len() != regex_array.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length".to_string(), + )); + } + let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); + + let mut patterns: HashMap = HashMap::new(); + let mut result = BooleanBufferBuilder::new(array.len()); + + let complete_pattern = match flags_array { + Some(flags) => Box::new( + regex_array + .iter() + .zip(flags.iter()) + .map(|(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(flag) => format!("(?{flag}){pattern}"), + None => pattern.to_string(), + }) + }), + ) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + // Required for Postgres compatibility: + // SELECT 'foobarbequebaz' ~ ''); = true + (Some(_), Some(pattern)) if pattern == *"" => { + result.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.entry(pattern).or_insert(re) + } + }; + result.append(re.is_match(value)); + } + _ => result.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + + let data = unsafe { + ArrayDataBuilder::new(DataType::Boolean) + .len(array.len()) + .buffers(vec![result.into()]) + .nulls(nulls) + .build_unchecked() + }; + Ok(BooleanArray::from(data)) +} + /// Perform SQL `array ~ regex_array` operation on /// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. /// @@ -38,7 +121,7 @@ use std::sync::Arc; /// which allow special search modes, such as case-insensitive and multi-line mode. /// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) /// for more information. -pub fn regexp_is_match_utf8<'a, S1, S2, S3>( +pub fn regexp_is_match<'a, S1, S2, S3>( array: &'a S1, regex_array: &'a S2, flags_array: Option<&'a S3>, @@ -120,11 +203,55 @@ where Ok(BooleanArray::from(data)) } +/// Perform SQL `array ~ regex_array` operation on [`StringArray`] / +/// [`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`regexp_is_match_utf8`] for more details. +#[deprecated(since = "54.0.0", note = "please use `regex_is_match_scalar` instead")] +pub fn regexp_is_match_utf8_scalar( + array: &GenericStringArray, + regex: &str, + flag: Option<&str>, +) -> Result { + let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); + let mut result = BooleanBufferBuilder::new(array.len()); + + let pattern = match flag { + Some(flag) => format!("(?{flag}){regex}"), + None => regex.to_string(), + }; + if pattern.is_empty() { + result.append_n(array.len(), true); + } else { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}")) + })?; + for i in 0..array.len() { + let value = array.value(i); + result.append(re.is_match(value)); + } + } + + let buffer = result.into(); + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + array.len(), + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ) + }; + Ok(BooleanArray::from(data)) +} + /// Perform SQL `array ~ regex_array` operation on /// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] and a scalar. /// -/// See the documentation on [`regexp_is_match_utf8`] for more details. -pub fn regexp_is_match_utf8_scalar<'a, S>( +/// See the documentation on [`regexp_is_match`] for more details. +pub fn regexp_is_match_scalar<'a, S>( array: &'a S, regex: &str, flag: Option<&str>, @@ -163,6 +290,7 @@ where vec![], ) }; + Ok(BooleanArray::from(data)) } @@ -603,45 +731,60 @@ mod tests { } test_flag_utf8!( - test_utf8_array_regexp_is_match, + test_array_regexp_is_match_utf8, StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), - regexp_is_match_utf8::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >, + regexp_is_match_utf8, [true, false, true, false, false, true] ); test_flag_utf8!( - test_utf8_array_regexp_is_match_2, + test_array_regexp_is_match_utf8_insensitive, + StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), + StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), + StringArray::from(vec!["i"; 6]), + regexp_is_match_utf8, + [true, true, true, true, false, true] + ); + + test_flag_utf8_scalar!( + test_array_regexp_is_match_utf8_scalar, + StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), + "^ar", + regexp_is_match_utf8_scalar, + [true, false, false, false] + ); + test_flag_utf8_scalar!( + test_array_regexp_is_match_utf8_scalar_empty, + StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), + "", + regexp_is_match_utf8_scalar, + [true, true, true, true] + ); + test_flag_utf8_scalar!( + test_array_regexp_is_match_utf8_scalar_insensitive, + StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), + "^ar", + "i", + regexp_is_match_utf8_scalar, + [true, true, false, false] + ); + + test_flag_utf8!( + tes_array_regexp_is_match, StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), - regexp_is_match_utf8::, + regexp_is_match::, [true, false, true, false, false, true] ); test_flag_utf8!( - test_utf8_array_regexp_is_match_3, + test_array_regexp_is_match_2, StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), - regexp_is_match_utf8::, GenericStringArray>, + regexp_is_match::, GenericStringArray>, [true, false, true, false, false, true] ); - test_flag_utf8!( - test_utf8_array_regexp_is_match_insensitive, - StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), - StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), - StringArray::from(vec!["i"; 6]), - regexp_is_match_utf8::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >, - [true, true, true, true, false, true] - ); - test_flag_utf8!( - test_utf8_array_regexp_is_match_insensitive_2, + test_array_regexp_is_match_insensitive, StringViewArray::from(vec![ "Official Rust implementation of Apache Arrow", "apache/arrow-rs", @@ -661,27 +804,20 @@ mod tests { "" ]), StringViewArray::from(vec!["i"; 7]), - regexp_is_match_utf8::, + regexp_is_match::, [true, true, true, true, true, false, true] ); test_flag_utf8!( - test_utf8_array_regexp_is_match_insensitive_3, + test_array_regexp_is_match_insensitive_2, LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), StringArray::from(vec!["i"; 6]), - regexp_is_match_utf8::, StringViewArray, GenericStringArray>, + regexp_is_match::, StringViewArray, GenericStringArray>, [true, true, true, true, false, true] ); test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_scalar, - StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), - "^ar", - regexp_is_match_utf8_scalar::>, - [true, false, false, false] - ); - test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_scalar_2, + test_array_regexp_is_match_scalar, StringViewArray::from(vec![ "apache/arrow-rs", "APACHE/ARROW-RS", @@ -689,19 +825,11 @@ mod tests { "PARQUET", ]), "^ap", - regexp_is_match_utf8_scalar::, + regexp_is_match_scalar::, [true, false, false, false] ); - - test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_empty_scalar, - StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), - "", - regexp_is_match_utf8_scalar::>, - [true, true, true, true] - ); test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_empty_scalar_2, + test_array_regexp_is_match_scalar_empty, StringViewArray::from(vec![ "apache/arrow-rs", "APACHE/ARROW-RS", @@ -709,20 +837,11 @@ mod tests { "PARQUET", ]), "", - regexp_is_match_utf8_scalar::, + regexp_is_match_scalar::, [true, true, true, true] ); - - test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_insensitive_scalar, - StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), - "^ar", - "i", - regexp_is_match_utf8_scalar::>, - [true, true, false, false] - ); test_flag_utf8_scalar!( - test_utf8_array_regexp_is_match_insensitive_scalar_2, + test_array_regexp_is_match_scalar_insensitive, StringViewArray::from(vec![ "apache/arrow-rs", "APACHE/ARROW-RS", @@ -731,7 +850,7 @@ mod tests { ]), "^ap", "i", - regexp_is_match_utf8_scalar::, + regexp_is_match_scalar::, [true, true, false, false] ); }