From 8b88378c4dd26af6fd5e6b593783e4fe9fa35796 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 16 Apr 2024 18:46:56 +0800 Subject: [PATCH] clean up Signed-off-by: Haoyang Li --- .../src/main/python/regexp_test.py | 12 +++-------- .../spark/sql/rapids/stringFunctions.scala | 20 +++++++++++++------ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index e528517d5ef..2cb0f81d91d 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -27,8 +27,7 @@ else: pytestmark = pytest.mark.regexp -_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True, - 'spark.rapids.sql.rLikeRegexRewrite.enabled': True} +_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True } def mk_str_gen(pattern): return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}') @@ -607,7 +606,7 @@ def test_regexp_hexadecimal_digits(): gen = mk_str_gen( '[abcd]\\\\x00\\\\x7f\\\\x80\\\\xff\\\\x{10ffff}\\\\x{00eeee}[\\\\xa0-\\\\xb0][abcd]') assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, gen, length=10).selectExpr( + lambda spark: unary_op_df(spark, gen).selectExpr( 'rlike(a, "\\\\x7f")', 'rlike(a, "\\\\x80")', 'rlike(a, "[\\\\xa0-\\\\xf0]")', @@ -1044,12 +1043,7 @@ def test_regexp_memory_fallback(): 'a rlike "a{1,6}"', 'a rlike "abcdef"', 'a rlike "(1)(2)(3)"', - 'a rlike "1|2|3|4|5|6"', - 'a rlike "^.*aaaa.*$"', - 'a rlike "^aaaa.*"', - 'a rlike ".*aaaa$"', - 'a rlike ".*aaaa.*"', - 'a rlike "aaaa"', + 'a rlike "1|2|3|4|5|6"' ), cpu_fallback_class_name='RLike', conf={ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 428b56db214..d3f5cd62d9f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1073,7 +1073,7 @@ class GpuRLikeMeta( private var originalPattern: String = "" private var pattern: Option[String] = None - val specialChars = Seq('^', '$', '.', '|', '*', '?', '+', '[', ']', '{', '}', '(', ')') + val specialChars = Seq('^', '$', '.', '|', '*', '?', '+', '[', ']', '{', '}', '\\' ,'(', ')') def isSimplePattern(pat: String): Boolean = { pat.size > 0 && pat.forall(c => !specialChars.contains(c)) @@ -1107,15 +1107,23 @@ class GpuRLikeMeta( def optimizeSimplePattern(rhs: Expression, lhs: Expression, parts: List[RegexprPart]): GpuExpression = { parts match { - case Wildcard :: rest => optimizeSimplePattern(rhs, lhs, rest) - case Start :: Fixstring(s) :: List(End) => GpuEqualTo(lhs, GpuLiteral(s, StringType)) + case Wildcard :: rest => { + optimizeSimplePattern(rhs, lhs, rest) + } + case Start :: Wildcard :: List(End) => { + GpuEqualTo(lhs, rhs) + } case Start :: Fixstring(s) :: rest - if rest.forall(_ == Wildcard) || rest == List() => + if rest.forall(_ == Wildcard) || rest == List() => { GpuStartsWith(lhs, GpuLiteral(s, StringType)) - case Fixstring(s) :: List(End) => GpuEndsWith(lhs, GpuLiteral(s, StringType)) + } + case Fixstring(s) :: List(End) => { + GpuEndsWith(lhs, GpuLiteral(s, StringType)) + } case Fixstring(s) :: rest - if rest == List() || rest.forall(_ == Wildcard) => + if rest == List() || rest.forall(_ == Wildcard) => { GpuContains(lhs, GpuLiteral(s, StringType)) + } case _ => { val patternStr = pattern.getOrElse(throw new IllegalStateException( "Expression has not been tagged with cuDF regex pattern"))