From aac469ea1af7c9ad75a578be817ee2fafaceb603 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 26 Apr 2024 13:23:47 +0800 Subject: [PATCH] Regex string digits pattern (#6) * A hacky approach for regexpr rewrite Signed-off-by: Haoyang Li * Use contains instead for that case Signed-off-by: Haoyang Li * add config to switch Signed-off-by: Haoyang Li * Rewrite some rlike expression to StartsWith/EndsWith/Contains Signed-off-by: Haoyang Li * clean up Signed-off-by: Haoyang Li * wip Signed-off-by: Haoyang Li * wip Signed-off-by: Haoyang Li * add tests and config Signed-off-by: Haoyang Li --------- Signed-off-by: Haoyang Li --- .../src/main/python/regexp_test.py | 28 ++- .../com/nvidia/spark/rapids/RapidsConf.scala | 8 +- .../spark/sql/rapids/stringFunctions.scala | 178 ++++++++++++++---- 3 files changed, 174 insertions(+), 40 deletions(-) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index dcaa711438d..e643b3e3a6c 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -27,8 +27,8 @@ else: pytestmark = pytest.mark.regexp -_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True, - 'spark.rapids.sql.rLikeRegexRewrite.enabled': False} +_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True , + 'spark.rapids.sql.rLikeRegexRewrite.enabled': 'new'} def mk_str_gen(pattern): return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}') @@ -446,13 +446,33 @@ def test_regexp_like(): conf=_regexp_conf) @pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0') -def test_regexp_rlike_contains(): +def test_regexp_rlike_rewrite_optimization(): gen = mk_str_gen('[abcd]{3,6}') assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, gen).selectExpr( 'a', 'regexp_like(a, "(abcd)(.*)")', - 'regexp_like(a, "abcd(.*)")'), + 'regexp_like(a, "abcd(.*)")', + 'regexp_like(a, "(.*)(abcd)(.*)")', + 'regexp_like(a, "^(abcd)(.*)")', + 'regexp_like(a, "^abcd")', + 'regexp_like(a, "(abcd)$")', + 'regexp_like(a, ".*abcd$")', + 'regexp_like(a, "^(abcd)$")', + 'regexp_like(a, "^abcd$")', + 'regexp_like(a, "ab(.*)cd")', + 'regexp_like(a, "^^abcd")', + 'regexp_like(a, "(.*)(.*)abcd")'), + conf=_regexp_conf) + +@pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0') +def test_regexp_rlike_rewrite_optimization_str_dig(): + gen = mk_str_gen('([abcd]{3,6})?[0-9]{2,5}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a', + 'regexp_like(a, "[0-9]{4,}")', + 'regexp_like(a, "abcd([0-9]{5})")'), conf=_regexp_conf) def test_regexp_replace_character_set_negated(): diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 76e137b10dd..cc232052c65 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -905,8 +905,9 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") val ENABLE_RLIKE_REGEX_REWRITE = conf("spark.rapids.sql.rLikeRegexRewrite.enabled") .doc("Enable the optimization to rewrite rlike regex to contains in some cases.") .internal() - .booleanConf - .createWithDefault(true) + .stringConf + .createWithDefault("new") + val ENABLE_GETJSONOBJECT_LEGACY = conf("spark.rapids.sql.getJsonObject.legacy.enabled") .doc("When set to true, the get_json_object function will use the legacy implementation " + "on the GPU. The legacy implementation is faster than the current implementation, but " + @@ -2634,7 +2635,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isTieredProjectEnabled: Boolean = get(ENABLE_TIERED_PROJECT) - lazy val isRlikeRegexRewriteEnabled: Boolean = get(ENABLE_RLIKE_REGEX_REWRITE) + lazy val isRlikeRegexRewriteEnabled: String = get(ENABLE_RLIKE_REGEX_REWRITE) + lazy val isLegacyGetJsonObjectEnabled: Boolean = get(ENABLE_GETJSONOBJECT_LEGACY) lazy val isExpandPreprojectEnabled: Boolean = get(ENABLE_EXPAND_PREPROJECT) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index ef72606b4c6..e33548868f2 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -26,7 +26,7 @@ import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, Co import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.jni.CastStrings +import com.nvidia.spark.rapids.jni.{CastStrings, StringDigitsPattern} import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl} import org.apache.spark.sql.catalyst.expressions._ @@ -1054,45 +1054,125 @@ object GpuRegExpUtils { } +sealed trait RegexprPart +object RegexprPart { + case object Start extends RegexprPart // ^ + case object End extends RegexprPart // $ + case object Wildcard extends RegexprPart // .* or (.*) + case class Digits(from: Int, to: Int) extends RegexprPart // [0-9]{a, b} + case class Fixstring(name: String) extends RegexprPart // normal string without special characters + case class Regexpr(value: String) extends RegexprPart // other strings +} + class GpuRLikeMeta( expr: RLike, conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) extends BinaryExprMeta[RLike](expr, conf, parent, rule) { - + import RegexprPart._ + + private var originalPattern: String = "" private var pattern: Option[String] = None - val specialChars = Seq('^', '$', '.', '|', '*', '?', '+', '[', ']', '{', '}') - - val startWithSuffix = "([^\n\r\u0085\u2028\u2029]*)" - - // val endWithPatterns = Seq(".*$", "(.*)$") - // val startWithPatterns = Seq("^.*", "^(.*)") - // val allMatchPatterns = Seq(".*", "(.*)") - - def isSimplePattern(pattern: String): Boolean = { - pattern.forall(c => !specialChars.contains(c)) + val specialChars = Seq('^', '$', '.', '|', '*', '?', '+', '[', ']', '{', '}', '\\' ,'(', ')') + + def isSimplePattern(pat: String): Boolean = { + pat.size > 0 && pat.forall(c => !specialChars.contains(c)) + } + + def parseRegexToParts(pat: String): List[RegexprPart] = { + pat match { + case "" => + List() + case s if s.startsWith("^") => + Start :: parseRegexToParts(s.substring(1)) + case s if s.endsWith("$") => + parseRegexToParts(s.substring(0, s.length - 1)) :+ End + case s if s.startsWith(".*") => + Wildcard :: parseRegexToParts(s.substring(2)) + case s if s.endsWith(".*") => + parseRegexToParts(s.substring(0, s.length - 2)) :+ Wildcard + case s if s.startsWith("(.*)") => + Wildcard :: parseRegexToParts(s.substring(4)) + case s if s.endsWith("(.*)") => + parseRegexToParts(s.substring(0, s.length - 4)) :+ Wildcard + case s if s.endsWith("([0-9]{5})") => + parseRegexToParts(s.substring(0, s.length - 10)) :+ Digits(5, 5) + case s if s.endsWith("[0-9]{4,}") => + parseRegexToParts(s.substring(0, s.length - 9)) :+ Digits(4, -1) + case s if s.startsWith("(") && s.endsWith(")") => + parseRegexToParts(s.substring(1, s.length - 1)) + case s if isSimplePattern(s) => + Fixstring(s) :: List() + case s => + Regexpr(s) :: List() + } } - def removeBrackets(pattern: String): String = { - if (pattern.startsWith("(") && pattern.endsWith(")")) { - pattern.substring(1, pattern.length - 1) - } else { - pattern + def optimizeSimplePattern(rhs: Expression, lhs: Expression, parts: List[RegexprPart]): + GpuExpression = { + parts match { + case Wildcard :: rest => { + optimizeSimplePattern(rhs, lhs, rest) + } + case Start :: Wildcard :: List(End) => { + GpuEqualTo(lhs, rhs) + } + case Start :: Fixstring(s) :: rest + if rest.forall(_ == Wildcard) || rest == List() => { + GpuStartsWith(lhs, GpuLiteral(s, StringType)) + } + case Fixstring(s) :: List(End) => { + GpuEndsWith(lhs, GpuLiteral(s, StringType)) + } + case Digits(from, _) :: rest + if rest == List() || rest.forall(_ == Wildcard) => { + // println(s"!!!GpuStringDigits1: $from") + GpuStringDigits(lhs, GpuLiteral("", StringType), from) + } + case Fixstring(s) :: Digits(from, _) :: rest + if rest == List() || rest.forall(_ == Wildcard) => { + // println(s"!!!GpuStringDigits2: $s, $from") + GpuStringDigits(lhs, GpuLiteral(s, StringType), from) + } + case Fixstring(s) :: rest + if rest == List() || rest.forall(_ == Wildcard) => { + GpuContains(lhs, GpuLiteral(s, StringType)) + } + case _ => { + val patternStr = pattern.getOrElse(throw new IllegalStateException( + "Expression has not been tagged with cuDF regex pattern")) + GpuRLike(lhs, rhs, patternStr) + } } } - def optimizeSimplePattern(rhs: Expression, lhs: Expression, pattern: String): GpuExpression = { - // check if the pattern is end with startWithSuffix - if (conf.isRlikeRegexRewriteEnabled && pattern.endsWith(startWithSuffix)) { - val startWithPattern = removeBrackets(pattern.stripSuffix(startWithSuffix)) - if (isSimplePattern(startWithPattern)) { - // println(s"Optimizing $pattern to GpuContains $startWithPattern") - return GpuContains(lhs, GpuLiteral(startWithPattern, StringType)) + def optimizeSimplePatternLegancy(rhs: Expression, lhs: Expression, parts: List[RegexprPart]): + GpuExpression = { + parts match { + case Wildcard :: rest => { + optimizeSimplePattern(rhs, lhs, rest) + } + case Start :: Wildcard :: List(End) => { + GpuEqualTo(lhs, rhs) + } + case Start :: Fixstring(s) :: rest + if rest.forall(_ == Wildcard) || rest == List() => { + GpuStartsWith(lhs, GpuLiteral(s, StringType)) + } + case Fixstring(s) :: List(End) => { + GpuEndsWith(lhs, GpuLiteral(s, StringType)) + } + case Fixstring(s) :: rest + if rest == List() || rest.forall(_ == Wildcard) => { + GpuContains(lhs, GpuLiteral(s, StringType)) + } + case _ => { + val patternStr = pattern.getOrElse(throw new IllegalStateException( + "Expression has not been tagged with cuDF regex pattern")) + GpuRLike(lhs, rhs, patternStr) } } - // println(s"Optimizing $pattern to gpurlike") - GpuRLike(lhs, rhs, pattern) } override def tagExprForGpu(): Unit = { @@ -1101,8 +1181,9 @@ class GpuRLikeMeta( case Literal(str: UTF8String, DataTypes.StringType) if str != null => try { // verify that we support this regex and can transpile it to cuDF format - val (transpiledAST, _) = - new CudfRegexTranspiler(RegexFindMode).getTranspiledAST(str.toString, None, None) + originalPattern = str.toString + val (transpiledAST, _) = new CudfRegexTranspiler(RegexFindMode) + .getTranspiledAST(originalPattern, None, None) GpuRegExpUtils.validateRegExpComplexity(this, transpiledAST) pattern = Some(transpiledAST.toRegexString) } catch { @@ -1115,12 +1196,43 @@ class GpuRLikeMeta( } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { - val patternStr = pattern.getOrElse( - throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")) - // if the pattern can be converted to a startswith or endswith pattern, we can use - // GpuStartsWith or GpuEndsWith instead to get better performance - optimizeSimplePattern(rhs, lhs, patternStr) + if (conf.isRlikeRegexRewriteEnabled == "new") { + // println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}") + // if the pattern can be converted to a startswith or endswith pattern, we can use + // GpuStartsWith, GpuEndsWith or GpuContains instead to get better performance + val parts = parseRegexToParts(originalPattern) + optimizeSimplePattern(rhs, lhs, parts) + } else if (conf.isRlikeRegexRewriteEnabled == "legacy") { + // println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}") + // if the pattern can be converted to a startswith or endswith pattern, we can use + // GpuStartsWith, GpuEndsWith or GpuContains instead to get better performance + val parts = parseRegexToParts(originalPattern) + optimizeSimplePatternLegancy(rhs, lhs, parts) + } else { + // println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}") + val patternStr = pattern.getOrElse(throw new IllegalStateException( + "Expression has not been tagged with cuDF regex pattern")) + GpuRLike(lhs, rhs, patternStr) + } + } +} + +case class GpuStringDigits(left: Expression, right: Expression, from: Int) + extends GpuBinaryExpressionArgsAnyScalar with ImplicitCastInputTypes with NullIntolerant { + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { + StringDigitsPattern.stringDigitsPattern(lhs.getBase, rhs.getBase, from) + } + + override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs => + doColumnar(expandedLhs, rhs) } + } } case class GpuRLike(left: Expression, right: Expression, pattern: String)