Skip to content

Commit

Permalink
Rewrite multiple literal choice regex to multiple contains in rlike (N…
Browse files Browse the repository at this point in the history
…VIDIA#10977)

* rewrite multiple literal choice to multiple contains, wip

Signed-off-by: Haoyang Li <[email protected]>

* fix bug

Signed-off-by: Haoyang Li <[email protected]>

* optimize memory

Signed-off-by: Haoyang Li <[email protected]>

* remove debug log

Signed-off-by: Haoyang Li <[email protected]>

* address comments

Signed-off-by: Haoyang Li <[email protected]>

* Apply suggestions from code review

Co-authored-by: Gera Shegalov <[email protected]>

* support abc|def case

Signed-off-by: Haoyang Li <[email protected]>

* fix 2.13

Signed-off-by: Haoyang Li <[email protected]>

* fix 2.13 build

Signed-off-by: Haoyang Li <[email protected]>

---------

Signed-off-by: Haoyang Li <[email protected]>
Co-authored-by: Gera Shegalov <[email protected]>
  • Loading branch information
thirtiseven and gerashegalov authored Jun 12, 2024
1 parent 9f73672 commit 2cf5934
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 45 deletions.
8 changes: 7 additions & 1 deletion integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def test_rlike_rewrite_optimization():
'rlike(a, "(.*)(abb)(.*)")',
'rlike(a, "^(abb)(.*)")',
'rlike(a, "^abb")',
'rlike(a, "^.*(aaa)")',
'rlike(a, "\\\\A(abb)(.*)")',
'rlike(a, "\\\\Aabb")',
'rlike(a, "^(abb)\\\\Z")',
Expand All @@ -466,7 +467,12 @@ def test_rlike_rewrite_optimization():
'rlike(a, "ab[a-c]{3}")',
'rlike(a, "a[a-c]{1,3}")',
'rlike(a, "a[a-c]{1,}")',
'rlike(a, "a[a-c]+")'),
'rlike(a, "a[a-c]+")',
'rlike(a, "(aaa|bbb|ccc)")',
'rlike(a, ".*.*(aaa|bbb).*.*")',
'rlike(a, "^.*(aaa|bbb|ccc)")',
'rlike(a, "aaa|bbb")',
'rlike(a, "aaa|(bbb|ccc)")'),
conf=_regexp_conf)

def test_regexp_replace_character_set_negated():
Expand Down
94 changes: 57 additions & 37 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2026,6 +2026,7 @@ object RegexOptimizationType {
case class Contains(literal: String) extends RegexOptimizationType
case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int)
extends RegexOptimizationType
case class MultipleContains(literals: Seq[String]) extends RegexOptimizationType
case object NoOptimization extends RegexOptimizationType
}

Expand Down Expand Up @@ -2091,6 +2092,20 @@ object RegexRewrite {
}
}

private def getMultipleContainsLiterals(ast: RegexAST): Seq[String] = {
ast match {
case RegexGroup(_, term, _) => getMultipleContainsLiterals(term)
case RegexChoice(RegexSequence(parts), ls) if isLiteralString(parts) => {
getMultipleContainsLiterals(ls) match {
case Seq() => Seq.empty
case literals => RegexCharsToString(parts) +: literals
}
}
case RegexSequence(parts) if (isLiteralString(parts)) => Seq(RegexCharsToString(parts))
case _ => Seq.empty
}
}

private def isWildcard(ast: RegexAST): Boolean = {
ast match {
case RegexRepetition(RegexChar('.'), SimpleQuantifier('*')) => true
Expand All @@ -2101,11 +2116,8 @@ object RegexRewrite {
}

private def stripLeadingWildcards(astLs: collection.Seq[RegexAST]):
collection.Seq[RegexAST] = astLs match {
case (RegexChar('^') | RegexEscaped('A')) :: tail =>
// if the pattern starts with ^ or \A, strip it too
tail.dropWhile(isWildcard)
case _ => astLs.dropWhile(isWildcard)
collection.Seq[RegexAST] = {
astLs.dropWhile(isWildcard)
}

private def stripTailingWildcards(astLs: collection.Seq[RegexAST]):
Expand All @@ -2124,40 +2136,48 @@ object RegexRewrite {
* Matches the given regex ast to a regex optimization type for regex rewrite
* optimization.
*
* @param ast unparsed children of the Abstract Syntax Tree parsed from a regex pattern.
* @param ast Abstract Syntax Tree parsed from a regex pattern.
* @return The `RegexOptimizationType` for the given pattern.
*/
@scala.annotation.tailrec
def matchSimplePattern(ast: Seq[RegexAST]): RegexOptimizationType = {
ast match {
case (RegexChar('^') | RegexEscaped('A')) :: astTail =>
val noTrailingWildCards = stripTailingWildcards(astTail)
if (isLiteralString(noTrailingWildCards)) {
// ^literal.* => startsWith literal
RegexOptimizationType.StartsWith(RegexCharsToString(noTrailingWildCards))
} else {
val noWildCards = stripLeadingWildcards(noTrailingWildCards)
if (noWildCards.length == noTrailingWildCards.length) {
// TODO startsWith with PrefIxRange
RegexOptimizationType.NoOptimization
} else {
matchSimplePattern(astTail)
}
}
case astLs => {
val noStartsWithAst = stripTailingWildcards(stripLeadingWildcards(astLs))
val prefixRangeInfo = getPrefixRangePattern(noStartsWithAst)
if (prefixRangeInfo.isDefined) {
val (prefix, length, start, end) = prefixRangeInfo.get
// (literal[a-b]{x,y}) => prefix range pattern
RegexOptimizationType.PrefixRange(prefix, length, start, end)
} else if (isLiteralString(noStartsWithAst)) {
// literal.* or (literal).* => contains literal
RegexOptimizationType.Contains(RegexCharsToString(noStartsWithAst))
} else {
RegexOptimizationType.NoOptimization
}
def matchSimplePattern(ast: RegexAST): RegexOptimizationType = {
val astLs = ast match {
case RegexSequence(_) => ast.children()
case _ => Seq(ast)
}
val noTailingWildcards = stripTailingWildcards(astLs)
if (noTailingWildcards.headOption.exists(
ast => ast == RegexChar('^') || ast == RegexEscaped('A'))) {
val possibleLiteral = noTailingWildcards.drop(1)
if (isLiteralString(possibleLiteral)) {
return RegexOptimizationType.StartsWith(RegexCharsToString(possibleLiteral))
}
}

val noStartsWithAst = stripLeadingWildcards(noTailingWildcards)

// Check if the pattern is a contains literal pattern
if (isLiteralString(noStartsWithAst)) {
// literal or .*(literal).* => contains literal
return RegexOptimizationType.Contains(RegexCharsToString(noStartsWithAst))
}

// Check if the pattern is a multiple contains literal pattern (e.g. "abc|def|ghi")
if (noStartsWithAst.length == 1) {
val containsLiterals = getMultipleContainsLiterals(noStartsWithAst.head)
if (!containsLiterals.isEmpty) {
return RegexOptimizationType.MultipleContains(containsLiterals)
}
}

// Check if the pattern is a prefix range pattern (e.g. "abc[a-z]{3}")
val prefixRangeInfo = getPrefixRangePattern(noStartsWithAst)
if (prefixRangeInfo.isDefined) {
val (prefix, length, start, end) = prefixRangeInfo.get
// (literal[a-b]{x,y}) => prefix range pattern
return RegexOptimizationType.PrefixRange(prefix, length, start, end)
}

// return NoOptimization if the pattern is not a simple pattern and use cuDF
RegexOptimizationType.NoOptimization
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ class GpuRLikeMeta(
val originalPattern = str.toString
val regexAst = new RegexParser(originalPattern).parse()
if (conf.isRlikeRegexRewriteEnabled) {
rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst.children())
rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst)
}
val (transpiledAST, _) = new CudfRegexTranspiler(RegexFindMode)
.getTranspiledAST(regexAst, None, None)
Expand All @@ -1097,6 +1097,7 @@ class GpuRLikeMeta(
}
case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType))
case Contains(s) => GpuContains(lhs, GpuLiteral(s, StringType))
case MultipleContains(ls) => GpuMultipleContains(lhs, ls)
case PrefixRange(s, length, start, end) =>
GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end)
case _ => throw new IllegalStateException("Unexpected optimization type")
Expand Down Expand Up @@ -1126,6 +1127,33 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String)
override def dataType: DataType = BooleanType
}

case class GpuMultipleContains(input: Expression, searchList: Seq[String])
extends GpuUnaryExpression with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = BooleanType

override def child: Expression = input

override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

override def doColumnar(input: GpuColumnVector): ColumnVector = {
assert(searchList.length > 1)
val accInit = withResource(Scalar.fromString(searchList.head)) { searchScalar =>
input.getBase.stringContains(searchScalar)
}
searchList.tail.foldLeft(accInit) { (acc, search) =>
val containsSearch = withResource(Scalar.fromString(search)) { searchScalar =>
input.getBase.stringContains(searchScalar)
}
withResource(acc) { _ =>
withResource(containsSearch) { _ =>
acc.or(containsSearch)
}
}
}
}
}

case class GpuLiteralRangePattern(left: Expression, right: Expression,
length: Int, start: Int, end: Int)
extends GpuBinaryExpressionArgsAnyScalar with ImplicitCastInputTypes with NullIntolerant {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
Unit = {
val results = patterns.map { pattern =>
val ast = new RegexParser(pattern).parse()
RegexRewrite.matchSimplePattern(ast.children())
RegexRewrite.matchSimplePattern(ast)
}
assert(results == excepted)
}
Expand All @@ -40,9 +40,9 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
test("regex rewrite contains") {
import RegexOptimizationType._
val patterns = Seq(".*abc.*", ".*(abc).*", "^.*(abc).*$", "^.*(.*)(abc).*.*",
raw".*\w.*\Z", raw".*..*\Z")
val excepted = Seq(Contains("abc"), Contains("abc"), NoOptimization, Contains("abc"),
NoOptimization, NoOptimization)
raw".*\w.*\Z", raw".*..*\Z", "^(.*)(abc)")
val excepted = Seq(Contains("abc"), Contains("abc"), NoOptimization, NoOptimization,
NoOptimization, NoOptimization, NoOptimization)
verifyRewritePattern(patterns, excepted)
}

Expand All @@ -67,8 +67,27 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
PrefixRange("火花急流", 1, 19968, 40869),
NoOptimization, // starts with PrefixRange not supported
NoOptimization, // starts with PrefixRange not supported
PrefixRange("", 6, 48, 57),
PrefixRange("", 3, 48, 57)
NoOptimization, // .* can't match line break so can't be optimized
NoOptimization // .* can't match line break so can't be optimized
)
verifyRewritePattern(patterns, excepted)
}

test("regex rewrite multiple contains") {
import RegexOptimizationType._
val patterns = Seq(
"(abc|def).*",
".*(abc|def|ghi).*",
"((abc)|(def))",
"(abc)|(def)",
"(火花|急流)"
)
val excepted = Seq(
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def", "ghi")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("火花", "急流"))
)
verifyRewritePattern(patterns, excepted)
}
Expand Down

0 comments on commit 2cf5934

Please sign in to comment.