Skip to content

Commit

Permalink
Implement native support StringViewArray for regex_is_match function
Browse files Browse the repository at this point in the history
Signed-off-by: Tai Le Manh <[email protected]>
  • Loading branch information
tlm365 committed Sep 17, 2024
1 parent e80deea commit 65e6839
Showing 1 changed file with 179 additions and 60 deletions.
239 changes: 179 additions & 60 deletions arrow-string/src/regexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,89 @@ use regex::Regex;
use std::collections::HashMap;
use std::sync::Arc;

/// Perform SQL `array ~ regex_array` operation on [`StringArray`] / [`LargeStringArray`].
/// If `regex_array` element has an empty value, the corresponding result value is always true.
///
/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] flag, which allow
/// special search modes, such as case insensitive and multi-line mode.
/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
/// for more information.
#[deprecated(since = "54.0.0", note = "please use `regex_is_match` instead")]
pub fn regexp_is_match_utf8<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex_array: &GenericStringArray<OffsetSize>,
flags_array: Option<&GenericStringArray<OffsetSize>>,
) -> Result<BooleanArray, ArrowError> {
if array.len() != regex_array.len() {
return Err(ArrowError::ComputeError(
"Cannot perform comparison operation on arrays of different length".to_string(),
));
}
let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());

let mut patterns: HashMap<String, Regex> = HashMap::new();
let mut result = BooleanBufferBuilder::new(array.len());

let complete_pattern = match flags_array {
Some(flags) => Box::new(
regex_array
.iter()
.zip(flags.iter())
.map(|(pattern, flags)| {
pattern.map(|pattern| match flags {
Some(flag) => format!("(?{flag}){pattern}"),
None => pattern.to_string(),
})
}),
) as Box<dyn Iterator<Item = Option<String>>>,
None => Box::new(
regex_array
.iter()
.map(|pattern| pattern.map(|pattern| pattern.to_string())),
),
};

array
.iter()
.zip(complete_pattern)
.map(|(value, pattern)| {
match (value, pattern) {
// Required for Postgres compatibility:
// SELECT 'foobarbequebaz' ~ ''); = true
(Some(_), Some(pattern)) if pattern == *"" => {
result.append(true);
}
(Some(value), Some(pattern)) => {
let existing_pattern = patterns.get(&pattern);
let re = match existing_pattern {
Some(re) => re,
None => {
let re = Regex::new(pattern.as_str()).map_err(|e| {
ArrowError::ComputeError(format!(
"Regular expression did not compile: {e:?}"
))
})?;
patterns.entry(pattern).or_insert(re)
}
};
result.append(re.is_match(value));
}
_ => result.append(false),
}
Ok(())
})
.collect::<Result<Vec<()>, ArrowError>>()?;

let data = unsafe {
ArrayDataBuilder::new(DataType::Boolean)
.len(array.len())
.buffers(vec![result.into()])
.nulls(nulls)
.build_unchecked()
};
Ok(BooleanArray::from(data))
}

/// Perform SQL `array ~ regex_array` operation on
/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`].
///
Expand All @@ -38,7 +121,7 @@ use std::sync::Arc;
/// which allow special search modes, such as case-insensitive and multi-line mode.
/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
/// for more information.
pub fn regexp_is_match_utf8<'a, S1, S2, S3>(
pub fn regexp_is_match<'a, S1, S2, S3>(
array: &'a S1,
regex_array: &'a S2,
flags_array: Option<&'a S3>,
Expand Down Expand Up @@ -120,11 +203,55 @@ where
Ok(BooleanArray::from(data))
}

/// Perform SQL `array ~ regex_array` operation on [`StringArray`] /
/// [`LargeStringArray`] and a scalar.
///
/// See the documentation on [`regexp_is_match_utf8`] for more details.
#[deprecated(since = "54.0.0", note = "please use `regex_is_match_scalar` instead")]
pub fn regexp_is_match_utf8_scalar<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex: &str,
flag: Option<&str>,
) -> Result<BooleanArray, ArrowError> {
let null_bit_buffer = array.nulls().map(|x| x.inner().sliced());
let mut result = BooleanBufferBuilder::new(array.len());

let pattern = match flag {
Some(flag) => format!("(?{flag}){regex}"),
None => regex.to_string(),
};
if pattern.is_empty() {
result.append_n(array.len(), true);
} else {
let re = Regex::new(pattern.as_str()).map_err(|e| {
ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
})?;
for i in 0..array.len() {
let value = array.value(i);
result.append(re.is_match(value));
}
}

let buffer = result.into();
let data = unsafe {
ArrayData::new_unchecked(
DataType::Boolean,
array.len(),
None,
null_bit_buffer,
0,
vec![buffer],
vec![],
)
};
Ok(BooleanArray::from(data))
}

/// Perform SQL `array ~ regex_array` operation on
/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] and a scalar.
///
/// See the documentation on [`regexp_is_match_utf8`] for more details.
pub fn regexp_is_match_utf8_scalar<'a, S>(
/// See the documentation on [`regexp_is_match`] for more details.
pub fn regexp_is_match_scalar<'a, S>(
array: &'a S,
regex: &str,
flag: Option<&str>,
Expand Down Expand Up @@ -163,6 +290,7 @@ where
vec![],
)
};

Ok(BooleanArray::from(data))
}

Expand Down Expand Up @@ -603,45 +731,60 @@ mod tests {
}

test_flag_utf8!(
test_utf8_array_regexp_is_match,
test_array_regexp_is_match_utf8,
StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
regexp_is_match_utf8::<
GenericStringArray<i32>,
GenericStringArray<i32>,
GenericStringArray<i32>,
>,
regexp_is_match_utf8,
[true, false, true, false, false, true]
);
test_flag_utf8!(
test_utf8_array_regexp_is_match_2,
test_array_regexp_is_match_utf8_insensitive,
StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
StringArray::from(vec!["i"; 6]),
regexp_is_match_utf8,
[true, true, true, true, false, true]
);

test_flag_utf8_scalar!(
test_array_regexp_is_match_utf8_scalar,
StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
"^ar",
regexp_is_match_utf8_scalar,
[true, false, false, false]
);
test_flag_utf8_scalar!(
test_array_regexp_is_match_utf8_scalar_empty,
StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
"",
regexp_is_match_utf8_scalar,
[true, true, true, true]
);
test_flag_utf8_scalar!(
test_array_regexp_is_match_utf8_scalar_insensitive,
StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
"^ar",
"i",
regexp_is_match_utf8_scalar,
[true, true, false, false]
);

test_flag_utf8!(
tes_array_regexp_is_match,
StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
regexp_is_match_utf8::<StringViewArray, StringViewArray, StringViewArray>,
regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
[true, false, true, false, false, true]
);
test_flag_utf8!(
test_utf8_array_regexp_is_match_3,
test_array_regexp_is_match_2,
StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
regexp_is_match_utf8::<StringViewArray, GenericStringArray<i32>, GenericStringArray<i32>>,
regexp_is_match::<StringViewArray, GenericStringArray<i32>, GenericStringArray<i32>>,
[true, false, true, false, false, true]
);

test_flag_utf8!(
test_utf8_array_regexp_is_match_insensitive,
StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
StringArray::from(vec!["i"; 6]),
regexp_is_match_utf8::<
GenericStringArray<i32>,
GenericStringArray<i32>,
GenericStringArray<i32>,
>,
[true, true, true, true, false, true]
);
test_flag_utf8!(
test_utf8_array_regexp_is_match_insensitive_2,
test_array_regexp_is_match_insensitive,
StringViewArray::from(vec![
"Official Rust implementation of Apache Arrow",
"apache/arrow-rs",
Expand All @@ -661,68 +804,44 @@ mod tests {
""
]),
StringViewArray::from(vec!["i"; 7]),
regexp_is_match_utf8::<StringViewArray, StringViewArray, StringViewArray>,
regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
[true, true, true, true, true, false, true]
);
test_flag_utf8!(
test_utf8_array_regexp_is_match_insensitive_3,
test_array_regexp_is_match_insensitive_2,
LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
StringArray::from(vec!["i"; 6]),
regexp_is_match_utf8::<GenericStringArray<i64>, StringViewArray, GenericStringArray<i32>>,
regexp_is_match::<GenericStringArray<i64>, StringViewArray, GenericStringArray<i32>>,
[true, true, true, true, false, true]
);

test_flag_utf8_scalar!(
test_utf8_array_regexp_is_match_scalar,
StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
"^ar",
regexp_is_match_utf8_scalar::<GenericStringArray<i32>>,
[true, false, false, false]
);
test_flag_utf8_scalar!(
test_utf8_array_regexp_is_match_scalar_2,
test_array_regexp_is_match_scalar,
StringViewArray::from(vec![
"apache/arrow-rs",
"APACHE/ARROW-RS",
"parquet",
"PARQUET",
]),
"^ap",
regexp_is_match_utf8_scalar::<StringViewArray>,
regexp_is_match_scalar::<StringViewArray>,
[true, false, false, false]
);

test_flag_utf8_scalar!(
test_utf8_array_regexp_is_match_empty_scalar,
StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
"",
regexp_is_match_utf8_scalar::<GenericStringArray<i32>>,
[true, true, true, true]
);
test_flag_utf8_scalar!(
test_utf8_array_regexp_is_match_empty_scalar_2,
test_array_regexp_is_match_scalar_empty,
StringViewArray::from(vec![
"apache/arrow-rs",
"APACHE/ARROW-RS",
"parquet",
"PARQUET",
]),
"",
regexp_is_match_utf8_scalar::<StringViewArray>,
regexp_is_match_scalar::<StringViewArray>,
[true, true, true, true]
);

test_flag_utf8_scalar!(
test_utf8_array_regexp_is_match_insensitive_scalar,
StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
"^ar",
"i",
regexp_is_match_utf8_scalar::<GenericStringArray<i32>>,
[true, true, false, false]
);
test_flag_utf8_scalar!(
test_utf8_array_regexp_is_match_insensitive_scalar_2,
test_array_regexp_is_match_scalar_insensitive,
StringViewArray::from(vec![
"apache/arrow-rs",
"APACHE/ARROW-RS",
Expand All @@ -731,7 +850,7 @@ mod tests {
]),
"^ap",
"i",
regexp_is_match_utf8_scalar::<StringViewArray>,
regexp_is_match_scalar::<StringViewArray>,
[true, true, false, false]
);
}

0 comments on commit 65e6839

Please sign in to comment.