Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use faster multi-contains in rlike regex rewrite #11810

Merged
merged 8 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import scala.collection.mutable.ListBuffer
import com.nvidia.spark.rapids.GpuOverrides.regexMetaChars
import com.nvidia.spark.rapids.RegexParser.toReadableString

import org.apache.spark.unsafe.types.UTF8String

/**
* Regular expression parser based on a Pratt Parser design.
*
Expand Down Expand Up @@ -1988,7 +1990,7 @@ object RegexOptimizationType {
case class Contains(literal: String) extends RegexOptimizationType
case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int)
extends RegexOptimizationType
case class MultipleContains(literals: Seq[String]) extends RegexOptimizationType
case class MultipleContains(literals: Seq[UTF8String]) extends RegexOptimizationType
case object NoOptimization extends RegexOptimizationType
}

Expand Down Expand Up @@ -2057,16 +2059,17 @@ object RegexRewrite {
}
}

private def getMultipleContainsLiterals(ast: RegexAST): Seq[String] = {
private def getMultipleContainsLiterals(ast: RegexAST): Seq[UTF8String] = {
ast match {
case RegexGroup(_, term, _) => getMultipleContainsLiterals(term)
case RegexChoice(RegexSequence(parts), ls) if isLiteralString(parts) => {
getMultipleContainsLiterals(ls) match {
case Seq() => Seq.empty
case literals => RegexCharsToString(parts) +: literals
case literals => UTF8String.fromString(RegexCharsToString(parts)) +: literals
}
}
case RegexSequence(parts) if (isLiteralString(parts)) => Seq(RegexCharsToString(parts))
case RegexSequence(parts) if (isLiteralString(parts)) =>
Seq(UTF8String.fromString(RegexCharsToString(parts)))
case _ => Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ class GpuRLikeMeta(
}
case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType))
case Contains(s) => GpuContains(lhs, GpuLiteral(UTF8String.fromString(s), StringType))
case MultipleContains(ls) => GpuMultipleContains(lhs, ls)
case MultipleContains(ls) => GpuContainsAny(lhs, ls)
case PrefixRange(s, length, start, end) =>
GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end)
case _ => throw new IllegalStateException("Unexpected optimization type")
Expand Down Expand Up @@ -1233,7 +1233,7 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String)
override def dataType: DataType = BooleanType
}

case class GpuMultipleContains(input: Expression, searchList: Seq[String])
case class GpuContainsAny(input: Expression, targets: Seq[UTF8String])
extends GpuUnaryExpression with ImplicitCastInputTypes with NullIntolerantShim {

override def dataType: DataType = BooleanType
Expand All @@ -1243,20 +1243,26 @@ case class GpuMultipleContains(input: Expression, searchList: Seq[String])
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

override def doColumnar(input: GpuColumnVector): ColumnVector = {
assert(searchList.length > 1)
val accInit = withResource(Scalar.fromString(searchList.head)) { searchScalar =>
input.getBase.stringContains(searchScalar)
val targetsBytes = targets.map(t => t.getBytes).toArray
val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv =>
input.getBase.stringContains(targetsCv)
}
searchList.tail.foldLeft(accInit) { (acc, search) =>
val containsSearch = withResource(Scalar.fromString(search)) { searchScalar =>
input.getBase.stringContains(searchScalar)
}
withResource(acc) { _ =>
withResource(containsSearch) { _ =>
acc.or(containsSearch)
var ret: ColumnVector = null
withResource(boolCvs) { _ =>
boolCvs.indices.foreach { i =>
if (ret == null) {
ret = boolCvs(i)
boolCvs(i) = null
} else {
val tmp = ret.or(boolCvs(i))
ret.close()
ret = tmp
boolCvs(i).close()
boolCvs(i) = null
}
}
}
ret
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package com.nvidia.spark.rapids

import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.unsafe.types.UTF8String

class RegularExpressionRewriteSuite extends AnyFunSuite {

private def verifyRewritePattern(patterns: Seq[String], excepted: Seq[RegexOptimizationType]):
Unit = {
private def verifyRewritePattern(patterns: Seq[String],
excepted: Seq[RegexOptimizationType]): Unit = {
val results = patterns.map { pattern =>
val ast = new RegexParser(pattern).parse()
RegexRewrite.matchSimplePattern(ast)
Expand Down Expand Up @@ -87,11 +89,11 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
"(火花|急流)"
)
val excepted = Seq(
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def", "ghi")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("abc", "def")),
MultipleContains(Seq("火花", "急流"))
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def", "ghi").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("abc", "def").map(UTF8String.fromString)),
MultipleContains(Seq("火花", "急流").map(UTF8String.fromString))
)
verifyRewritePattern(patterns, excepted)
}
Expand Down
Loading