Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven committed Apr 16, 2024
1 parent 552cf7e commit 8b88378
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
12 changes: 3 additions & 9 deletions integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
else:
pytestmark = pytest.mark.regexp

_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True,
'spark.rapids.sql.rLikeRegexRewrite.enabled': True}
_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True }

def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')
Expand Down Expand Up @@ -607,7 +606,7 @@ def test_regexp_hexadecimal_digits():
gen = mk_str_gen(
'[abcd]\\\\x00\\\\x7f\\\\x80\\\\xff\\\\x{10ffff}\\\\x{00eeee}[\\\\xa0-\\\\xb0][abcd]')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen, length=10).selectExpr(
lambda spark: unary_op_df(spark, gen).selectExpr(
'rlike(a, "\\\\x7f")',
'rlike(a, "\\\\x80")',
'rlike(a, "[\\\\xa0-\\\\xf0]")',
Expand Down Expand Up @@ -1044,12 +1043,7 @@ def test_regexp_memory_fallback():
'a rlike "a{1,6}"',
'a rlike "abcdef"',
'a rlike "(1)(2)(3)"',
'a rlike "1|2|3|4|5|6"',
'a rlike "^.*aaaa.*$"',
'a rlike "^aaaa.*"',
'a rlike ".*aaaa$"',
'a rlike ".*aaaa.*"',
'a rlike "aaaa"',
'a rlike "1|2|3|4|5|6"'
),
cpu_fallback_class_name='RLike',
conf={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ class GpuRLikeMeta(
private var originalPattern: String = ""
private var pattern: Option[String] = None

val specialChars = Seq('^', '$', '.', '|', '*', '?', '+', '[', ']', '{', '}', '(', ')')
val specialChars = Seq('^', '$', '.', '|', '*', '?', '+', '[', ']', '{', '}', '\\' ,'(', ')')

def isSimplePattern(pat: String): Boolean = {
pat.size > 0 && pat.forall(c => !specialChars.contains(c))
Expand Down Expand Up @@ -1107,15 +1107,23 @@ class GpuRLikeMeta(
def optimizeSimplePattern(rhs: Expression, lhs: Expression, parts: List[RegexprPart]):
GpuExpression = {
parts match {
case Wildcard :: rest => optimizeSimplePattern(rhs, lhs, rest)
case Start :: Fixstring(s) :: List(End) => GpuEqualTo(lhs, GpuLiteral(s, StringType))
case Wildcard :: rest => {
optimizeSimplePattern(rhs, lhs, rest)
}
case Start :: Wildcard :: List(End) => {
GpuEqualTo(lhs, rhs)
}
case Start :: Fixstring(s) :: rest
if rest.forall(_ == Wildcard) || rest == List() =>
if rest.forall(_ == Wildcard) || rest == List() => {
GpuStartsWith(lhs, GpuLiteral(s, StringType))
case Fixstring(s) :: List(End) => GpuEndsWith(lhs, GpuLiteral(s, StringType))
}
case Fixstring(s) :: List(End) => {
GpuEndsWith(lhs, GpuLiteral(s, StringType))
}
case Fixstring(s) :: rest
if rest == List() || rest.forall(_ == Wildcard) =>
if rest == List() || rest.forall(_ == Wildcard) => {
GpuContains(lhs, GpuLiteral(s, StringType))
}
case _ => {
val patternStr = pattern.getOrElse(throw new IllegalStateException(
"Expression has not been tagged with cuDF regex pattern"))
Expand Down

0 comments on commit 8b88378

Please sign in to comment.