From 4fa0a1dee986e05733dbdbf4971c42ad5e0e84ec Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Tue, 26 Nov 2024 23:44:45 +0800 Subject: [PATCH] repartition-based fallback for hash aggregate v3 (#11712) Signed-off-by: Hongbin Ma (Mahone) Signed-off-by: Firestarman Co-authored-by: Firestarman --- .../scala/com/nvidia/spark/rapids/Arm.scala | 16 +- .../rapids/AutoClosableArrayBuffer.scala | 54 ++ .../spark/rapids/GpuAggregateExec.scala | 725 ++++++++++-------- .../com/nvidia/spark/rapids/GpuExec.scala | 6 + ...GpuUnboundedToUnboundedAggWindowExec.scala | 29 +- 5 files changed, 476 insertions(+), 354 deletions(-) create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/AutoClosableArrayBuffer.scala diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala index 926f770a683..b0cd798c179 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.ControlThrowable import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -134,6 +134,20 @@ object Arm extends ArmScalaSpecificImpl { } } + /** Executes the provided code block, closing the resources only if an exception occurs */ + def closeOnExcept[T <: AutoCloseable, V](r: ListBuffer[T])(block: ListBuffer[T] => V): V = { + try { + block(r) + } catch { + case t: ControlThrowable => + // Don't close for these cases.. + throw t + case t: Throwable => + r.safeClose(t) + throw t + } + } + /** Executes the provided code block, closing the resources only if an exception occurs */ def closeOnExcept[T <: AutoCloseable, V](r: mutable.Queue[T])(block: mutable.Queue[T] => V): V = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AutoClosableArrayBuffer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AutoClosableArrayBuffer.scala new file mode 100644 index 00000000000..fb1e10b9c9e --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AutoClosableArrayBuffer.scala @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +/** + * Just a simple wrapper to make working with buffers of AutoClosable things play + * nicely with withResource. + */ +class AutoClosableArrayBuffer[T <: AutoCloseable] extends AutoCloseable { + val data = new ArrayBuffer[T]() + + def append(scb: T): Unit = data.append(scb) + + def last: T = data.last + + def removeLast(): T = data.remove(data.length - 1) + + def foreach[U](f: T => U): Unit = data.foreach(f) + + def map[U](f: T => U): Seq[U] = data.map(f).toSeq + + def toArray[B >: T : ClassTag]: Array[B] = data.toArray + + def size(): Int = data.size + + def clear(): Unit = data.clear() + + def forall(p: T => Boolean): Boolean = data.forall(p) + + def iterator: Iterator[T] = data.iterator + + override def toString: String = s"AutoCloseable(${super.toString})" + + override def close(): Unit = { + data.foreach(_.close()) + data.clear() + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala index b5360a62f94..60f6dd68509 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala @@ -16,11 +16,9 @@ package com.nvidia.spark.rapids -import java.util - import scala.annotation.tailrec -import scala.collection.JavaConverters.collectionAsScalaIterableConverter import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf import ai.rapids.cudf.{NvtxColor, NvtxRange} @@ -37,7 +35,7 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, AttributeSeq, AttributeSet, Expression, ExprId, If, NamedExpression, NullsFirst, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, AttributeSeq, AttributeSet, Expression, ExprId, If, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, HashPartitioning, Partitioning, UnspecifiedDistribution} @@ -47,11 +45,11 @@ import org.apache.spark.sql.execution.{ExplainUtils, SortExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, CudfAggregate, GpuAggregateExpression, GpuToCpuAggregateBufferConverter} -import org.apache.spark.sql.rapids.execution.{GpuShuffleMeta, TrampolineUtil} +import org.apache.spark.sql.rapids.execution.{GpuBatchSubPartitioner, GpuShuffleMeta, TrampolineUtil} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -object AggregateUtils { +object AggregateUtils extends Logging { private val aggs = List("min", "max", "avg", "sum", "count", "first", "last") @@ -98,8 +96,10 @@ object AggregateUtils { inputTypes: Seq[DataType], outputTypes: Seq[DataType], isReductionOnly: Boolean): Long = { + def typesToSize(types: Seq[DataType]): Long = types.map(GpuBatchUtils.estimateGpuMemory(_, nullable = false, rowCount = 1)).sum + val inputRowSize = typesToSize(inputTypes) val outputRowSize = typesToSize(outputTypes) // The cudf hash table implementation allocates four 32-bit integers per input row. @@ -120,22 +120,198 @@ object AggregateUtils { } // Calculate the max rows that can be processed during computation within the budget - val maxRows = totalBudget / computationBytesPerRow + // Make sure it's not less than 1, otherwise some corner test cases may fail + val maxRows = Math.max(totalBudget / computationBytesPerRow, 1) // Finally compute the input target batching size taking into account the cudf row limits Math.min(inputRowSize * maxRows, Int.MaxValue) } + + /** + * Concatenate batches together and perform a merge aggregation on the result. The input batches + * will be closed as part of this operation. + * + * @param batches batches to concatenate and merge aggregate + * @return lazy spillable batch which has NOT been marked spillable + */ + def concatenateAndMerge( + batches: mutable.ArrayBuffer[SpillableColumnarBatch], + metrics: GpuHashAggregateMetrics, + concatAndMergeHelper: AggHelper): SpillableColumnarBatch = { + // TODO: concatenateAndMerge (and calling code) could output a sequence + // of batches for the partial aggregate case. This would be done in case + // a retry failed a certain number of times. + val concatBatch = withResource(batches) { _ => + val concatSpillable = concatenateBatches(metrics, batches.toSeq) + withResource(concatSpillable) { + _.getColumnarBatch() + } + } + computeAggregateAndClose(metrics, concatBatch, concatAndMergeHelper) + } + + /** + * Try to concat and merge neighbour input batches to reduce the number of output batches. + * For some cases where input is highly aggregate-able, we can merge multiple input batches + * into a single output batch. In such cases we can skip repartition at all. + */ + def streamAggregateNeighours( + aggregatedBatches: CloseableBufferedIterator[SpillableColumnarBatch], + metrics: GpuHashAggregateMetrics, + targetMergeBatchSize: Long, + configuredTargetBatchSize: Long, + helper: AggHelper + ): Iterator[SpillableColumnarBatch] = { + new Iterator[SpillableColumnarBatch] { + + override def hasNext: Boolean = aggregatedBatches.hasNext + + override def next(): SpillableColumnarBatch = { + closeOnExcept(new ArrayBuffer[SpillableColumnarBatch]) { stagingBatches => { + var currentSize = 0L + while (aggregatedBatches.hasNext) { + val nextBatch = aggregatedBatches.head + if (currentSize + nextBatch.sizeInBytes > targetMergeBatchSize) { + if (stagingBatches.size == 1) { + return stagingBatches.head + } else if (stagingBatches.isEmpty) { + aggregatedBatches.next + return nextBatch + } + val merged = concatenateAndMerge(stagingBatches, metrics, helper) + stagingBatches.clear + currentSize = 0L + if (merged.sizeInBytes < configuredTargetBatchSize * 0.5) { + stagingBatches += merged + currentSize += merged.sizeInBytes + } else { + return merged + } + } else { + stagingBatches.append(nextBatch) + currentSize += nextBatch.sizeInBytes + aggregatedBatches.next + } + } + + if (stagingBatches.size == 1) { + return stagingBatches.head + } + concatenateAndMerge(stagingBatches, metrics, helper) + } + } + } + } + } + + /** + * Read the input batches and repartition them into buckets. + */ + def iterateAndRepartition( + aggregatedBatches: Iterator[SpillableColumnarBatch], + metrics: GpuHashAggregateMetrics, + targetMergeBatchSize: Long, + helper: AggHelper, + hashKeys: Seq[GpuExpression], + hashBucketNum: Int, + hashSeed: Int, + batchesByBucket: ArrayBuffer[AutoClosableArrayBuffer[SpillableColumnarBatch]] + ): Boolean = { + + var repartitionHappened = false + if (hashSeed > 200) { + throw new IllegalStateException("Too many times of repartition, may hit a bug?") + } + + def repartitionAndClose(batch: SpillableColumnarBatch): Unit = { + + // OPTIMIZATION + if (!aggregatedBatches.hasNext && batchesByBucket.forall(_.size() == 0)) { + // If this is the only batch (after merging neighbours) to be repartitioned, + // we can just add it to the first bucket and skip repartitioning. + // This is a common case when total input size can fit into a single batch. + batchesByBucket.head.append(batch) + return + } + + withResource(new NvtxWithMetrics("agg repartition", + NvtxColor.CYAN, metrics.repartitionTime)) { _ => + + withResource(new GpuBatchSubPartitioner( + Seq(batch).map(batch => { + withResource(batch) { _ => + batch.getColumnarBatch() + } + }).iterator, + hashKeys, hashBucketNum, hashSeed, "aggRepartition")) { + partitioner => { + (0 until partitioner.partitionsCount).foreach { id => + closeOnExcept(batchesByBucket) { _ => { + val newBatches = partitioner.releaseBatchesByPartition(id) + newBatches.foreach { newBatch => + if (newBatch.numRows() > 0) { + batchesByBucket(id).append(newBatch) + } else { + newBatch.safeClose() + } + } + } + } + } + } + } + } + repartitionHappened = true + } + + while (aggregatedBatches.hasNext) { + repartitionAndClose(aggregatedBatches.next) + } + + // Deal with the over sized buckets + def needRepartitionAgain(bucket: AutoClosableArrayBuffer[SpillableColumnarBatch]) = { + bucket.map(_.sizeInBytes).sum > targetMergeBatchSize && + bucket.size() != 1 && + !bucket.forall(_.numRows() == 1) // this is for test + } + + if (repartitionHappened && batchesByBucket.exists(needRepartitionAgain)) { + logDebug("Some of the repartition buckets are over sized, trying to split them") + + val newBuckets = batchesByBucket.flatMap(bucket => { + if (needRepartitionAgain(bucket)) { + val nextLayerBuckets = + ArrayBuffer.fill(hashBucketNum)(new AutoClosableArrayBuffer[SpillableColumnarBatch]()) + // Recursively merge and repartition the over sized bucket + repartitionHappened = + iterateAndRepartition( + new CloseableBufferedIterator(bucket.iterator), metrics, targetMergeBatchSize, + helper, hashKeys, hashBucketNum, hashSeed + 7, + nextLayerBuckets) || repartitionHappened + nextLayerBuckets + } else { + ArrayBuffer.apply(bucket) + } + }) + batchesByBucket.clear() + batchesByBucket.appendAll(newBuckets) + } + + repartitionHappened + } } /** Utility class to hold all of the metrics related to hash aggregation */ case class GpuHashAggregateMetrics( numOutputRows: GpuMetric, numOutputBatches: GpuMetric, - numTasksFallBacked: GpuMetric, + numTasksRepartitioned: GpuMetric, + numTasksSkippedAgg: GpuMetric, opTime: GpuMetric, computeAggTime: GpuMetric, concatTime: GpuMetric, sortTime: GpuMetric, + repartitionTime: GpuMetric, numAggOps: GpuMetric, numPreSplits: GpuMetric, singlePassTasks: GpuMetric, @@ -208,7 +384,7 @@ class AggHelper( private val groupingAttributes = groupingExpressions.map(_.toAttribute) private val aggBufferAttributes = groupingAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) // `GpuAggregateFunction` can add a pre and post step for update // and merge aggregates. @@ -228,7 +404,7 @@ class AggHelper( postStep ++= groupingAttributes postStepAttr ++= groupingAttributes postStepDataTypes ++= - groupingExpressions.map(_.dataType) + groupingExpressions.map(_.dataType) private var ix = groupingAttributes.length for (aggExp <- aggregateExpressions) { @@ -380,9 +556,9 @@ class AggHelper( withResource(new NvtxRange("groupby", NvtxColor.BLUE)) { _ => withResource(GpuColumnVector.from(preProcessed)) { preProcessedTbl => val groupOptions = cudf.GroupByOptions.builder() - .withIgnoreNullKeys(false) - .withKeysSorted(doSortAgg) - .build() + .withIgnoreNullKeys(false) + .withKeysSorted(doSortAgg) + .build() val cudfAggsOnColumn = cudfAggregates.zip(aggOrdinals).map { case (cudfAgg, ord) => cudfAgg.groupByAggregate.onColumn(ord) @@ -390,8 +566,8 @@ class AggHelper( // perform the aggregate val aggTbl = preProcessedTbl - .groupBy(groupOptions, groupingOrdinals: _*) - .aggregate(cudfAggsOnColumn.toSeq: _*) + .groupBy(groupOptions, groupingOrdinals: _*) + .aggregate(cudfAggsOnColumn.toSeq: _*) withResource(aggTbl) { _ => GpuColumnVector.from(aggTbl, postStepDataTypes.toArray) @@ -555,8 +731,8 @@ object GpuAggFirstPassIterator { metrics: GpuHashAggregateMetrics ): Iterator[SpillableColumnarBatch] = { val preprocessProjectIter = cbIter.map { cb => - val sb = SpillableColumnarBatch (cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY) - aggHelper.preStepBound.projectAndCloseWithRetrySingleBatch (sb) + val sb = SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY) + aggHelper.preStepBound.projectAndCloseWithRetrySingleBatch(sb) } computeAggregateWithoutPreprocessAndClose(metrics, preprocessProjectIter, aggHelper) } @@ -597,18 +773,18 @@ object GpuAggFinalPassIterator { modeInfo: AggregateModeInfo): BoundExpressionsModeAggregates = { val groupingAttributes = groupingExpressions.map(_.toAttribute) val aggBufferAttributes = groupingAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val boundFinalProjections = if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) { val finalProjections = groupingAttributes ++ - aggregateExpressions.map(_.aggregateFunction.evaluateExpression) + aggregateExpressions.map(_.aggregateFunction.evaluateExpression) Some(GpuBindReferences.bindGpuReferences(finalProjections, aggBufferAttributes)) } else { None } // allAttributes can be different things, depending on aggregation mode: - // - Partial mode: grouping key + cudf aggregates (e.g. no avg, intead sum::count + // - Partial mode: grouping key + cudf aggregates (e.g. no avg, instead sum::count // - Final mode: grouping key + spark aggregates (e.g. avg) val finalAttributes = groupingAttributes ++ aggregateAttributes @@ -689,17 +865,22 @@ object GpuAggFinalPassIterator { /** * Iterator that takes another columnar batch iterator as input and emits new columnar batches that * are aggregated based on the specified grouping and aggregation expressions. This iterator tries - * to perform a hash-based aggregation but is capable of falling back to a sort-based aggregation - * which can operate on data that is either larger than can be represented by a cudf column or - * larger than can fit in GPU memory. + * to perform a hash-based aggregation but is capable of falling back to a repartition-based + * aggregation which can operate on data that is either larger than can be represented by a cudf + * column or larger than can fit in GPU memory. + * + * In general, GpuMergeAggregateIterator works in this flow: * - * The iterator starts by pulling all batches from the input iterator, performing an initial - * projection and aggregation on each individual batch via `aggregateInputBatches()`. The resulting - * aggregated batches are cached in memory as spillable batches. Once all input batches have been - * aggregated, `tryMergeAggregatedBatches()` is called to attempt a merge of the aggregated batches - * into a single batch. If this is successful then the resulting batch can be returned, otherwise - * `buildSortFallbackIterator` is used to sort the aggregated batches by the grouping keys and - * performs a final merge aggregation pass on the sorted batches. + * (1) The iterator starts by pulling all batches from the input iterator, performing an initial + * projection and aggregation on each individual batch via `GpuAggFirstPassIterator`, we call it + * "First Pass Aggregate". + * (2) Then the batches after first pass agg is sent to "streamAggregateNeighours", where it tries + * to concat & merge the neighbour batches into fewer batches, then "iterateAndRepartition" + * repartition the batch into fixed size buckets. Recursive repartition will be applied on + * over-sized buckets until each bucket * is within the target size. + * We call this phase "Second Pass Aggregate". + * (3) At "Third Pass Aggregate", we take each bucket and perform a final aggregation on all batches + * in the bucket, check "RepartitionAggregateIterator" for details. * * @param firstPassIter iterator that has done a first aggregation pass over the input data. * @param inputAttributes input attributes to identify the input columns from the input batches @@ -710,13 +891,12 @@ object GpuAggFinalPassIterator { * @param modeInfo identifies which aggregation modes are being used * @param metrics metrics that will be updated during aggregation * @param configuredTargetBatchSize user-specified value for the targeted input batch size - * @param useTieredProject user-specified option to enable tiered projections * @param allowNonFullyAggregatedOutput if allowed to skip third pass Agg * @param skipAggPassReductionRatio skip if the ratio of rows after a pass is bigger than this value * @param localInputRowsCount metric to track the number of input rows processed locally */ class GpuMergeAggregateIterator( - firstPassIter: Iterator[SpillableColumnarBatch], + firstPassIter: CloseableBufferedIterator[SpillableColumnarBatch], inputAttributes: Seq[Attribute], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[GpuAggregateExpression], @@ -728,18 +908,22 @@ class GpuMergeAggregateIterator( conf: SQLConf, allowNonFullyAggregatedOutput: Boolean, skipAggPassReductionRatio: Double, - localInputRowsCount: LocalGpuMetric) - extends Iterator[ColumnarBatch] with AutoCloseable with Logging { + localInputRowsCount: LocalGpuMetric +) + extends Iterator[ColumnarBatch] with AutoCloseable with Logging { private[this] val isReductionOnly = groupingExpressions.isEmpty private[this] val targetMergeBatchSize = computeTargetMergeBatchSize(configuredTargetBatchSize) - private[this] val aggregatedBatches = new util.ArrayDeque[SpillableColumnarBatch] - private[this] var outOfCoreIter: Option[GpuOutOfCoreSortIterator] = None - /** Iterator for fetching aggregated batches either if: - * 1. a sort-based fallback has occurred - * 2. skip third pass agg has occurred - **/ - private[this] var fallbackIter: Option[Iterator[ColumnarBatch]] = None + private[this] val defaultHashBucketNum = 16 + private[this] val defaultHashSeed = 107 + private[this] var batchesByBucket = + ArrayBuffer.fill(defaultHashBucketNum)(new AutoClosableArrayBuffer[SpillableColumnarBatch]()) + + private[this] var firstBatchChecked = false + + private[this] var bucketIter: Option[RepartitionAggregateIterator] = None + + private[this] var realIter: Option[Iterator[ColumnarBatch]] = None /** Whether a batch is pending for a reduction-only aggregation */ private[this] var hasReductionOnlyBatch: Boolean = isReductionOnly @@ -752,286 +936,168 @@ class GpuMergeAggregateIterator( } override def hasNext: Boolean = { - fallbackIter.map(_.hasNext).getOrElse { + realIter.map(_.hasNext).getOrElse { // reductions produce a result even if the input is empty - hasReductionOnlyBatch || !aggregatedBatches.isEmpty || firstPassIter.hasNext + hasReductionOnlyBatch || firstPassIter.hasNext } } override def next(): ColumnarBatch = { - fallbackIter.map(_.next()).getOrElse { - var shouldSkipThirdPassAgg = false - - // aggregate and merge all pending inputs - if (firstPassIter.hasNext) { - // first pass agg - val rowsAfterFirstPassAgg = aggregateInputBatches() - - // by now firstPassIter has been traversed, so localInputRowsCount is finished updating - if (isReductionOnly || - skipAggPassReductionRatio * localInputRowsCount.value >= rowsAfterFirstPassAgg) { - // second pass agg - tryMergeAggregatedBatches() - - val rowsAfterSecondPassAgg = aggregatedBatches.asScala.foldLeft(0L) { - (totalRows, batch) => totalRows + batch.numRows() - } - shouldSkipThirdPassAgg = - rowsAfterSecondPassAgg > skipAggPassReductionRatio * rowsAfterFirstPassAgg - } else { - shouldSkipThirdPassAgg = true - logInfo(s"Rows after first pass aggregation $rowsAfterFirstPassAgg exceeds " + - s"${skipAggPassReductionRatio * 100}% of " + - s"localInputRowsCount ${localInputRowsCount.value}, skip the second pass agg") - } - } + realIter.map(_.next()).getOrElse { - if (aggregatedBatches.size() > 1) { - // Unable to merge to a single output, so must fall back - if (allowNonFullyAggregatedOutput && shouldSkipThirdPassAgg) { - // skip third pass agg, return the aggregated batches directly - logInfo(s"Rows after second pass aggregation exceeds " + - s"${skipAggPassReductionRatio * 100}% of " + - s"rows after first pass, skip the third pass agg") - fallbackIter = Some(new Iterator[ColumnarBatch] { - override def hasNext: Boolean = !aggregatedBatches.isEmpty - - override def next(): ColumnarBatch = { - withResource(aggregatedBatches.pop()) { spillableBatch => - spillableBatch.getColumnarBatch() - } - } - }) - } else { - // fallback to sort agg, this is the third pass agg - fallbackIter = Some(buildSortFallbackIterator()) + // Handle reduction-only aggregation + if (isReductionOnly) { + val batches = ArrayBuffer.apply[SpillableColumnarBatch]() + while (firstPassIter.hasNext) { + batches += firstPassIter.next() } - fallbackIter.get.next() - } else if (aggregatedBatches.isEmpty) { - if (hasReductionOnlyBatch) { + + if (batches.isEmpty || batches.forall(_.numRows() == 0)) { hasReductionOnlyBatch = false - generateEmptyReductionBatch() + return generateEmptyReductionBatch() } else { - throw new NoSuchElementException("batches exhausted") + hasReductionOnlyBatch = false + val concat = AggregateUtils.concatenateAndMerge(batches, metrics, concatAndMergeHelper) + return withResource(concat) { cb => + cb.getColumnarBatch() + } } - } else { - // this will be the last batch - hasReductionOnlyBatch = false - withResource(aggregatedBatches.pop()) { spillableBatch => - spillableBatch.getColumnarBatch() + } + + // Handle the case of skipping second and third pass of aggregation + // This only work when spark.rapids.sql.agg.skipAggPassReductionRatio < 1 + if (!firstBatchChecked && firstPassIter.hasNext + && allowNonFullyAggregatedOutput) { + firstBatchChecked = true + + val peek = firstPassIter.head + // It's only based on first batch of first pass agg, so it's an estimate + val firstPassReductionRatioEstimate = 1.0 * peek.numRows() / localInputRowsCount.value + if (firstPassReductionRatioEstimate > skipAggPassReductionRatio) { + logDebug("Skipping second and third pass aggregation due to " + + "too high reduction ratio in first pass: " + + s"$firstPassReductionRatioEstimate") + // if so, skip any aggregation, return the origin batch directly + + realIter = Some(ConcatIterator.apply(firstPassIter, configuredTargetBatchSize)) + metrics.numTasksSkippedAgg += 1 + return realIter.get.next() + } else { + logInfo(s"The reduction ratio in first pass is not high enough to skip " + + s"second and third pass aggregation: peek.numRows: ${peek.numRows()}, " + + s"localInputRowsCount.value: ${localInputRowsCount.value}") } } + firstBatchChecked = true + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val aggBufferAttributes = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val hashKeys: Seq[GpuExpression] = + GpuBindReferences.bindGpuReferences(groupingAttributes, aggBufferAttributes.toSeq) + + val repartitionHappened = AggregateUtils.iterateAndRepartition( + AggregateUtils.streamAggregateNeighours( + firstPassIter, + metrics, + targetMergeBatchSize, + configuredTargetBatchSize, + concatAndMergeHelper) + , metrics, targetMergeBatchSize, concatAndMergeHelper, + hashKeys, defaultHashBucketNum, defaultHashSeed, batchesByBucket) + if (repartitionHappened) { + metrics.numTasksRepartitioned += 1 + } + + realIter = Some(ConcatIterator.apply( + new CloseableBufferedIterator(buildBucketIterator()), configuredTargetBatchSize)) + realIter.get.next() } } override def close(): Unit = { - aggregatedBatches.forEach(_.safeClose()) - aggregatedBatches.clear() - outOfCoreIter.foreach(_.close()) - outOfCoreIter = None - fallbackIter = None + batchesByBucket.foreach(_.close()) + batchesByBucket.clear() hasReductionOnlyBatch = false } private def computeTargetMergeBatchSize(confTargetSize: Long): Long = { val mergedTypes = groupingExpressions.map(_.dataType) ++ aggregateExpressions.map(_.dataType) - AggregateUtils.computeTargetBatchSize(confTargetSize, mergedTypes, mergedTypes,isReductionOnly) + AggregateUtils.computeTargetBatchSize(confTargetSize, mergedTypes, mergedTypes, isReductionOnly) } - /** Aggregate all input batches and place the results in the aggregatedBatches queue. */ - private def aggregateInputBatches(): Long = { - var rowsAfter = 0L - // cache everything in the first pass - while (firstPassIter.hasNext) { - val batch = firstPassIter.next() - rowsAfter += batch.numRows() - aggregatedBatches.add(batch) - } - rowsAfter - } + private lazy val concatAndMergeHelper = + new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions, + forceMerge = true, conf, isSorted = false) + + private case class ConcatIterator( + input: CloseableBufferedIterator[SpillableColumnarBatch], + targetSize: Long) + extends Iterator[ColumnarBatch] { + + override def hasNext: Boolean = input.hasNext + + override def next(): ColumnarBatch = { + // combine all the data into a single batch + val spillCbs = ArrayBuffer[SpillableColumnarBatch]() + var totalBytes = 0L + closeOnExcept(spillCbs) { _ => + while (input.hasNext && (spillCbs.isEmpty || + (totalBytes + input.head.sizeInBytes) < targetSize)) { + val tmp = input.next + totalBytes += tmp.sizeInBytes + spillCbs += tmp + } - /** - * Attempt to merge adjacent batches in the aggregatedBatches queue until either there is only - * one batch or merging adjacent batches would exceed the target batch size. - */ - private def tryMergeAggregatedBatches(): Unit = { - while (aggregatedBatches.size() > 1) { - val concatTime = metrics.concatTime - val opTime = metrics.opTime - withResource(new NvtxWithMetrics("agg merge pass", NvtxColor.BLUE, concatTime, - opTime)) { _ => - // continue merging as long as some batches are able to be combined - if (!mergePass()) { - if (aggregatedBatches.size() > 1 && isReductionOnly) { - // We were unable to merge the aggregated batches within the target batch size limit, - // which means normally we would fallback to a sort-based approach. However for - // reduction-only aggregation there are no keys to use for a sort. The only way this - // can work is if all batches are merged. This will exceed the target batch size limit, - // but at this point it is either risk an OOM/cudf error and potentially work or - // not work at all. - logWarning(s"Unable to merge reduction-only aggregated batches within " + - s"target batch limit of $targetMergeBatchSize, attempting to merge remaining " + - s"${aggregatedBatches.size()} batches beyond limit") - withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { batchesToConcat => - aggregatedBatches.forEach(b => batchesToConcat += b) - aggregatedBatches.clear() - val batch = concatenateAndMerge(batchesToConcat) - // batch does not need to be marked spillable since it is the last and only batch - // and will be immediately retrieved on the next() call. - aggregatedBatches.add(batch) - } - } - return + val concat = GpuAggregateIterator.concatenateBatches(metrics, spillCbs.toSeq) + withResource(concat) { _ => + concat.getColumnarBatch() } } } } - /** - * Perform a single pass over the aggregated batches attempting to merge adjacent batches. - * @return true if at least one merge operation occurred - */ - private def mergePass(): Boolean = { - val batchesToConcat: mutable.ArrayBuffer[SpillableColumnarBatch] = mutable.ArrayBuffer.empty - var wasBatchMerged = false - // Current size in bytes of the batches targeted for the next concatenation - var concatSize: Long = 0L - var batchesLeftInPass = aggregatedBatches.size() - - while (batchesLeftInPass > 0) { - closeOnExcept(batchesToConcat) { _ => - var isConcatSearchFinished = false - // Old batches are picked up at the front of the queue and freshly merged batches are - // appended to the back of the queue. Although tempting to allow the pass to "wrap around" - // and pick up batches freshly merged in this pass, it's avoided to prevent changing the - // order of aggregated batches. - while (batchesLeftInPass > 0 && !isConcatSearchFinished) { - val candidate = aggregatedBatches.getFirst - val potentialSize = concatSize + candidate.sizeInBytes - isConcatSearchFinished = concatSize > 0 && potentialSize > targetMergeBatchSize - if (!isConcatSearchFinished) { - batchesLeftInPass -= 1 - batchesToConcat += aggregatedBatches.removeFirst() - concatSize = potentialSize - } - } - } + private case class RepartitionAggregateIterator(opTime: GpuMetric) + extends Iterator[SpillableColumnarBatch] { - val mergedBatch = if (batchesToConcat.length > 1) { - wasBatchMerged = true - concatenateAndMerge(batchesToConcat) - } else { - // Unable to find a neighboring buffer to produce a valid merge in this pass, - // so simply put this buffer back on the queue for other passes. - batchesToConcat.remove(0) - } + batchesByBucket = batchesByBucket.filter(_.size() > 0) - // Add the merged batch to the end of the aggregated batch queue. Only a single pass over - // the batches is being performed due to the batch count check above, so the single-pass - // loop will terminate before picking up this new batch. - aggregatedBatches.addLast(mergedBatch) - batchesToConcat.clear() - concatSize = 0 - } + override def hasNext: Boolean = batchesByBucket.nonEmpty - wasBatchMerged - } + override def next(): SpillableColumnarBatch = { + withResource(new NvtxWithMetrics("RepartitionAggregateIterator.next", + NvtxColor.BLUE, opTime)) { _ => - private lazy val concatAndMergeHelper = - new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions, - forceMerge = true, conf = conf) - - /** - * Concatenate batches together and perform a merge aggregation on the result. The input batches - * will be closed as part of this operation. - * @param batches batches to concatenate and merge aggregate - * @return lazy spillable batch which has NOT been marked spillable - */ - private def concatenateAndMerge( - batches: mutable.ArrayBuffer[SpillableColumnarBatch]): SpillableColumnarBatch = { - // TODO: concatenateAndMerge (and calling code) could output a sequence - // of batches for the partial aggregate case. This would be done in case - // a retry failed a certain number of times. - val concatBatch = withResource(batches) { _ => - val concatSpillable = concatenateBatches(metrics, batches.toSeq) - withResource(concatSpillable) { _.getColumnarBatch() } - } - computeAggregateAndClose(metrics, concatBatch, concatAndMergeHelper) - } - - /** Build an iterator that uses a sort-based approach to merge aggregated batches together. */ - private def buildSortFallbackIterator(): Iterator[ColumnarBatch] = { - logInfo(s"Falling back to sort-based aggregation with ${aggregatedBatches.size()} batches") - metrics.numTasksFallBacked += 1 - val aggregatedBatchIter = new Iterator[ColumnarBatch] { - override def hasNext: Boolean = !aggregatedBatches.isEmpty + if (batchesByBucket.last.size() == 1) { + batchesByBucket.remove(batchesByBucket.size - 1).removeLast() + } else { + // put as many buckets as possible together to aggregate, to reduce agg times + closeOnExcept(new ArrayBuffer[AutoClosableArrayBuffer[SpillableColumnarBatch]]) { + toAggregateBuckets => + var currentSize = 0L + while (batchesByBucket.nonEmpty && + batchesByBucket.last.size() + currentSize < targetMergeBatchSize) { + val bucket = batchesByBucket.remove(batchesByBucket.size - 1) + currentSize += bucket.map(_.sizeInBytes).sum + toAggregateBuckets += bucket + } - override def next(): ColumnarBatch = { - withResource(aggregatedBatches.removeFirst()) { spillable => - spillable.getColumnarBatch() + AggregateUtils.concatenateAndMerge( + toAggregateBuckets.flatMap(_.data), metrics, concatAndMergeHelper) + } } } } + } - if (isReductionOnly) { - // Normally this should never happen because `tryMergeAggregatedBatches` should have done - // a last-ditch effort to concatenate all batches together regardless of target limits. - throw new IllegalStateException("Unable to fallback to sort-based aggregation " + - "without grouping keys") - } - - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val ordering = groupingAttributes.map(SortOrder(_, Ascending, NullsFirst, Seq.empty)) - val aggBufferAttributes = groupingAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val sorter = new GpuSorter(ordering, aggBufferAttributes) - val aggBatchTypes = aggBufferAttributes.map(_.dataType) - // Use the out of core sort iterator to sort the batches by grouping key - outOfCoreIter = Some(GpuOutOfCoreSortIterator( - aggregatedBatchIter, - sorter, - configuredTargetBatchSize, - opTime = metrics.opTime, - sortTime = metrics.sortTime, - outputBatches = NoopMetric, - outputRows = NoopMetric)) - - // The out of core sort iterator does not guarantee that a batch contains all of the values - // for a particular key, so add a key batching iterator to enforce this. That allows each batch - // to be merge-aggregated safely since all values associated with a particular key are - // guaranteed to be in the same batch. - val keyBatchingIter = new GpuKeyBatchingIterator( - outOfCoreIter.get, - sorter, - aggBatchTypes.toArray, - configuredTargetBatchSize, - numInputRows = NoopMetric, - numInputBatches = NoopMetric, - numOutputRows = NoopMetric, - numOutputBatches = NoopMetric, - concatTime = metrics.concatTime, - opTime = metrics.opTime) - - // Finally wrap the key batching iterator with a merge aggregation on the output batches. - new Iterator[ColumnarBatch] { - override def hasNext: Boolean = keyBatchingIter.hasNext - - private val mergeSortedHelper = - new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions, - forceMerge = true, conf, isSorted = true) - - override def next(): ColumnarBatch = { - // batches coming out of the sort need to be merged - val resultSpillable = - computeAggregateAndClose(metrics, keyBatchingIter.next(), mergeSortedHelper) - withResource(resultSpillable) { _ => - resultSpillable.getColumnarBatch() - } - } - } + /** Build an iterator merging aggregated batches in each bucket. */ + private def buildBucketIterator(): Iterator[SpillableColumnarBatch] = { + bucketIter = Some(RepartitionAggregateIterator(opTime = metrics.opTime)) + bucketIter.get } + /** * Generates the result of a reduction-only aggregation on empty input by emitting the * initial value of each aggregator. @@ -1117,13 +1183,13 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( ) if (arrayWithStructsGroupings) { willNotWorkOnGpu("ArrayTypes with Struct children in grouping expressions are not " + - "supported") + "supported") } tagForReplaceMode() if (agg.aggregateExpressions.exists(expr => expr.isDistinct) - && agg.aggregateExpressions.exists(expr => expr.filter.isDefined)) { + && agg.aggregateExpressions.exists(expr => expr.filter.isDefined)) { // Distinct with Filter is not supported on the GPU currently, // This makes sure that if we end up here, the plan falls back to the CPU // which will do the right thing. @@ -1195,15 +1261,15 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( // (due to First). Fall back to CPU in this case. if (AggregateUtils.shouldFallbackMultiDistinct(agg.aggregateExpressions)) { willNotWorkOnGpu("Aggregates of non-distinct functions with multiple distinct " + - "functions are non-deterministic for non-distinct functions as it is " + - "computed using First.") + "functions are non-deterministic for non-distinct functions as it is " + + "computed using First.") } } } if (!conf.partialMergeDistinctEnabled && aggPattern.contains(PartialMerge)) { willNotWorkOnGpu("Replacing Partial Merge aggregates disabled. " + - s"Set ${conf.partialMergeDistinctEnabled} to true if desired") + s"Set ${conf.partialMergeDistinctEnabled} to true if desired") } } @@ -1256,11 +1322,11 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( // This is a short term heuristic until we can better understand the cost // of sort vs the cost of doing aggregations so we can better decide. lazy val hasSingleBasicGroupingKey = agg.groupingExpressions.length == 1 && - agg.groupingExpressions.headOption.map(_.dataType).exists { - case StringType | BooleanType | ByteType | ShortType | IntegerType | - LongType | _: DecimalType | DateType | TimestampType => true - case _ => false - } + agg.groupingExpressions.headOption.map(_.dataType).exists { + case StringType | BooleanType | ByteType | ShortType | IntegerType | + LongType | _: DecimalType | DateType | TimestampType => true + case _ => false + } val gpuChild = childPlans.head.convertIfNeeded() val gpuAggregateExpressions = @@ -1314,11 +1380,11 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( } val allowSinglePassAgg = (conf.forceSinglePassPartialSortAgg || - (conf.allowSinglePassPartialSortAgg && - hasSingleBasicGroupingKey && - estimatedPreProcessGrowth > 1.1)) && - canUsePartialSortAgg && - groupingCanBeSorted + (conf.allowSinglePassPartialSortAgg && + hasSingleBasicGroupingKey && + estimatedPreProcessGrowth > 1.1)) && + canUsePartialSortAgg && + groupingCanBeSorted GpuHashAggregateExec( aggRequiredChildDistributionExpressions, @@ -1332,7 +1398,8 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( conf.forceSinglePassPartialSortAgg, allowSinglePassAgg, allowNonFullyAggregatedOutput, - conf.skipAggPassReductionRatio) + conf.skipAggPassReductionRatio + ) } } @@ -1351,7 +1418,7 @@ abstract class GpuTypedImperativeSupportedAggregateExecMeta[INPUT <: BaseAggrega private val mayNeedAggBufferConversion: Boolean = agg.aggregateExpressions.exists { expr => expr.aggregateFunction.isInstanceOf[TypedImperativeAggregate[_]] && - (expr.mode == Partial || expr.mode == PartialMerge) + (expr.mode == Partial || expr.mode == PartialMerge) } // overriding data types of Aggregation Buffers if necessary @@ -1420,6 +1487,7 @@ abstract class GpuTypedImperativeSupportedAggregateExecMeta[INPUT <: BaseAggrega allowSinglePassAgg = false, allowNonFullyAggregatedOutput = false, 1) + } else { super.convertToGpu() } @@ -1523,8 +1591,8 @@ object GpuTypedImperativeSupportedAggregateExecMeta { // [A]. there will be a R2C or C2R transition between them // [B]. there exists TypedImperativeAggregate functions in each of them (stages(i).canThisBeReplaced ^ stages(i + 1).canThisBeReplaced) && - containTypedImperativeAggregate(stages(i)) && - containTypedImperativeAggregate(stages(i + 1)) + containTypedImperativeAggregate(stages(i)) && + containTypedImperativeAggregate(stages(i + 1)) } // Return if all internal aggregation buffers are compatible with GPU Overrides. @@ -1602,10 +1670,10 @@ object GpuTypedImperativeSupportedAggregateExecMeta { fromCpuToGpu: Boolean): Seq[NamedExpression] = { val converters = mutable.Queue[Either[ - CpuToGpuAggregateBufferConverter, GpuToCpuAggregateBufferConverter]]() + CpuToGpuAggregateBufferConverter, GpuToCpuAggregateBufferConverter]]() mergeAggMeta.childExprs.foreach { case e if e.childExprs.length == 1 && - e.childExprs.head.isInstanceOf[TypedImperativeAggExprMeta[_]] => + e.childExprs.head.isInstanceOf[TypedImperativeAggExprMeta[_]] => e.wrapped.asInstanceOf[AggregateExpression].mode match { case Final | PartialMerge => val typImpAggMeta = e.childExprs.head.asInstanceOf[TypedImperativeAggExprMeta[_]] @@ -1660,16 +1728,16 @@ class GpuHashAggregateMeta( conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) - extends GpuBaseAggregateMeta(agg, agg.requiredChildDistributionExpressions, - conf, parent, rule) + extends GpuBaseAggregateMeta(agg, agg.requiredChildDistributionExpressions, + conf, parent, rule) class GpuSortAggregateExecMeta( override val agg: SortAggregateExec, conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) - extends GpuTypedImperativeSupportedAggregateExecMeta(agg, - agg.requiredChildDistributionExpressions, conf, parent, rule) { + extends GpuTypedImperativeSupportedAggregateExecMeta(agg, + agg.requiredChildDistributionExpressions, conf, parent, rule) { override def tagPlanForGpu(): Unit = { super.tagPlanForGpu() @@ -1716,14 +1784,14 @@ class GpuObjectHashAggregateExecMeta( conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) - extends GpuTypedImperativeSupportedAggregateExecMeta(agg, - agg.requiredChildDistributionExpressions, conf, parent, rule) + extends GpuTypedImperativeSupportedAggregateExecMeta(agg, + agg.requiredChildDistributionExpressions, conf, parent, rule) object GpuHashAggregateExecBase { def calcInputAttributes(aggregateExpressions: Seq[GpuAggregateExpression], - childOutput: Seq[Attribute], - inputAggBufferAttributes: Seq[Attribute]): Seq[Attribute] = { + childOutput: Seq[Attribute], + inputAggBufferAttributes: Seq[Attribute]): Seq[Attribute] = { val modes = aggregateExpressions.map(_.mode).distinct if (modes.contains(Final) || modes.contains(PartialMerge)) { // SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's @@ -1754,7 +1822,7 @@ object GpuHashAggregateExecBase { } /** - * The GPU version of SortAggregateExec that is intended for partial aggregations that are not + * The GPU version of AggregateExec that is intended for partial aggregations that are not * reductions and so it sorts the input data ahead of time to do it in a single pass. * * @param requiredChildDistributionExpressions this is unchanged by the GPU. It is used in @@ -1767,7 +1835,6 @@ object GpuHashAggregateExecBase { * node should project) * @param child incoming plan (where we get input columns from) * @param configuredTargetBatchSize user-configured maximum device memory size of a batch - * @param configuredTieredProjectEnabled configurable optimization to use tiered projections * @param allowNonFullyAggregatedOutput whether we can skip the third pass of aggregation * (can omit non fully aggregated data for non-final * stage of aggregation) @@ -1802,11 +1869,13 @@ case class GpuHashAggregateExec( protected override val outputRowsLevel: MetricsLevel = ESSENTIAL_LEVEL protected override val outputBatchesLevel: MetricsLevel = MODERATE_LEVEL override lazy val additionalMetrics: Map[String, GpuMetric] = Map( - NUM_TASKS_FALL_BACKED -> createMetric(MODERATE_LEVEL, DESCRIPTION_NUM_TASKS_FALL_BACKED), + NUM_TASKS_REPARTITIONED -> createMetric(MODERATE_LEVEL, DESCRIPTION_NUM_TASKS_REPARTITIONED), + NUM_TASKS_SKIPPED_AGG -> createMetric(MODERATE_LEVEL, DESCRIPTION_NUM_TASKS_SKIPPED_AGG), OP_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_OP_TIME), AGG_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_AGG_TIME), CONCAT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_CONCAT_TIME), SORT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_SORT_TIME), + REPARTITION_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_REPARTITION_TIME), "NUM_AGGS" -> createMetric(DEBUG_LEVEL, "num agg operations"), "NUM_PRE_SPLITS" -> createMetric(DEBUG_LEVEL, "num pre splits"), "NUM_TASKS_SINGLE_PASS" -> createMetric(MODERATE_LEVEL, "number of single pass tasks"), @@ -1833,11 +1902,13 @@ case class GpuHashAggregateExec( val aggMetrics = GpuHashAggregateMetrics( numOutputRows = gpuLongMetric(NUM_OUTPUT_ROWS), numOutputBatches = gpuLongMetric(NUM_OUTPUT_BATCHES), - numTasksFallBacked = gpuLongMetric(NUM_TASKS_FALL_BACKED), + numTasksRepartitioned = gpuLongMetric(NUM_TASKS_REPARTITIONED), + numTasksSkippedAgg = gpuLongMetric(NUM_TASKS_SKIPPED_AGG), opTime = gpuLongMetric(OP_TIME), computeAggTime = gpuLongMetric(AGG_TIME), concatTime = gpuLongMetric(CONCAT_TIME), sortTime = gpuLongMetric(SORT_TIME), + repartitionTime = gpuLongMetric(REPARTITION_TIME), numAggOps = gpuLongMetric("NUM_AGGS"), numPreSplits = gpuLongMetric("NUM_PRE_SPLITS"), singlePassTasks = gpuLongMetric("NUM_TASKS_SINGLE_PASS"), @@ -1867,11 +1938,12 @@ case class GpuHashAggregateExec( val postBoundReferences = GpuAggFinalPassIterator.setupReferences(groupingExprs, aggregateExprs, aggregateAttrs, resultExprs, modeInfo) - new DynamicGpuPartialSortAggregateIterator(cbIter, inputAttrs, groupingExprs, + new DynamicGpuPartialAggregateIterator(cbIter, inputAttrs, groupingExprs, boundGroupExprs, aggregateExprs, aggregateAttrs, resultExprs, modeInfo, localEstimatedPreProcessGrowth, alreadySorted, expectedOrdering, postBoundReferences, targetBatchSize, aggMetrics, conf, - localForcePre, localAllowPre, allowNonFullyAggregatedOutput, skipAggPassReductionRatio) + localForcePre, localAllowPre, allowNonFullyAggregatedOutput, skipAggPassReductionRatio + ) } } @@ -1914,8 +1986,8 @@ case class GpuHashAggregateExec( // Used in de-duping and optimizer rules override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) // AllTuples = distribution with a single partition and all tuples of the dataset are co-located. // Clustered = dataset with tuples co-located in the same partition if they share a specific value @@ -1938,7 +2010,7 @@ case class GpuHashAggregateExec( */ override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) @@ -1957,8 +2029,8 @@ case class GpuHashAggregateExec( s"""${loreArgs.mkString(", ")}""" } else { s"$nodeName (keys=$keyString, functions=$functionString)," + - s" filters=${aggregateExpressions.map(_.filter)})" + - s""" ${loreArgs.mkString(", ")}""" + s" filters=${aggregateExpressions.map(_.filter)})" + + s""" ${loreArgs.mkString(", ")}""" } } // @@ -1972,7 +2044,7 @@ case class GpuHashAggregateExec( } } -class DynamicGpuPartialSortAggregateIterator( +class DynamicGpuPartialAggregateIterator( cbIter: Iterator[ColumnarBatch], inputAttrs: Seq[Attribute], groupingExprs: Seq[NamedExpression], @@ -1999,7 +2071,7 @@ class DynamicGpuPartialSortAggregateIterator( // When doing a reduction we don't have the aggIter setup for the very first time // so we have to match what happens for the normal reduction operations. override def hasNext: Boolean = aggIter.map(_.hasNext) - .getOrElse(isReductionOnly || cbIter.hasNext) + .getOrElse(isReductionOnly || cbIter.hasNext) private[this] def estimateCardinality(cb: ColumnarBatch): Int = { withResource(boundGroupExprs.project(cb)) { groupingKeys => @@ -2052,7 +2124,8 @@ class DynamicGpuPartialSortAggregateIterator( inputAttrs.map(_.dataType).toArray, preProcessAggHelper.preStepBound, metrics.opTime, metrics.numPreSplits) - val firstPassIter = GpuAggFirstPassIterator(sortedSplitIter, preProcessAggHelper, metrics) + val firstPassIter = GpuAggFirstPassIterator(sortedSplitIter, preProcessAggHelper, + metrics) // Technically on a partial-agg, which this only works for, this last iterator should // be a noop except for some metrics. But for consistency between all of the @@ -2071,6 +2144,7 @@ class DynamicGpuPartialSortAggregateIterator( metrics.opTime, metrics.numPreSplits) val localInputRowsMetrics = new LocalGpuMetric + val firstPassIter = GpuAggFirstPassIterator( splitInputIter.map(cb => { localInputRowsMetrics += cb.numRows() @@ -2080,7 +2154,7 @@ class DynamicGpuPartialSortAggregateIterator( metrics) val mergeIter = new GpuMergeAggregateIterator( - firstPassIter, + new CloseableBufferedIterator(firstPassIter), inputAttrs, groupingExprs, aggregateExprs, @@ -2092,7 +2166,8 @@ class DynamicGpuPartialSortAggregateIterator( conf, allowNonFullyAggregatedOutput, skipAggPassReductionRatio, - localInputRowsMetrics) + localInputRowsMetrics + ) GpuAggFinalPassIterator.makeIter(mergeIter, postBoundReferences, metrics) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index 0ffead09de6..3d9b6285a91 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -66,6 +66,7 @@ object GpuMetric extends Logging { val COLLECT_TIME = "collectTime" val CONCAT_TIME = "concatTime" val SORT_TIME = "sortTime" + val REPARTITION_TIME = "repartitionTime" val AGG_TIME = "computeAggTime" val JOIN_TIME = "joinTime" val FILTER_TIME = "filterTime" @@ -73,6 +74,8 @@ object GpuMetric extends Logging { val BUILD_TIME = "buildTime" val STREAM_TIME = "streamTime" val NUM_TASKS_FALL_BACKED = "numTasksFallBacked" + val NUM_TASKS_REPARTITIONED = "numTasksRepartitioned" + val NUM_TASKS_SKIPPED_AGG = "numTasksSkippedAgg" val READ_FS_TIME = "readFsTime" val WRITE_BUFFER_TIME = "writeBufferTime" val FILECACHE_FOOTER_HITS = "filecacheFooterHits" @@ -104,6 +107,7 @@ object GpuMetric extends Logging { val DESCRIPTION_COLLECT_TIME = "collect batch time" val DESCRIPTION_CONCAT_TIME = "concat batch time" val DESCRIPTION_SORT_TIME = "sort time" + val DESCRIPTION_REPARTITION_TIME = "repartition time" val DESCRIPTION_AGG_TIME = "aggregation time" val DESCRIPTION_JOIN_TIME = "join time" val DESCRIPTION_FILTER_TIME = "filter time" @@ -111,6 +115,8 @@ object GpuMetric extends Logging { val DESCRIPTION_BUILD_TIME = "build time" val DESCRIPTION_STREAM_TIME = "stream time" val DESCRIPTION_NUM_TASKS_FALL_BACKED = "number of sort fallback tasks" + val DESCRIPTION_NUM_TASKS_REPARTITIONED = "number of tasks repartitioned for agg" + val DESCRIPTION_NUM_TASKS_SKIPPED_AGG = "number of tasks skipped aggregation" val DESCRIPTION_READ_FS_TIME = "time to read fs data" val DESCRIPTION_WRITE_BUFFER_TIME = "time to write data to buffer" val DESCRIPTION_FILECACHE_FOOTER_HITS = "cached footer hits" diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/window/GpuUnboundedToUnboundedAggWindowExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/window/GpuUnboundedToUnboundedAggWindowExec.scala index d685efe68e0..7c5b55cd0bd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/window/GpuUnboundedToUnboundedAggWindowExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/window/GpuUnboundedToUnboundedAggWindowExec.scala @@ -17,10 +17,9 @@ package com.nvidia.spark.rapids.window import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.reflect.ClassTag import ai.rapids.cudf -import com.nvidia.spark.rapids.{ConcatAndConsumeAll, GpuAlias, GpuBindReferences, GpuBoundReference, GpuColumnVector, GpuExpression, GpuLiteral, GpuMetric, GpuProjectExec, SpillableColumnarBatch, SpillPriorities} +import com.nvidia.spark.rapids.{AutoClosableArrayBuffer, ConcatAndConsumeAll, GpuAlias, GpuBindReferences, GpuBoundReference, GpuColumnVector, GpuExpression, GpuLiteral, GpuMetric, GpuProjectExec, SpillableColumnarBatch, SpillPriorities} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRetry, withRetryNoSplit} @@ -36,32 +35,6 @@ import org.apache.spark.sql.rapids.aggregate.{CudfAggregate, GpuAggregateExpress import org.apache.spark.sql.types.{DataType, IntegerType, LongType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} - -/** - * Just a simple wrapper to make working with buffers of AutoClosable things play - * nicely with withResource. - */ -class AutoClosableArrayBuffer[T <: AutoCloseable]() extends AutoCloseable { - private val data = new ArrayBuffer[T]() - - def append(scb: T): Unit = data.append(scb) - - def last: T = data.last - - def removeLast(): T = data.remove(data.length - 1) - - def foreach[U](f: T => U): Unit = data.foreach(f) - - def toArray[B >: T : ClassTag]: Array[B] = data.toArray - - override def toString: String = s"AutoCloseable(${super.toString})" - - override def close(): Unit = { - data.foreach(_.close()) - data.clear() - } -} - /** * Utilities for conversion between SpillableColumnarBatch, ColumnarBatch, and cudf.Table. */