diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index d277f298949..b3f81aed025 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -27,7 +27,8 @@ else: pytestmark = pytest.mark.regexp -_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True } +_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True , + 'spark.rapids.sql.rLikeRegexRewrite.enabled': 'new'} def mk_str_gen(pattern): return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}') @@ -464,6 +465,16 @@ def test_regexp_rlike_rewrite_optimization(): 'regexp_like(a, "(.*)(.*)abcd")'), conf=_regexp_conf) +@pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0') +def test_regexp_rlike_rewrite_optimization_str_dig(): + gen = mk_str_gen('([abcd]{3,6})?[0-9]{2,5}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a', + 'regexp_like(a, "[0-9]{4,}")', + 'regexp_like(a, "abcd([0-9]{5})")'), + conf=_regexp_conf) + def test_regexp_replace_character_set_negated(): gen = mk_str_gen('[abcd]{0,3}[\r\n]{0,2}[abcd]{0,3}') assert_gpu_and_cpu_are_equal_collect( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 9a5b28ae43e..b3f49b95a84 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -926,8 +926,8 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") val ENABLE_RLIKE_REGEX_REWRITE = conf("spark.rapids.sql.rLikeRegexRewrite.enabled") .doc("Enable the optimization to rewrite rlike regex to contains in some cases.") .internal() - .booleanConf - .createWithDefault(true) + .stringConf + .createWithDefault("new") val ENABLE_GETJSONOBJECT_LEGACY = conf("spark.rapids.sql.getJsonObject.legacy.enabled") .doc("When set to true, the get_json_object function will use the legacy implementation " + @@ -2630,7 +2630,7 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isTieredProjectEnabled: Boolean = get(ENABLE_TIERED_PROJECT) - lazy val isRlikeRegexRewriteEnabled: Boolean = get(ENABLE_RLIKE_REGEX_REWRITE) + lazy val isRlikeRegexRewriteEnabled: String = get(ENABLE_RLIKE_REGEX_REWRITE) lazy val isLegacyGetJsonObjectEnabled: Boolean = get(ENABLE_GETJSONOBJECT_LEGACY) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 7d5b9daa2b6..e33548868f2 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1127,12 +1127,12 @@ class GpuRLikeMeta( } case Digits(from, _) :: rest if rest == List() || rest.forall(_ == Wildcard) => { - println(s"!!!GpuStringDigits1: $from") + // println(s"!!!GpuStringDigits1: $from") GpuStringDigits(lhs, GpuLiteral("", StringType), from) } case Fixstring(s) :: Digits(from, _) :: rest if rest == List() || rest.forall(_ == Wildcard) => { - println(s"!!!GpuStringDigits2: $s, $from") + // println(s"!!!GpuStringDigits2: $s, $from") GpuStringDigits(lhs, GpuLiteral(s, StringType), from) } case Fixstring(s) :: rest @@ -1147,6 +1147,34 @@ class GpuRLikeMeta( } } + def optimizeSimplePatternLegancy(rhs: Expression, lhs: Expression, parts: List[RegexprPart]): + GpuExpression = { + parts match { + case Wildcard :: rest => { + optimizeSimplePattern(rhs, lhs, rest) + } + case Start :: Wildcard :: List(End) => { + GpuEqualTo(lhs, rhs) + } + case Start :: Fixstring(s) :: rest + if rest.forall(_ == Wildcard) || rest == List() => { + GpuStartsWith(lhs, GpuLiteral(s, StringType)) + } + case Fixstring(s) :: List(End) => { + GpuEndsWith(lhs, GpuLiteral(s, StringType)) + } + case Fixstring(s) :: rest + if rest == List() || rest.forall(_ == Wildcard) => { + GpuContains(lhs, GpuLiteral(s, StringType)) + } + case _ => { + val patternStr = pattern.getOrElse(throw new IllegalStateException( + "Expression has not been tagged with cuDF regex pattern")) + GpuRLike(lhs, rhs, patternStr) + } + } + } + override def tagExprForGpu(): Unit = { GpuRegExpUtils.tagForRegExpEnabled(this) expr.right match { @@ -1168,12 +1196,20 @@ class GpuRLikeMeta( } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { - if (conf.isRlikeRegexRewriteEnabled) { + if (conf.isRlikeRegexRewriteEnabled == "new") { + // println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}") // if the pattern can be converted to a startswith or endswith pattern, we can use // GpuStartsWith, GpuEndsWith or GpuContains instead to get better performance val parts = parseRegexToParts(originalPattern) optimizeSimplePattern(rhs, lhs, parts) + } else if (conf.isRlikeRegexRewriteEnabled == "legacy") { + // println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}") + // if the pattern can be converted to a startswith or endswith pattern, we can use + // GpuStartsWith, GpuEndsWith or GpuContains instead to get better performance + val parts = parseRegexToParts(originalPattern) + optimizeSimplePatternLegancy(rhs, lhs, parts) } else { + // println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}") val patternStr = pattern.getOrElse(throw new IllegalStateException( "Expression has not been tagged with cuDF regex pattern")) GpuRLike(lhs, rhs, patternStr)