Skip to content

Commit

Permalink
add tests and config
Browse files Browse the repository at this point in the history
Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven committed Apr 26, 2024
1 parent d290ccd commit eeefa34
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 7 deletions.
13 changes: 12 additions & 1 deletion integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
else:
pytestmark = pytest.mark.regexp

_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True }
_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True ,
'spark.rapids.sql.rLikeRegexRewrite.enabled': 'new'}

def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')
Expand Down Expand Up @@ -464,6 +465,16 @@ def test_regexp_rlike_rewrite_optimization():
'regexp_like(a, "(.*)(.*)abcd")'),
conf=_regexp_conf)

@pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0')
def test_regexp_rlike_rewrite_optimization_str_dig():
gen = mk_str_gen('([abcd]{3,6})?[0-9]{2,5}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'a',
'regexp_like(a, "[0-9]{4,}")',
'regexp_like(a, "abcd([0-9]{5})")'),
conf=_regexp_conf)

def test_regexp_replace_character_set_negated():
gen = mk_str_gen('[abcd]{0,3}[\r\n]{0,2}[abcd]{0,3}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -926,8 +926,8 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
val ENABLE_RLIKE_REGEX_REWRITE = conf("spark.rapids.sql.rLikeRegexRewrite.enabled")
.doc("Enable the optimization to rewrite rlike regex to contains in some cases.")
.internal()
.booleanConf
.createWithDefault(true)
.stringConf
.createWithDefault("new")

val ENABLE_GETJSONOBJECT_LEGACY = conf("spark.rapids.sql.getJsonObject.legacy.enabled")
.doc("When set to true, the get_json_object function will use the legacy implementation " +
Expand Down Expand Up @@ -2630,7 +2630,7 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val isTieredProjectEnabled: Boolean = get(ENABLE_TIERED_PROJECT)

lazy val isRlikeRegexRewriteEnabled: Boolean = get(ENABLE_RLIKE_REGEX_REWRITE)
lazy val isRlikeRegexRewriteEnabled: String = get(ENABLE_RLIKE_REGEX_REWRITE)

lazy val isLegacyGetJsonObjectEnabled: Boolean = get(ENABLE_GETJSONOBJECT_LEGACY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1127,12 +1127,12 @@ class GpuRLikeMeta(
}
case Digits(from, _) :: rest
if rest == List() || rest.forall(_ == Wildcard) => {
println(s"!!!GpuStringDigits1: $from")
// println(s"!!!GpuStringDigits1: $from")
GpuStringDigits(lhs, GpuLiteral("", StringType), from)
}
case Fixstring(s) :: Digits(from, _) :: rest
if rest == List() || rest.forall(_ == Wildcard) => {
println(s"!!!GpuStringDigits2: $s, $from")
// println(s"!!!GpuStringDigits2: $s, $from")
GpuStringDigits(lhs, GpuLiteral(s, StringType), from)
}
case Fixstring(s) :: rest
Expand All @@ -1147,6 +1147,34 @@ class GpuRLikeMeta(
}
}

def optimizeSimplePatternLegancy(rhs: Expression, lhs: Expression, parts: List[RegexprPart]):
GpuExpression = {
parts match {
case Wildcard :: rest => {
optimizeSimplePattern(rhs, lhs, rest)
}
case Start :: Wildcard :: List(End) => {
GpuEqualTo(lhs, rhs)
}
case Start :: Fixstring(s) :: rest
if rest.forall(_ == Wildcard) || rest == List() => {
GpuStartsWith(lhs, GpuLiteral(s, StringType))
}
case Fixstring(s) :: List(End) => {
GpuEndsWith(lhs, GpuLiteral(s, StringType))
}
case Fixstring(s) :: rest
if rest == List() || rest.forall(_ == Wildcard) => {
GpuContains(lhs, GpuLiteral(s, StringType))
}
case _ => {
val patternStr = pattern.getOrElse(throw new IllegalStateException(
"Expression has not been tagged with cuDF regex pattern"))
GpuRLike(lhs, rhs, patternStr)
}
}
}

override def tagExprForGpu(): Unit = {
GpuRegExpUtils.tagForRegExpEnabled(this)
expr.right match {
Expand All @@ -1168,12 +1196,20 @@ class GpuRLikeMeta(
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
if (conf.isRlikeRegexRewriteEnabled) {
if (conf.isRlikeRegexRewriteEnabled == "new") {
// println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}")
// if the pattern can be converted to a startswith or endswith pattern, we can use
// GpuStartsWith, GpuEndsWith or GpuContains instead to get better performance
val parts = parseRegexToParts(originalPattern)
optimizeSimplePattern(rhs, lhs, parts)
} else if (conf.isRlikeRegexRewriteEnabled == "legacy") {
// println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}")
// if the pattern can be converted to a startswith or endswith pattern, we can use
// GpuStartsWith, GpuEndsWith or GpuContains instead to get better performance
val parts = parseRegexToParts(originalPattern)
optimizeSimplePatternLegancy(rhs, lhs, parts)
} else {
// println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}")
val patternStr = pattern.getOrElse(throw new IllegalStateException(
"Expression has not been tagged with cuDF regex pattern"))
GpuRLike(lhs, rhs, patternStr)
Expand Down

0 comments on commit eeefa34

Please sign in to comment.