From 136e69aecbb330d2b9f2804ebac458b2c6284dc5 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Thu, 24 Oct 2024 11:40:44 +0800 Subject: [PATCH] Fix --- .../aggregate/GpuHyperLogLogPlusPlus.scala | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala index 2fef97e5df67..4fc7e3f0acca 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala @@ -68,16 +68,16 @@ case class GpuHyperLogLogPlusPlusEvaluation(children: Seq[Expression], override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { - // TODO -// withResourceIfAllowed(childExpr.columnarEval(batch)) { sketches => -// val distinctValues = HLL.estimateDistinctValueFromSketches( + // TODO + // withResourceIfAllowed(childExpr.columnarEval(batch)) { sketches => + // val distinctValues = HLL.estimateDistinctValueFromSketches( // sketches.getBase, numRegistersPerSketch) -// GpuColumnVector.from(distinctValues, LongType) -// } + // GpuColumnVector.from(distinctValues, LongType) + // } val numRows = batch.numRows() val zeros = Array.fill(numRows)(0L) - val longCv = ColumnVector.fromLongs(zeros : _*) + val longCv = ColumnVector.fromLongs(zeros: _*) GpuColumnVector.from(longCv, LongType) } } @@ -140,15 +140,10 @@ case class GpuHyperLogLogPlusPlus(childExpr: Expression, relativeSD: Double) /** * Convert Struct column to long columns */ - private def extractChildren: Seq[Expression] = Seq.tabulate(hllppHelper.numWords) { + override lazy val postUpdate: Seq[Expression] = Seq.tabulate(hllppHelper.numWords) { i => GpuGetStructField(postUpdateAttr.head, i) } - /** - * Convert to long columns - */ - override lazy val postUpdate: Seq[Expression] = extractChildren - /** * convert to Struct */ @@ -158,9 +153,11 @@ case class GpuHyperLogLogPlusPlus(childExpr: Expression, relativeSD: Double) Seq(CudfMergeHLLPP(cuDFBufferType, numRegistersPerSketch)) /** - * Convert to long columns + * Convert Struct column to long columns */ - override lazy val postMerge: Seq[Expression] = extractChildren + override lazy val postMerge: Seq[Expression] = Seq.tabulate(hllppHelper.numWords) { + i => GpuGetStructField(postMergeAttr.head, i) + } override lazy val evaluateExpression: Expression = GpuHyperLogLogPlusPlusEvaluation(aggBufferAttributes, numRegistersPerSketch)