Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use faster multi-contains in rlike regex rewrite #11810

Merged
merged 8 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import scala.collection.mutable.ListBuffer
import com.nvidia.spark.rapids.GpuOverrides.regexMetaChars
import com.nvidia.spark.rapids.RegexParser.toReadableString

import org.apache.spark.unsafe.types.UTF8String

/**
* Regular expression parser based on a Pratt Parser design.
*
Expand Down Expand Up @@ -1988,7 +1990,7 @@ object RegexOptimizationType {
case class Contains(literal: String) extends RegexOptimizationType
case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int)
extends RegexOptimizationType
case class MultipleContains(literals: Seq[String]) extends RegexOptimizationType
case class MultipleContains(literals: Seq[UTF8String]) extends RegexOptimizationType
case object NoOptimization extends RegexOptimizationType
}

Expand Down Expand Up @@ -2057,16 +2059,17 @@ object RegexRewrite {
}
}

private def getMultipleContainsLiterals(ast: RegexAST): Seq[String] = {
private def getMultipleContainsLiterals(ast: RegexAST): Seq[UTF8String] = {
ast match {
case RegexGroup(_, term, _) => getMultipleContainsLiterals(term)
case RegexChoice(RegexSequence(parts), ls) if isLiteralString(parts) => {
getMultipleContainsLiterals(ls) match {
case Seq() => Seq.empty
case literals => RegexCharsToString(parts) +: literals
case literals => UTF8String.fromString(RegexCharsToString(parts)) +: literals
}
}
case RegexSequence(parts) if (isLiteralString(parts)) => Seq(RegexCharsToString(parts))
case RegexSequence(parts) if (isLiteralString(parts)) =>
Seq(UTF8String.fromString(RegexCharsToString(parts)))
case _ => Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ class GpuRLikeMeta(
}
case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType))
case Contains(s) => GpuContains(lhs, GpuLiteral(UTF8String.fromString(s), StringType))
case MultipleContains(ls) => GpuMultipleContains(lhs, ls)
case MultipleContains(ls) => GpuContainsAny(lhs, ls)
case PrefixRange(s, length, start, end) =>
GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end)
case _ => throw new IllegalStateException("Unexpected optimization type")
Expand Down Expand Up @@ -1233,7 +1233,7 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String)
override def dataType: DataType = BooleanType
}

case class GpuMultipleContains(input: Expression, searchList: Seq[String])
case class GpuContainsAny(input: Expression, targets: Seq[UTF8String])
extends GpuUnaryExpression with ImplicitCastInputTypes with NullIntolerantShim {

override def dataType: DataType = BooleanType
Expand All @@ -1243,19 +1243,15 @@ case class GpuMultipleContains(input: Expression, searchList: Seq[String])
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

override def doColumnar(input: GpuColumnVector): ColumnVector = {
assert(searchList.length > 1)
val accInit = withResource(Scalar.fromString(searchList.head)) { searchScalar =>
input.getBase.stringContains(searchScalar)
val targetsBytes = targets.map(t => t.getBytes).toArray
val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv =>
input.getBase.stringContains(targetsCv)
}
searchList.tail.foldLeft(accInit) { (acc, search) =>
val containsSearch = withResource(Scalar.fromString(search)) { searchScalar =>
input.getBase.stringContains(searchScalar)
}
withResource(acc) { _ =>
withResource(containsSearch) { _ =>
acc.or(containsSearch)
}
withResource(boolCvs) { _ =>
val falseCv = withResource(Scalar.fromBool(false)) { falseScalar =>
ColumnVector.fromScalar(falseScalar, input.getRowCount.toInt)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't reduceLeft do what you want and not require creating a boolean column? Also how to do you close the intermediate values when running foldLeft, or reduceLeft.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't reduceLeft do what you want and not require creating a boolean column?

For the lambda (l, r) => l or r) in the reduce, l can either be the first value which protected with the outer withResource or an intermediate value that needs to add a withResource, so I failed to find an unified way to write the lambda. Updated to a little tricky way with foldLeft to save a boolean column.

Also how to do you close the intermediate values when running foldLeft, or reduceLeft.

They are closed by the withResource in (l, r) => withResource(l) { _ => l.or(r)}, where l is always an intermediate value.

}
boolCvs.foldLeft(falseCv)((l, r) => withResource(l) { _ => l.or(r)})
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package com.nvidia.spark.rapids

import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.unsafe.types.UTF8String

class RegularExpressionRewriteSuite extends AnyFunSuite {

private def verifyRewritePattern(patterns: Seq[String], excepted: Seq[RegexOptimizationType]):
Unit = {
private def verifyRewritePattern(patterns: Seq[String],
excepted: Seq[RegexOptimizationType]): Unit = {
val results = patterns.map { pattern =>
val ast = new RegexParser(pattern).parse()
RegexRewrite.matchSimplePattern(ast)
Expand Down Expand Up @@ -87,11 +89,11 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
"(火花|急流)"
)
val excepted = Seq(
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def", "ghi")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("火花", "急流"))
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def", "ghi").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("火花", "急流").map(UTF8String.fromString))
)
verifyRewritePattern(patterns, excepted)
}
Expand Down
Loading