diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py index b95ed53f398..aaa390476a4 100644 --- a/integration_tests/src/main/python/conditionals_test.py +++ b/integration_tests/src/main/python/conditionals_test.py @@ -379,3 +379,30 @@ def test_case_when_all_then_values_are_scalars_with_nulls(): "tab", sql_without_else, conf = {'spark.rapids.sql.case_when.fuse': 'true'}) + +@pytest.mark.parametrize('combine_string_contains_enabled', [True, False]) +def test_combine_string_contains_in_case_when(combine_string_contains_enabled): + data_gen = [("c1", string_gen)] + sql = """ + SELECT + CASE + WHEN INSTR(c1, 'a') > 0 THEN 'a' + WHEN INSTR(c1, 'b') > 0 THEN 'b' + WHEN INSTR(c1, 'c') > 0 THEN 'c' + ELSE '' + END as output_1, + CASE + WHEN INSTR(c1, 'c') > 0 THEN 'c' + WHEN INSTR(c1, 'd') > 0 THEN 'd' + WHEN INSTR(c1, 'e') > 0 THEN 'e' + ELSE '' + END as output_2 + from tab + """ + # spark.rapids.sql.combined.expressions.enabled is true by default + assert_gpu_and_cpu_are_equal_sql( + lambda spark : gen_df(spark, data_gen), + "tab", + sql, + { "spark.rapids.sql.expression.combined.GpuContains" : combine_string_contains_enabled} + ) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index f933b7e51a5..c3f737409cd 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -20,6 +20,8 @@ import java.nio.charset.Charset import java.text.DecimalFormatSymbols import java.util.{EnumSet, Locale, Optional} +import scala.annotation.tailrec +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar} @@ -32,6 +34,7 @@ import com.nvidia.spark.rapids.jni.RegexRewriteUtils import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.rapids.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -391,7 +394,8 @@ case class GpuContains(left: Expression, right: Expression) extends GpuBinaryExpressionArgsAnyScalar with Predicate with ImplicitCastInputTypes - with NullIntolerant { + with NullIntolerant + with GpuCombinable { override def inputTypes: Seq[DataType] = Seq(StringType) @@ -411,6 +415,106 @@ case class GpuContains(left: Expression, right: Expression) doColumnar(expandedLhs, rhs) } } + + /** + * Get a combiner that can be used to find candidates to combine + */ + override def getCombiner(): GpuExpressionCombiner = new ContainsCombiner(this) +} + +case class GpuMultiContains(left: Expression, targets: Seq[UTF8String], output: StructType) + extends GpuExpression with ShimExpression { + + override def otherCopyArgs: Seq[AnyRef] = Nil + + override def dataType: DataType = output + + override def nullable: Boolean = false + + override def prettyName: String = "multi_contains" + + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + val targetsBytes = targets.map(t => t.getBytes).toArray + val boolCvs = withResource(ColumnVector.fromUTF8Strings(targetsBytes: _*)) { targetsCv => + withResource(left.columnarEval(batch)) { lhs => + lhs.getBase.stringContains(targetsCv) + } + } + withResource(boolCvs) { _ => + val retView = ColumnView.makeStructView(batch.numRows(), boolCvs: _*) + GpuColumnVector.from(retView.copyToColumnVector(), dataType) + } + } + + override def children: Seq[Expression] = Seq(left) +} + +class ContainsCombiner(private val exp: GpuContains) extends GpuExpressionCombiner { + private var outputLocation = 0 + /** + * A mapping between an expression and where in the output struct of + * the MultiGetJsonObject will the output be. + */ + private val toCombine = mutable.HashMap.empty[GpuExpressionEquals, Int] + addExpression(exp) + + override def toString: String = s"ContainsCombiner $toCombine" + + override def hashCode: Int = { + // We already know that we are Contains, and what we can combine is based + // on the string column being the same. + "Contains".hashCode + (exp.left.semanticHash() * 17) + } + + /** + * only combine when targets are literals + */ + override def equals(o: Any): Boolean = o match { + case other: ContainsCombiner => exp.left.semanticEquals(other.exp.left) && + exp.right.isInstanceOf[GpuLiteral] && other.exp.right.isInstanceOf[GpuLiteral] + case _ => false + } + + override def addExpression(e: Expression): Unit = { + val localOutputLocation = outputLocation + outputLocation += 1 + val key = GpuExpressionEquals(e) + if (!toCombine.contains(key)) { + toCombine.put(key, localOutputLocation) + } + } + + override def useCount: Int = toCombine.size + + private def fieldName(id: Int): String = + s"_mc_$id" + + @tailrec + private def extractLiteral(exp: Expression): GpuLiteral = exp match { + case l: GpuLiteral => l + case a: Alias => extractLiteral(a.child) + case other => throw new RuntimeException("Unsupported expression in contains combiner, " + + "should be a literal type, actual type is " + other.getClass.getName) + } + + private lazy val multiContains: GpuMultiContains = { + val input = toCombine.head._1.e.asInstanceOf[GpuContains].left + val fieldsNPaths = toCombine.toSeq.map { + case (k, id) => + (id, k.e) + }.sortBy(_._1).map { + case (id, e: GpuContains) => + val target = extractLiteral(e.right).value.asInstanceOf[UTF8String] + (StructField(fieldName(id), e.dataType, e.nullable), target) + } + val dt = StructType(fieldsNPaths.map(_._1)) + GpuMultiContains(input, fieldsNPaths.map(_._2), dt) + } + + override def getReplacementExpression(e: Expression): Expression = { + val localId = toCombine(GpuExpressionEquals(e)) + GpuGetStructField(multiContains, localId, Some(fieldName(localId))) + } } case class GpuSubstring(str: Expression, pos: Expression, len: Expression) @@ -1097,7 +1201,7 @@ class GpuRLikeMeta( GpuRLike(lhs, rhs, patternStr) } case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType)) - case Contains(s) => GpuContains(lhs, GpuLiteral(s, StringType)) + case Contains(s) => GpuContains(lhs, GpuLiteral(UTF8String.fromString(s), StringType)) case MultipleContains(ls) => GpuMultipleContains(lhs, ls) case PrefixRange(s, length, start, end) => GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end)