From 2498204313dbc2f21cbdd76f84d5b92068949620 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sat, 29 Jun 2024 10:06:05 +0800 Subject: [PATCH] Support regex patterns with brackets when rewriting to PrefixRange pattern in rlike. (#11088) * Remove bracket when necessary in PrefixRange patten in Regex rewrite Signed-off-by: Haoyang Li * add pytest cases Signed-off-by: Haoyang Li * fix scala 2.13 build Signed-off-by: Haoyang Li --------- Signed-off-by: Haoyang Li --- .../src/main/python/regexp_test.py | 2 + .../com/nvidia/spark/rapids/RegexParser.scala | 45 ++++++++++--------- .../RegularExpressionRewriteSuite.scala | 4 ++ 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index 18a83870d83..c2062605ca1 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -468,6 +468,8 @@ def test_rlike_rewrite_optimization(): 'rlike(a, "a[a-c]{1,3}")', 'rlike(a, "a[a-c]{1,}")', 'rlike(a, "a[a-c]+")', + 'rlike(a, "(ab)([a-c]{1})")', + 'rlike(a, "(ab[a-c]{1})")', 'rlike(a, "(aaa|bbb|ccc)")', 'rlike(a, ".*.*(aaa|bbb).*.*")', 'rlike(a, "^.*(aaa|bbb|ccc)")', diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 1ca155f8a52..362a9cce293 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -2035,7 +2035,7 @@ object RegexRewrite { @scala.annotation.tailrec private def removeBrackets(astLs: collection.Seq[RegexAST]): collection.Seq[RegexAST] = { astLs match { - case collection.Seq(RegexGroup(_, term, None)) => removeBrackets(term.children()) + case collection.Seq(RegexGroup(_, RegexSequence(terms), None)) => removeBrackets(terms) case _ => astLs } } @@ -2051,28 +2051,31 @@ object RegexRewrite { Option[(String, Int, Int, Int)] = { val haveLiteralPrefix = isLiteralString(astLs.dropRight(1)) val endsWithRange = astLs.lastOption match { - case Some(RegexRepetition( - RegexCharacterClass(false, ListBuffer(RegexCharacterRange(a,b))), - quantifier)) => { - val (start, end) = (a, b) match { - case (RegexChar(start), RegexChar(end)) => (start, end) - case _ => return None - } - val length = quantifier match { - // In Rlike, contains [a-b]{minLen,maxLen} pattern is equivalent to contains - // [a-b]{minLen} because the matching will return the result once it finds the - // minimum match so y here is unnecessary. - case QuantifierVariableLength(minLen, _) => minLen - case QuantifierFixedLength(len) => len - case SimpleQuantifier(ch) => ch match { - case '*' | '?' => 0 - case '+' => 1 + case Some(ast) => removeBrackets(collection.Seq(ast)) match { + case collection.Seq(RegexRepetition( + RegexCharacterClass(false, ListBuffer(RegexCharacterRange(a,b))), + quantifier)) => { + val (start, end) = (a, b) match { + case (RegexChar(start), RegexChar(end)) => (start, end) + case _ => return None + } + val length = quantifier match { + // In Rlike, contains [a-b]{minLen,maxLen} pattern is equivalent to contains + // [a-b]{minLen} because the matching will return the result once it finds the + // minimum match so y here is unnecessary. + case QuantifierVariableLength(minLen, _) => minLen + case QuantifierFixedLength(len) => len + case SimpleQuantifier(ch) => ch match { + case '*' | '?' => 0 + case '+' => 1 + case _ => return None + } case _ => return None } - case _ => return None + // Convert start and end to code points + Some((length, start.toInt, end.toInt)) } - // Convert start and end to code points - Some((length, start.toInt, end.toInt)) + case _ => None } case _ => None } @@ -2153,7 +2156,7 @@ object RegexRewrite { } } - val noStartsWithAst = stripLeadingWildcards(noTailingWildcards) + val noStartsWithAst = removeBrackets(stripLeadingWildcards(noTailingWildcards)) // Check if the pattern is a contains literal pattern if (isLiteralString(noStartsWithAst)) { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala index 7626c1450c1..a55815b95ef 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala @@ -52,6 +52,8 @@ class RegularExpressionRewriteSuite extends AnyFunSuite { "(.*)abc[0-9]{1,3}(.*)", "(.*)abc[0-9a-z]{1,3}(.*)", "(.*)abc[0-9]{2}.*", + "((abc))([0-9]{3})", + "(abc[0-9]{3})", "^abc[0-9]{1,3}", "火花急流[\u4e00-\u9fa5]{1}", "^[0-9]{6}", @@ -63,6 +65,8 @@ class RegularExpressionRewriteSuite extends AnyFunSuite { PrefixRange("abc", 1, 48, 57), NoOptimization, // prefix followed by a multi-range not supported PrefixRange("abc", 2, 48, 57), + PrefixRange("abc", 3, 48, 57), + PrefixRange("abc", 3, 48, 57), NoOptimization, // starts with PrefixRange not supported PrefixRange("火花急流", 1, 19968, 40869), NoOptimization, // starts with PrefixRange not supported