From 2cf59346fd66e1d500a16e8107068dc9e20c3585 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 12 Jun 2024 08:03:27 +0800 Subject: [PATCH] Rewrite multiple literal choice regex to multiple contains in rlike (#10977) * rewrite multiple literal choice to multiple contains, wip Signed-off-by: Haoyang Li * fix bug Signed-off-by: Haoyang Li * optimize memory Signed-off-by: Haoyang Li * remove debug log Signed-off-by: Haoyang Li * address comments Signed-off-by: Haoyang Li * Apply suggestions from code review Co-authored-by: Gera Shegalov * support abc|def case Signed-off-by: Haoyang Li * fix 2.13 Signed-off-by: Haoyang Li * fix 2.13 build Signed-off-by: Haoyang Li --------- Signed-off-by: Haoyang Li Co-authored-by: Gera Shegalov --- .../src/main/python/regexp_test.py | 8 +- .../com/nvidia/spark/rapids/RegexParser.scala | 94 +++++++++++-------- .../spark/sql/rapids/stringFunctions.scala | 30 +++++- .../RegularExpressionRewriteSuite.scala | 31 ++++-- 4 files changed, 118 insertions(+), 45 deletions(-) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index 89929eb6762..18a83870d83 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -454,6 +454,7 @@ def test_rlike_rewrite_optimization(): 'rlike(a, "(.*)(abb)(.*)")', 'rlike(a, "^(abb)(.*)")', 'rlike(a, "^abb")', + 'rlike(a, "^.*(aaa)")', 'rlike(a, "\\\\A(abb)(.*)")', 'rlike(a, "\\\\Aabb")', 'rlike(a, "^(abb)\\\\Z")', @@ -466,7 +467,12 @@ def test_rlike_rewrite_optimization(): 'rlike(a, "ab[a-c]{3}")', 'rlike(a, "a[a-c]{1,3}")', 'rlike(a, "a[a-c]{1,}")', - 'rlike(a, "a[a-c]+")'), + 'rlike(a, "a[a-c]+")', + 'rlike(a, "(aaa|bbb|ccc)")', + 'rlike(a, ".*.*(aaa|bbb).*.*")', + 'rlike(a, "^.*(aaa|bbb|ccc)")', + 'rlike(a, "aaa|bbb")', + 'rlike(a, "aaa|(bbb|ccc)")'), conf=_regexp_conf) def test_regexp_replace_character_set_negated(): diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 0f5ada9f7fa..1ca155f8a52 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -2026,6 +2026,7 @@ object RegexOptimizationType { case class Contains(literal: String) extends RegexOptimizationType case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int) extends RegexOptimizationType + case class MultipleContains(literals: Seq[String]) extends RegexOptimizationType case object NoOptimization extends RegexOptimizationType } @@ -2091,6 +2092,20 @@ object RegexRewrite { } } + private def getMultipleContainsLiterals(ast: RegexAST): Seq[String] = { + ast match { + case RegexGroup(_, term, _) => getMultipleContainsLiterals(term) + case RegexChoice(RegexSequence(parts), ls) if isLiteralString(parts) => { + getMultipleContainsLiterals(ls) match { + case Seq() => Seq.empty + case literals => RegexCharsToString(parts) +: literals + } + } + case RegexSequence(parts) if (isLiteralString(parts)) => Seq(RegexCharsToString(parts)) + case _ => Seq.empty + } + } + private def isWildcard(ast: RegexAST): Boolean = { ast match { case RegexRepetition(RegexChar('.'), SimpleQuantifier('*')) => true @@ -2101,11 +2116,8 @@ object RegexRewrite { } private def stripLeadingWildcards(astLs: collection.Seq[RegexAST]): - collection.Seq[RegexAST] = astLs match { - case (RegexChar('^') | RegexEscaped('A')) :: tail => - // if the pattern starts with ^ or \A, strip it too - tail.dropWhile(isWildcard) - case _ => astLs.dropWhile(isWildcard) + collection.Seq[RegexAST] = { + astLs.dropWhile(isWildcard) } private def stripTailingWildcards(astLs: collection.Seq[RegexAST]): @@ -2124,40 +2136,48 @@ object RegexRewrite { * Matches the given regex ast to a regex optimization type for regex rewrite * optimization. * - * @param ast unparsed children of the Abstract Syntax Tree parsed from a regex pattern. + * @param ast Abstract Syntax Tree parsed from a regex pattern. * @return The `RegexOptimizationType` for the given pattern. */ - @scala.annotation.tailrec - def matchSimplePattern(ast: Seq[RegexAST]): RegexOptimizationType = { - ast match { - case (RegexChar('^') | RegexEscaped('A')) :: astTail => - val noTrailingWildCards = stripTailingWildcards(astTail) - if (isLiteralString(noTrailingWildCards)) { - // ^literal.* => startsWith literal - RegexOptimizationType.StartsWith(RegexCharsToString(noTrailingWildCards)) - } else { - val noWildCards = stripLeadingWildcards(noTrailingWildCards) - if (noWildCards.length == noTrailingWildCards.length) { - // TODO startsWith with PrefIxRange - RegexOptimizationType.NoOptimization - } else { - matchSimplePattern(astTail) - } - } - case astLs => { - val noStartsWithAst = stripTailingWildcards(stripLeadingWildcards(astLs)) - val prefixRangeInfo = getPrefixRangePattern(noStartsWithAst) - if (prefixRangeInfo.isDefined) { - val (prefix, length, start, end) = prefixRangeInfo.get - // (literal[a-b]{x,y}) => prefix range pattern - RegexOptimizationType.PrefixRange(prefix, length, start, end) - } else if (isLiteralString(noStartsWithAst)) { - // literal.* or (literal).* => contains literal - RegexOptimizationType.Contains(RegexCharsToString(noStartsWithAst)) - } else { - RegexOptimizationType.NoOptimization - } + def matchSimplePattern(ast: RegexAST): RegexOptimizationType = { + val astLs = ast match { + case RegexSequence(_) => ast.children() + case _ => Seq(ast) + } + val noTailingWildcards = stripTailingWildcards(astLs) + if (noTailingWildcards.headOption.exists( + ast => ast == RegexChar('^') || ast == RegexEscaped('A'))) { + val possibleLiteral = noTailingWildcards.drop(1) + if (isLiteralString(possibleLiteral)) { + return RegexOptimizationType.StartsWith(RegexCharsToString(possibleLiteral)) + } + } + + val noStartsWithAst = stripLeadingWildcards(noTailingWildcards) + + // Check if the pattern is a contains literal pattern + if (isLiteralString(noStartsWithAst)) { + // literal or .*(literal).* => contains literal + return RegexOptimizationType.Contains(RegexCharsToString(noStartsWithAst)) + } + + // Check if the pattern is a multiple contains literal pattern (e.g. "abc|def|ghi") + if (noStartsWithAst.length == 1) { + val containsLiterals = getMultipleContainsLiterals(noStartsWithAst.head) + if (!containsLiterals.isEmpty) { + return RegexOptimizationType.MultipleContains(containsLiterals) } } + + // Check if the pattern is a prefix range pattern (e.g. "abc[a-z]{3}") + val prefixRangeInfo = getPrefixRangePattern(noStartsWithAst) + if (prefixRangeInfo.isDefined) { + val (prefix, length, start, end) = prefixRangeInfo.get + // (literal[a-b]{x,y}) => prefix range pattern + return RegexOptimizationType.PrefixRange(prefix, length, start, end) + } + + // return NoOptimization if the pattern is not a simple pattern and use cuDF + RegexOptimizationType.NoOptimization } -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 8fea4014149..dc2845e4461 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1073,7 +1073,7 @@ class GpuRLikeMeta( val originalPattern = str.toString val regexAst = new RegexParser(originalPattern).parse() if (conf.isRlikeRegexRewriteEnabled) { - rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst.children()) + rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst) } val (transpiledAST, _) = new CudfRegexTranspiler(RegexFindMode) .getTranspiledAST(regexAst, None, None) @@ -1097,6 +1097,7 @@ class GpuRLikeMeta( } case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType)) case Contains(s) => GpuContains(lhs, GpuLiteral(s, StringType)) + case MultipleContains(ls) => GpuMultipleContains(lhs, ls) case PrefixRange(s, length, start, end) => GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end) case _ => throw new IllegalStateException("Unexpected optimization type") @@ -1126,6 +1127,33 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String) override def dataType: DataType = BooleanType } +case class GpuMultipleContains(input: Expression, searchList: Seq[String]) + extends GpuUnaryExpression with ImplicitCastInputTypes with NullIntolerant { + + override def dataType: DataType = BooleanType + + override def child: Expression = input + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def doColumnar(input: GpuColumnVector): ColumnVector = { + assert(searchList.length > 1) + val accInit = withResource(Scalar.fromString(searchList.head)) { searchScalar => + input.getBase.stringContains(searchScalar) + } + searchList.tail.foldLeft(accInit) { (acc, search) => + val containsSearch = withResource(Scalar.fromString(search)) { searchScalar => + input.getBase.stringContains(searchScalar) + } + withResource(acc) { _ => + withResource(containsSearch) { _ => + acc.or(containsSearch) + } + } + } + } +} + case class GpuLiteralRangePattern(left: Expression, right: Expression, length: Int, start: Int, end: Int) extends GpuBinaryExpressionArgsAnyScalar with ImplicitCastInputTypes with NullIntolerant { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala index a140f4123f4..7626c1450c1 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala @@ -23,7 +23,7 @@ class RegularExpressionRewriteSuite extends AnyFunSuite { Unit = { val results = patterns.map { pattern => val ast = new RegexParser(pattern).parse() - RegexRewrite.matchSimplePattern(ast.children()) + RegexRewrite.matchSimplePattern(ast) } assert(results == excepted) } @@ -40,9 +40,9 @@ class RegularExpressionRewriteSuite extends AnyFunSuite { test("regex rewrite contains") { import RegexOptimizationType._ val patterns = Seq(".*abc.*", ".*(abc).*", "^.*(abc).*$", "^.*(.*)(abc).*.*", - raw".*\w.*\Z", raw".*..*\Z") - val excepted = Seq(Contains("abc"), Contains("abc"), NoOptimization, Contains("abc"), - NoOptimization, NoOptimization) + raw".*\w.*\Z", raw".*..*\Z", "^(.*)(abc)") + val excepted = Seq(Contains("abc"), Contains("abc"), NoOptimization, NoOptimization, + NoOptimization, NoOptimization, NoOptimization) verifyRewritePattern(patterns, excepted) } @@ -67,8 +67,27 @@ class RegularExpressionRewriteSuite extends AnyFunSuite { PrefixRange("火花急流", 1, 19968, 40869), NoOptimization, // starts with PrefixRange not supported NoOptimization, // starts with PrefixRange not supported - PrefixRange("", 6, 48, 57), - PrefixRange("", 3, 48, 57) + NoOptimization, // .* can't match line break so can't be optimized + NoOptimization // .* can't match line break so can't be optimized + ) + verifyRewritePattern(patterns, excepted) + } + + test("regex rewrite multiple contains") { + import RegexOptimizationType._ + val patterns = Seq( + "(abc|def).*", + ".*(abc|def|ghi).*", + "((abc)|(def))", + "(abc)|(def)", + "(火花|急流)" + ) + val excepted = Seq( + MultipleContains(Seq("abc", "def")), + MultipleContains(Seq("abc", "def", "ghi")), + MultipleContains(Seq("abc", "def")), + MultipleContains(Seq("abc", "def")), + MultipleContains(Seq("火花", "急流")) ) verifyRewritePattern(patterns, excepted) }