diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 89fd5bf9191..2b0b46f55ea 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ListBuffer import com.nvidia.spark.rapids.GpuOverrides.regexMetaChars import com.nvidia.spark.rapids.RegexParser.toReadableString +import org.apache.spark.unsafe.types.UTF8String + /** * Regular expression parser based on a Pratt Parser design. * @@ -1988,7 +1990,7 @@ object RegexOptimizationType { case class Contains(literal: String) extends RegexOptimizationType case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int) extends RegexOptimizationType - case class MultipleContains(literals: Seq[String]) extends RegexOptimizationType + case class MultipleContains(literals: Seq[UTF8String]) extends RegexOptimizationType case object NoOptimization extends RegexOptimizationType } @@ -2057,16 +2059,17 @@ object RegexRewrite { } } - private def getMultipleContainsLiterals(ast: RegexAST): Seq[String] = { + private def getMultipleContainsLiterals(ast: RegexAST): Seq[UTF8String] = { ast match { case RegexGroup(_, term, _) => getMultipleContainsLiterals(term) case RegexChoice(RegexSequence(parts), ls) if isLiteralString(parts) => { getMultipleContainsLiterals(ls) match { case Seq() => Seq.empty - case literals => RegexCharsToString(parts) +: literals + case literals => UTF8String.fromString(RegexCharsToString(parts)) +: literals } } - case RegexSequence(parts) if (isLiteralString(parts)) => Seq(RegexCharsToString(parts)) + case RegexSequence(parts) if (isLiteralString(parts)) => + Seq(UTF8String.fromString(RegexCharsToString(parts))) case _ => Seq.empty } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 79db87f1736..f668195abc7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -24,7 +24,7 @@ import scala.annotation.tailrec import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar} +import ai.rapids.cudf.{ast, BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -1202,7 +1202,7 @@ class GpuRLikeMeta( } case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType)) case Contains(s) => GpuContains(lhs, GpuLiteral(UTF8String.fromString(s), StringType)) - case MultipleContains(ls) => GpuMultipleContains(lhs, ls) + case MultipleContains(ls) => GpuContainsAny(lhs, ls) case PrefixRange(s, length, start, end) => GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end) case _ => throw new IllegalStateException("Unexpected optimization type") @@ -1233,7 +1233,7 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String) override def dataType: DataType = BooleanType } -case class GpuMultipleContains(input: Expression, searchList: Seq[String]) +case class GpuContainsAny(input: Expression, targets: Seq[UTF8String]) extends GpuUnaryExpression with ImplicitCastInputTypes with NullIntolerantShim { override def dataType: DataType = BooleanType @@ -1242,19 +1242,24 @@ case class GpuMultipleContains(input: Expression, searchList: Seq[String]) override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + def multiOrsAst: ast.AstExpression = { + (1 until targets.length) + .foldLeft(new ast.ColumnReference(0).asInstanceOf[ast.AstExpression]) { (acc, id) => + new ast.BinaryOperation(ast.BinaryOperator.NULL_LOGICAL_OR, acc, new ast.ColumnReference(id)) + } + } + override def doColumnar(input: GpuColumnVector): ColumnVector = { - assert(searchList.length > 1) - val accInit = withResource(Scalar.fromString(searchList.head)) { searchScalar => - input.getBase.stringContains(searchScalar) + val targetsBytes = targets.map(t => t.getBytes).toArray + val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => + input.getBase.stringContains(targetsCv) } - searchList.tail.foldLeft(accInit) { (acc, search) => - val containsSearch = withResource(Scalar.fromString(search)) { searchScalar => - input.getBase.stringContains(searchScalar) - } - withResource(acc) { _ => - withResource(containsSearch) { _ => - acc.or(containsSearch) - } + val boolTable = withResource(boolCvs) { _ => + new Table(boolCvs: _*) + } + withResource(boolTable) { _ => + withResource(multiOrsAst.compile()) { compiledAst => + compiledAst.computeColumn(boolTable) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala index a55815b95ef..12e12fd957f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala @@ -17,10 +17,12 @@ package com.nvidia.spark.rapids import org.scalatest.funsuite.AnyFunSuite +import org.apache.spark.unsafe.types.UTF8String + class RegularExpressionRewriteSuite extends AnyFunSuite { - private def verifyRewritePattern(patterns: Seq[String], excepted: Seq[RegexOptimizationType]): - Unit = { + private def verifyRewritePattern(patterns: Seq[String], + excepted: Seq[RegexOptimizationType]): Unit = { val results = patterns.map { pattern => val ast = new RegexParser(pattern).parse() RegexRewrite.matchSimplePattern(ast) @@ -87,11 +89,11 @@ class RegularExpressionRewriteSuite extends AnyFunSuite { "(火花|急流)" ) val excepted = Seq( - MultipleContains(Seq("abc", "def")), - MultipleContains(Seq("abc", "def", "ghi")), - MultipleContains(Seq("abc", "def")), - MultipleContains(Seq("abc", "def")), - MultipleContains(Seq("火花", "急流")) + MultipleContains(Seq("abc", "def").map(UTF8String.fromString)), + MultipleContains(Seq("abc", "def", "ghi").map(UTF8String.fromString)), + MultipleContains(Seq("abc", "def").map(UTF8String.fromString)), + MultipleContains(Seq("abc", "def").map(UTF8String.fromString)), + MultipleContains(Seq("火花", "急流").map(UTF8String.fromString)) ) verifyRewritePattern(patterns, excepted) }