Skip to content

Commit

Permalink
use multiple contains in rlike rewrite
Browse files Browse the repository at this point in the history
Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven committed Dec 3, 2024
1 parent 017fdef commit cedcd58
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import scala.collection.mutable.ListBuffer
import com.nvidia.spark.rapids.GpuOverrides.regexMetaChars
import com.nvidia.spark.rapids.RegexParser.toReadableString

import org.apache.spark.unsafe.types.UTF8String

/**
* Regular expression parser based on a Pratt Parser design.
*
Expand Down Expand Up @@ -1988,7 +1990,7 @@ object RegexOptimizationType {
case class Contains(literal: String) extends RegexOptimizationType
case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int)
extends RegexOptimizationType
case class MultipleContains(literals: Seq[String]) extends RegexOptimizationType
case class MultipleContains(literals: Seq[UTF8String]) extends RegexOptimizationType
case object NoOptimization extends RegexOptimizationType
}

Expand Down Expand Up @@ -2057,16 +2059,17 @@ object RegexRewrite {
}
}

private def getMultipleContainsLiterals(ast: RegexAST): Seq[String] = {
private def getMultipleContainsLiterals(ast: RegexAST): Seq[UTF8String] = {
ast match {
case RegexGroup(_, term, _) => getMultipleContainsLiterals(term)
case RegexChoice(RegexSequence(parts), ls) if isLiteralString(parts) => {
getMultipleContainsLiterals(ls) match {
case Seq() => Seq.empty
case literals => RegexCharsToString(parts) +: literals
case literals => UTF8String.fromString(RegexCharsToString(parts)) +: literals
}
}
case RegexSequence(parts) if (isLiteralString(parts)) => Seq(RegexCharsToString(parts))
case RegexSequence(parts) if (isLiteralString(parts)) =>
Seq(UTF8String.fromString(RegexCharsToString(parts)))
case _ => Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,9 @@ class GpuRLikeMeta(
}
case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType))
case Contains(s) => GpuContains(lhs, GpuLiteral(UTF8String.fromString(s), StringType))
case MultipleContains(ls) => GpuMultipleContains(lhs, ls)
case MultipleContains(ls) => {
GpuContainsAny(lhs, ls)
}
case PrefixRange(s, length, start, end) =>
GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end)
case _ => throw new IllegalStateException("Unexpected optimization type")
Expand Down Expand Up @@ -1233,7 +1235,7 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String)
override def dataType: DataType = BooleanType
}

case class GpuMultipleContains(input: Expression, searchList: Seq[String])
case class GpuContainsAny(input: Expression, targets: Seq[UTF8String])
extends GpuUnaryExpression with ImplicitCastInputTypes with NullIntolerantShim {

override def dataType: DataType = BooleanType
Expand All @@ -1243,17 +1245,15 @@ case class GpuMultipleContains(input: Expression, searchList: Seq[String])
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

override def doColumnar(input: GpuColumnVector): ColumnVector = {
assert(searchList.length > 1)
val accInit = withResource(Scalar.fromString(searchList.head)) { searchScalar =>
input.getBase.stringContains(searchScalar)
val targetsBytes = targets.map(t => t.getBytes).toArray
val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv =>
input.getBase.stringContains(targetsCv)
}
searchList.tail.foldLeft(accInit) { (acc, search) =>
val containsSearch = withResource(Scalar.fromString(search)) { searchScalar =>
input.getBase.stringContains(searchScalar)
}
withResource(acc) { _ =>
withResource(containsSearch) { _ =>
acc.or(containsSearch)
// boolCvs is a sequence of ColumnVectors, we need to OR them together
boolCvs.reduce {
(cv1, cv2) => withResource(cv1) { cv1 =>
withResource(cv2) { cv2 =>
cv1.or(cv2)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package com.nvidia.spark.rapids

import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.unsafe.types.UTF8String

class RegularExpressionRewriteSuite extends AnyFunSuite {

private def verifyRewritePattern(patterns: Seq[String], excepted: Seq[RegexOptimizationType]):
Unit = {
private def verifyRewritePattern(patterns: Seq[String],
excepted: Seq[RegexOptimizationType]): Unit = {
val results = patterns.map { pattern =>
val ast = new RegexParser(pattern).parse()
RegexRewrite.matchSimplePattern(ast)
Expand Down Expand Up @@ -87,11 +89,11 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
"(火花|急流)"
)
val excepted = Seq(
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def", "ghi")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("火花", "急流"))
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def", "ghi").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("火花", "急流").map(UTF8String.fromString))
)
verifyRewritePattern(patterns, excepted)
}
Expand Down

0 comments on commit cedcd58

Please sign in to comment.