From 24988bffb6371205b03a183021db698e7445af34 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sun, 7 Apr 2024 19:42:15 +0800 Subject: [PATCH] A hacky approach for regexpr rewrite Signed-off-by: Haoyang Li --- .../src/main/python/regexp_test.py | 10 +++++ .../spark/sql/rapids/stringFunctions.scala | 42 +++++++++++++++++-- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index ff47d0020f3..2808fd5e4db 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -444,6 +444,16 @@ def test_regexp_like(): 'regexp_like(a, "a[bc]d")'), conf=_regexp_conf) +@pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0') +def test_regexp_rlike_startswith(): + gen = mk_str_gen('[abcd]{3,4}[0-9]{0,2}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a', + 'regexp_like(a, "(abcd)(.*)")', + 'regexp_like(a, "abcd(.*)")'), + conf=_regexp_conf) + def test_regexp_replace_character_set_negated(): gen = mk_str_gen('[abcd]{0,3}[\r\n]{0,2}[abcd]{0,3}') assert_gpu_and_cpu_are_equal_collect( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 0850b1fcafb..bf74931e805 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -168,7 +168,7 @@ case class GpuStartsWith(left: Expression, right: Expression) override def toString: String = s"gpustartswith($left, $right)" - def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = + def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = lhs.getBase.startsWith(rhs.getBase) override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { @@ -1082,11 +1082,47 @@ class GpuRLikeMeta( } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { - GpuRLike(lhs, rhs, pattern.getOrElse( - throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern"))) + val patternStr = pattern.getOrElse( + throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")) + // if the pattern can be converted to a startswith or endswith pattern, we can use + // GpuStartsWith or GpuEndsWith instead to get better performance + GpuRLike.optimizeSimplePattern(rhs, lhs, patternStr) } } +object GpuRLike { + + // // '(' and ')' are allowed + val specialChars = Seq('^', '$', '.', '|', '*', '?', '+', '[', ']', '{', '}') + + // val endWithPatterns = Seq(".*$", "(.*)$") + // val startWithPatterns = Seq("^.*", "^(.*)") + // val allMatchPatterns = Seq(".*", "(.*)") + + def isSimplePattern(pattern: String): Boolean = { + pattern.forall(c => !specialChars.contains(c)) + } + + def removeBrackets(pattern: String): String = { + if (pattern.startsWith("(") && pattern.endsWith(")")) { + pattern.substring(1, pattern.length - 1) + } else { + pattern + } + } + + def optimizeSimplePattern(rhs: Expression, lhs: Expression, pattern: String): GpuExpression = { + val startWithPattern = removeBrackets(pattern.stripSuffix("([^\n\r\u0085\u2028\u2029]*)")) + if (isSimplePattern(startWithPattern)) { + // println(s"Optimizing $pattern to startWithPattern $startWithPattern") + GpuStartsWith(lhs, GpuLiteral(startWithPattern, StringType)) + } else { + // println(s"Optimizing $pattern to gpurlike") + GpuRLike(lhs, rhs, pattern) + } + } +} + case class GpuRLike(left: Expression, right: Expression, pattern: String) extends GpuBinaryExpressionArgsAnyScalar with ImplicitCastInputTypes