Skip to content

Commit

Permalink
Support regex patterns with brackets when rewriting to PrefixRange pa…
Browse files Browse the repository at this point in the history
…ttern in rlike. (NVIDIA#11088)

* Remove bracket when necessary in PrefixRange patten in Regex rewrite

Signed-off-by: Haoyang Li <[email protected]>

* add pytest cases

Signed-off-by: Haoyang Li <[email protected]>

* fix scala 2.13 build

Signed-off-by: Haoyang Li <[email protected]>

---------

Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven authored Jun 29, 2024
1 parent f954026 commit 2498204
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
2 changes: 2 additions & 0 deletions integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ def test_rlike_rewrite_optimization():
'rlike(a, "a[a-c]{1,3}")',
'rlike(a, "a[a-c]{1,}")',
'rlike(a, "a[a-c]+")',
'rlike(a, "(ab)([a-c]{1})")',
'rlike(a, "(ab[a-c]{1})")',
'rlike(a, "(aaa|bbb|ccc)")',
'rlike(a, ".*.*(aaa|bbb).*.*")',
'rlike(a, "^.*(aaa|bbb|ccc)")',
Expand Down
45 changes: 24 additions & 21 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2035,7 +2035,7 @@ object RegexRewrite {
@scala.annotation.tailrec
private def removeBrackets(astLs: collection.Seq[RegexAST]): collection.Seq[RegexAST] = {
astLs match {
case collection.Seq(RegexGroup(_, term, None)) => removeBrackets(term.children())
case collection.Seq(RegexGroup(_, RegexSequence(terms), None)) => removeBrackets(terms)
case _ => astLs
}
}
Expand All @@ -2051,28 +2051,31 @@ object RegexRewrite {
Option[(String, Int, Int, Int)] = {
val haveLiteralPrefix = isLiteralString(astLs.dropRight(1))
val endsWithRange = astLs.lastOption match {
case Some(RegexRepetition(
RegexCharacterClass(false, ListBuffer(RegexCharacterRange(a,b))),
quantifier)) => {
val (start, end) = (a, b) match {
case (RegexChar(start), RegexChar(end)) => (start, end)
case _ => return None
}
val length = quantifier match {
// In Rlike, contains [a-b]{minLen,maxLen} pattern is equivalent to contains
// [a-b]{minLen} because the matching will return the result once it finds the
// minimum match so y here is unnecessary.
case QuantifierVariableLength(minLen, _) => minLen
case QuantifierFixedLength(len) => len
case SimpleQuantifier(ch) => ch match {
case '*' | '?' => 0
case '+' => 1
case Some(ast) => removeBrackets(collection.Seq(ast)) match {
case collection.Seq(RegexRepetition(
RegexCharacterClass(false, ListBuffer(RegexCharacterRange(a,b))),
quantifier)) => {
val (start, end) = (a, b) match {
case (RegexChar(start), RegexChar(end)) => (start, end)
case _ => return None
}
val length = quantifier match {
// In Rlike, contains [a-b]{minLen,maxLen} pattern is equivalent to contains
// [a-b]{minLen} because the matching will return the result once it finds the
// minimum match so y here is unnecessary.
case QuantifierVariableLength(minLen, _) => minLen
case QuantifierFixedLength(len) => len
case SimpleQuantifier(ch) => ch match {
case '*' | '?' => 0
case '+' => 1
case _ => return None
}
case _ => return None
}
case _ => return None
// Convert start and end to code points
Some((length, start.toInt, end.toInt))
}
// Convert start and end to code points
Some((length, start.toInt, end.toInt))
case _ => None
}
case _ => None
}
Expand Down Expand Up @@ -2153,7 +2156,7 @@ object RegexRewrite {
}
}

val noStartsWithAst = stripLeadingWildcards(noTailingWildcards)
val noStartsWithAst = removeBrackets(stripLeadingWildcards(noTailingWildcards))

// Check if the pattern is a contains literal pattern
if (isLiteralString(noStartsWithAst)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
"(.*)abc[0-9]{1,3}(.*)",
"(.*)abc[0-9a-z]{1,3}(.*)",
"(.*)abc[0-9]{2}.*",
"((abc))([0-9]{3})",
"(abc[0-9]{3})",
"^abc[0-9]{1,3}",
"火花急流[\u4e00-\u9fa5]{1}",
"^[0-9]{6}",
Expand All @@ -63,6 +65,8 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
PrefixRange("abc", 1, 48, 57),
NoOptimization, // prefix followed by a multi-range not supported
PrefixRange("abc", 2, 48, 57),
PrefixRange("abc", 3, 48, 57),
PrefixRange("abc", 3, 48, 57),
NoOptimization, // starts with PrefixRange not supported
PrefixRange("火花急流", 1, 19968, 40869),
NoOptimization, // starts with PrefixRange not supported
Expand Down

0 comments on commit 2498204

Please sign in to comment.