From efb351a9e82ec418e96a99c26f7db3a3c4c673d5 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Thu, 24 Oct 2024 11:54:25 +0800 Subject: [PATCH] Fix --- .../aggregate/GpuHyperLogLogPlusPlus.scala | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala index 4fc7e3f0acca..244536c8c46e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.rapids.aggregate import scala.collection.immutable.Seq import ai.rapids.cudf -import ai.rapids.cudf.{ColumnVector, DType, GroupByAggregation, ReductionAggregation} +import ai.rapids.cudf.{DType, GroupByAggregation, ReductionAggregation} import com.nvidia.spark.rapids._ -//import com.nvidia.spark.rapids.Arm.withResourceIfAllowed -//import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression -//import com.nvidia.spark.rapids.jni.HLL +import com.nvidia.spark.rapids.Arm.withResourceIfAllowed +import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression +import com.nvidia.spark.rapids.jni.HLL import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} @@ -57,7 +57,7 @@ case class CudfMergeHLLPP(override val dataType: DataType, * Perform the final evaluation step to compute approximate count distinct from sketches. * Input is long columns, first construct struct of long then feed to cuDF */ -case class GpuHyperLogLogPlusPlusEvaluation(children: Seq[Expression], +case class GpuHyperLogLogPlusPlusEvaluation(childExpr: Expression, numRegistersPerSketch: Int) extends GpuExpression with ShimExpression { override def dataType: DataType = LongType @@ -66,19 +66,14 @@ case class GpuHyperLogLogPlusPlusEvaluation(children: Seq[Expression], override def nullable: Boolean = false - override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + override def children: scala.Seq[Expression] = Seq(childExpr) - // TODO - // withResourceIfAllowed(childExpr.columnarEval(batch)) { sketches => - // val distinctValues = HLL.estimateDistinctValueFromSketches( - // sketches.getBase, numRegistersPerSketch) - // GpuColumnVector.from(distinctValues, LongType) - // } - - val numRows = batch.numRows() - val zeros = Array.fill(numRows)(0L) - val longCv = ColumnVector.fromLongs(zeros: _*) - GpuColumnVector.from(longCv, LongType) + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + withResourceIfAllowed(childExpr.columnarEval(batch)) { sketches => + val distinctValues = HLL.estimateDistinctValueFromSketches( + sketches.getBase, numRegistersPerSketch) + GpuColumnVector.from(distinctValues, LongType) + } } } @@ -160,7 +155,7 @@ case class GpuHyperLogLogPlusPlus(childExpr: Expression, relativeSD: Double) } override lazy val evaluateExpression: Expression = - GpuHyperLogLogPlusPlusEvaluation(aggBufferAttributes, numRegistersPerSketch) + GpuHyperLogLogPlusPlusEvaluation(genStruct.head, numRegistersPerSketch) override def dataType: DataType = LongType