From 7c43e698511b556311165c9e022a8347bd4308af Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 18 Jun 2024 18:27:49 +0800 Subject: [PATCH] Use Table.gather instead of a custom kernel Signed-off-by: Chong Gao --- .../com/nvidia/spark/rapids/RapidsConf.scala | 2 +- .../spark/rapids/conditionalExpressions.scala | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 02268c21a2f..4b493793bf4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -2276,7 +2276,7 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. val CASE_WHEN_FUSE = conf("spark.rapids.sql.case_when.fuse") - .doc("If when branches is greater than 3 and all then values in case when are string " + + .doc("If when branches is greater than 2 and all then/else values in case when are string " + "scalar, fuse mode improves the performance. By default this is enabled.") .internal() .booleanConf diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala index 271267441c3..1c0ff3f01f5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala @@ -369,10 +369,14 @@ case class GpuCaseWhen( ) { // when branches size > 2; // return type is string type; - // all the then and else exprs are Scalars. + // all the then and else expressions are Scalars. // Avoid to use multiple `computeIfElse`s which will create multiple temp columns - // 1. select first true index from bool columns + // 1. select first true index from bool columns, if no true, index will be out of bound + // e.g.: + // case when bool result column 0: true, false, false + // case when bool result column 1: false, true, false + // result is: [0, 1, 2] val whenBoolCols = branches.safeMap(_._1.columnarEval(batch).getBase).toArray val firstTrueIndex: ColumnVector = withResource(whenBoolCols) { _ => CaseWhen.selectFirstTrueIndex(whenBoolCols) @@ -387,14 +391,22 @@ case class GpuCaseWhen( .asInstanceOf[UTF8String].getBytes) val scalarCol = ColumnVector.fromUTF8Strings(scalarsBytes: _*) withResource(scalarCol) { _ => - // 3. execute final select - val finalRet = CaseWhen.selectFromIndex(scalarCol, firstTrueIndex) + + val finalRet = withResource(new Table(scalarCol)) { oneColumnTable => + // 3. execute final select + // default gather OutOfBoundsPolicy is nullify, + // If index is out of bound, return null + withResource(oneColumnTable.gather(firstTrueIndex)) { resultTable => + resultTable.getColumn(0).incRefCount() + } + } // return final column vector GpuColumnVector.from(finalRet, dataType) } } } } else { + // execute from tail to front recursively // `elseRet` will be closed in `computeIfElse`. val elseRet = elseValue .map(_.columnarEvalAny(batch))