Skip to content

Commit

Permalink
Use Table.gather instead of a custom kernel
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <[email protected]>
  • Loading branch information
Chong Gao committed Jun 18, 2024
1 parent 72f76e5 commit 7c43e69
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2276,7 +2276,7 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.

val CASE_WHEN_FUSE =
conf("spark.rapids.sql.case_when.fuse")
.doc("If when branches is greater than 3 and all then values in case when are string " +
.doc("If when branches is greater than 2 and all then/else values in case when are string " +
"scalar, fuse mode improves the performance. By default this is enabled.")
.internal()
.booleanConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,14 @@ case class GpuCaseWhen(
) {
// when branches size > 2;
// return type is string type;
// all the then and else exprs are Scalars.
// all the then and else expressions are Scalars.
// Avoid to use multiple `computeIfElse`s which will create multiple temp columns

// 1. select first true index from bool columns
// 1. select first true index from bool columns, if no true, index will be out of bound
// e.g.:
// case when bool result column 0: true, false, false
// case when bool result column 1: false, true, false
// result is: [0, 1, 2]
val whenBoolCols = branches.safeMap(_._1.columnarEval(batch).getBase).toArray
val firstTrueIndex: ColumnVector = withResource(whenBoolCols) { _ =>
CaseWhen.selectFirstTrueIndex(whenBoolCols)
Expand All @@ -387,14 +391,22 @@ case class GpuCaseWhen(
.asInstanceOf[UTF8String].getBytes)
val scalarCol = ColumnVector.fromUTF8Strings(scalarsBytes: _*)
withResource(scalarCol) { _ =>
// 3. execute final select
val finalRet = CaseWhen.selectFromIndex(scalarCol, firstTrueIndex)

val finalRet = withResource(new Table(scalarCol)) { oneColumnTable =>
// 3. execute final select
// default gather OutOfBoundsPolicy is nullify,
// If index is out of bound, return null
withResource(oneColumnTable.gather(firstTrueIndex)) { resultTable =>
resultTable.getColumn(0).incRefCount()
}
}
// return final column vector
GpuColumnVector.from(finalRet, dataType)
}
}
}
} else {
// execute from tail to front recursively
// `elseRet` will be closed in `computeIfElse`.
val elseRet = elseValue
.map(_.columnarEvalAny(batch))
Expand Down

0 comments on commit 7c43e69

Please sign in to comment.