Skip to content

Commit

Permalink
Use Host UDF
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Dec 17, 2024
1 parent 4d5b1e7 commit 0d2b2a2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import ai.rapids.cudf.{DType, GroupByAggregation, ReductionAggregation}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.withResourceIfAllowed
import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression
import com.nvidia.spark.rapids.jni.HLLPP
import com.nvidia.spark.rapids.jni.HLLPPHostUDF
import com.nvidia.spark.rapids.shims.ShimExpression

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
Expand All @@ -35,20 +35,28 @@ case class CudfHLLPP(override val dataType: DataType,
precision: Int) extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
(input: cudf.ColumnVector) => input.reduce(
ReductionAggregation.HLLPP(precision), DType.STRUCT)
ReductionAggregation.hostUDF(
HLLPPHostUDF.createHLLPPHostUDF(HLLPPHostUDF.AggregationType.Reduction, precision)),
DType.STRUCT)
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.HLLPP(precision)
GroupByAggregation.hostUDF(
HLLPPHostUDF.createHLLPPHostUDF(HLLPPHostUDF.AggregationType.GroupBy, precision)
)
override val name: String = "CudfHyperLogLogPlusPlus"
}

case class CudfMergeHLLPP(override val dataType: DataType,
precision: Int)
extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
(input: cudf.ColumnVector) =>
input.reduce(ReductionAggregation.mergeHLL(precision), DType.STRUCT)
(input: cudf.ColumnVector) => input.reduce(
ReductionAggregation.hostUDF(
HLLPPHostUDF.createHLLPPHostUDF(HLLPPHostUDF.AggregationType.Reduction_MERGE, precision)),
DType.STRUCT)
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.mergeHLL(precision)
GroupByAggregation.hostUDF(
HLLPPHostUDF.createHLLPPHostUDF(HLLPPHostUDF.AggregationType.GroupByMerge, precision)
)
override val name: String = "CudfMergeHyperLogLogPlusPlus"
}

Expand All @@ -69,7 +77,7 @@ case class GpuHyperLogLogPlusPlusEvaluation(childExpr: Expression,

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
withResourceIfAllowed(childExpr.columnarEval(batch)) { sketches =>
val distinctValues = HLLPP.estimateDistinctValueFromSketches(
val distinctValues = HLLPPHostUDF.estimateDistinctValueFromSketches(
sketches.getBase, precision)
GpuColumnVector.from(distinctValues, LongType)
}
Expand Down
1 change: 1 addition & 0 deletions tools/generated_files/330/operatorsScore.csv
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ HiveGenericUDF,4
HiveHash,4
HiveSimpleUDF,4
Hour,4
HyperLogLogPlusPlus,4
Hypot,4
If,4
In,4
Expand Down
4 changes: 4 additions & 0 deletions tools/generated_files/330/supportedExprs.csv
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,10 @@ First,S,`first_value`; `first`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,N
First,S,`first_value`; `first`,None,reduction,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS,NS,NS
First,S,`first_value`; `first`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS,NS,NS
First,S,`first_value`; `first`,None,window,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS,NS,NS
HyperLogLogPlusPlus,S,`approx_count_distinct`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,NS,S,NS,NS,NS,NS,NS,NS,NS
HyperLogLogPlusPlus,S,`approx_count_distinct`,None,aggregation,result,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
HyperLogLogPlusPlus,S,`approx_count_distinct`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,NS,S,NS,NS,NS,NS,NS,NS,NS
HyperLogLogPlusPlus,S,`approx_count_distinct`,None,reduction,result,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Last,S,`last_value`; `last`,None,aggregation,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS,NS,NS
Last,S,`last_value`; `last`,None,aggregation,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS,NS,NS
Last,S,`last_value`; `last`,None,reduction,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS,NS,NS
Expand Down

0 comments on commit 0d2b2a2

Please sign in to comment.