From 21717dd2cb43d657a49d109bdf47486410bcf51d Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Mon, 1 Jul 2024 19:22:30 +0800 Subject: [PATCH] 240701 repartition agg (#36) * workable version without tests Signed-off-by: Hongbin Ma (Mahone) * doc Signed-off-by: Hongbin Ma (Mahone) * fix scala 2.13 Signed-off-by: Hongbin Ma (Mahone) --------- Signed-off-by: Hongbin Ma (Mahone) Co-authored-by: Hongbin Ma (Mahone) --- .../advanced_configs.md | 1 + .../scala/com/nvidia/spark/rapids/Arm.scala | 19 +- .../spark/rapids/GpuAggregateExec.scala | 426 +++++++++++++----- .../com/nvidia/spark/rapids/GpuExec.scala | 2 + .../com/nvidia/spark/rapids/RapidsConf.scala | 9 + 5 files changed, 330 insertions(+), 127 deletions(-) diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 0d4a3267e30..25bba0dbd90 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -60,6 +60,7 @@ Name | Description | Default Value | Applicable at spark.rapids.shuffle.ucx.activeMessages.forceRndv|Set to true to force 'rndv' mode for all UCX Active Messages. This should only be required with UCX 1.10.x. UCX 1.11.x deployments should set to false.|false|Startup spark.rapids.shuffle.ucx.managementServerHost|The host to be used to start the management server|null|Startup spark.rapids.shuffle.ucx.useWakeup|When set to true, use UCX's event-based progress (epoll) in order to wake up the progress thread when needed, instead of a hot loop.|true|Startup +spark.rapids.sql.agg.fallbackAlgorithm|When agg cannot be done in a single pass, use sort-based fallback or repartition-based fallback.|sort|Runtime spark.rapids.sql.agg.skipAggPassReductionRatio|In non-final aggregation stages, if the previous pass has a row reduction ratio greater than this value, the next aggregation pass will be skipped.Setting this to 1 essentially disables this feature.|1.0|Runtime spark.rapids.sql.allowMultipleJars|Allow multiple rapids-4-spark, spark-rapids-jni, and cudf jars on the classpath. Spark will take the first one it finds, so the version may not be expected. Possisble values are ALWAYS: allow all jars, SAME_REVISION: only allow jars with the same revision, NEVER: do not allow multiple jars at all.|SAME_REVISION|Startup spark.rapids.sql.castDecimalToFloat.enabled|Casting from decimal to floating point types on the GPU returns results that have tiny difference compared to results returned from CPU.|true|Runtime diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala index 926f770a683..96254b9f38d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.ControlThrowable import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -43,7 +43,8 @@ object Arm extends ArmScalaSpecificImpl { } /** Executes the provided code block and then closes the sequence of resources */ - def withResource[T <: AutoCloseable, V](r: Seq[T])(block: Seq[T] => V): V = { + def withResource[T <: AutoCloseable, V](r: scala.collection.Seq[T]) + (block: scala.collection.Seq[T] => V): V = { try { block(r) } finally { @@ -134,6 +135,20 @@ object Arm extends ArmScalaSpecificImpl { } } + /** Executes the provided code block, closing the resources only if an exception occurs */ + def closeOnExcept[T <: AutoCloseable, V](r: ListBuffer[T])(block: ListBuffer[T] => V): V = { + try { + block(r) + } catch { + case t: ControlThrowable => + // Don't close for these cases.. + throw t + case t: Throwable => + r.safeClose(t) + throw t + } + } + /** Executes the provided code block, closing the resources only if an exception occurs */ def closeOnExcept[T <: AutoCloseable, V](r: mutable.Queue[T])(block: mutable.Queue[T] => V): V = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala index 7e6a1056d01..b28101f3442 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala @@ -16,11 +16,9 @@ package com.nvidia.spark.rapids -import java.util - import scala.annotation.tailrec -import scala.collection.JavaConverters.collectionAsScalaIterableConverter import scala.collection.mutable +import scala.collection.mutable.ListBuffer import ai.rapids.cudf import ai.rapids.cudf.{NvtxColor, NvtxRange} @@ -46,11 +44,11 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.{ExplainUtils, SortExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, CudfAggregate, GpuAggregateExpression, GpuToCpuAggregateBufferConverter} -import org.apache.spark.sql.rapids.execution.{GpuShuffleMeta, TrampolineUtil} +import org.apache.spark.sql.rapids.execution.{GpuBatchSubPartitioner, GpuShuffleMeta, TrampolineUtil} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -object AggregateUtils { +object AggregateUtils extends Logging { private val aggs = List("min", "max", "avg", "sum", "count", "first", "last") @@ -86,9 +84,10 @@ object AggregateUtils { /** * Computes a target input batch size based on the assumption that computation can consume up to * 4X the configured batch size. - * @param confTargetSize user-configured maximum desired batch size - * @param inputTypes input batch schema - * @param outputTypes output batch schema + * + * @param confTargetSize user-configured maximum desired batch size + * @param inputTypes input batch schema + * @param outputTypes output batch schema * @param isReductionOnly true if this is a reduction-only aggregation without grouping * @return maximum target batch size to keep computation under the 4X configured batch limit */ @@ -99,6 +98,7 @@ object AggregateUtils { isReductionOnly: Boolean): Long = { def typesToSize(types: Seq[DataType]): Long = types.map(GpuBatchUtils.estimateGpuMemory(_, nullable = false, rowCount = 1)).sum + val inputRowSize = typesToSize(inputTypes) val outputRowSize = typesToSize(outputTypes) // The cudf hash table implementation allocates four 32-bit integers per input row. @@ -124,6 +124,129 @@ object AggregateUtils { // Finally compute the input target batching size taking into account the cudf row limits Math.min(inputRowSize * maxRows, Int.MaxValue) } + + + /** + * Concatenate batches together and perform a merge aggregation on the result. The input batches + * will be closed as part of this operation. + * + * @param batches batches to concatenate and merge aggregate + * @return lazy spillable batch which has NOT been marked spillable + */ + private def concatenateAndMerge( + batches: mutable.Buffer[SpillableColumnarBatch], + metrics: GpuHashAggregateMetrics, + concatAndMergeHelper: AggHelper): SpillableColumnarBatch = { + // TODO: concatenateAndMerge (and calling code) could output a sequence + // of batches for the partial aggregate case. This would be done in case + // a retry failed a certain number of times. + val concatBatch = withResource(batches) { _ => + val concatSpillable = concatenateBatches(metrics, batches.toSeq) + withResource(concatSpillable) { + _.getColumnarBatch() + } + } + computeAggregateAndClose(metrics, concatBatch, concatAndMergeHelper) + } + + /** + * Perform a single pass over the aggregated batches attempting to merge adjacent batches. + * + * @return true if at least one merge operation occurred + */ + private def mergePass( + aggregatedBatches: mutable.Buffer[SpillableColumnarBatch], + targetMergeBatchSize: Long, + helper: AggHelper, + metrics: GpuHashAggregateMetrics + ): Boolean = { + val batchesToConcat: mutable.ArrayBuffer[SpillableColumnarBatch] = mutable.ArrayBuffer.empty + var wasBatchMerged = false + // Current size in bytes of the batches targeted for the next concatenation + var concatSize: Long = 0L + var batchesLeftInPass = aggregatedBatches.size + + while (batchesLeftInPass > 0) { + closeOnExcept(batchesToConcat) { _ => + var isConcatSearchFinished = false + // Old batches are picked up at the front of the queue and freshly merged batches are + // appended to the back of the queue. Although tempting to allow the pass to "wrap around" + // and pick up batches freshly merged in this pass, it's avoided to prevent changing the + // order of aggregated batches. + while (batchesLeftInPass > 0 && !isConcatSearchFinished) { + val candidate = aggregatedBatches.head + val potentialSize = concatSize + candidate.sizeInBytes + isConcatSearchFinished = concatSize > 0 && potentialSize > targetMergeBatchSize + if (!isConcatSearchFinished) { + batchesLeftInPass -= 1 + batchesToConcat += aggregatedBatches.remove(0) + concatSize = potentialSize + } + } + } + + val mergedBatch = if (batchesToConcat.length > 1) { + wasBatchMerged = true + concatenateAndMerge(batchesToConcat, metrics, helper) + } else { + // Unable to find a neighboring buffer to produce a valid merge in this pass, + // so simply put this buffer back on the queue for other passes. + batchesToConcat.remove(0) + } + + // Add the merged batch to the end of the aggregated batch queue. Only a single pass over + // the batches is being performed due to the batch count check above, so the single-pass + // loop will terminate before picking up this new batch. + aggregatedBatches += mergedBatch + batchesToConcat.clear() + concatSize = 0 + } + + wasBatchMerged + } + + + /** + * Attempt to merge adjacent batches in the aggregatedBatches queue until either there is only + * one batch or merging adjacent batches would exceed the target batch size. + */ + def tryMergeAggregatedBatches( + aggregatedBatches: mutable.Buffer[SpillableColumnarBatch], + isReductionOnly: Boolean, + metrics: GpuHashAggregateMetrics, + targetMergeBatchSize: Long, + helper: AggHelper + ): Unit = { + while (aggregatedBatches.size > 1) { + val concatTime = metrics.concatTime + val opTime = metrics.opTime + withResource(new NvtxWithMetrics("agg merge pass", NvtxColor.BLUE, concatTime, + opTime)) { _ => + // continue merging as long as some batches are able to be combined + if (!mergePass(aggregatedBatches, targetMergeBatchSize, helper, metrics)) + if (aggregatedBatches.size > 1 && isReductionOnly) { + // We were unable to merge the aggregated batches within the target batch size limit, + // which means normally we would fallback to a sort-based approach. However for + // reduction-only aggregation there are no keys to use for a sort. The only way this + // can work is if all batches are merged. This will exceed the target batch size limit, + // but at this point it is either risk an OOM/cudf error and potentially work or + // not work at all. + logWarning(s"Unable to merge reduction-only aggregated batches within " + + s"target batch limit of $targetMergeBatchSize, attempting to merge remaining " + + s"${aggregatedBatches.size} batches beyond limit") + withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { batchesToConcat => + aggregatedBatches.foreach(b => batchesToConcat += b) + aggregatedBatches.clear() + val batch = concatenateAndMerge(batchesToConcat, metrics, helper) + // batch does not need to be marked spillable since it is the last and only batch + // and will be immediately retrieved on the next() call. + aggregatedBatches += batch + } + } + return + } + } + } } /** Utility class to hold all of the metrics related to hash aggregation */ @@ -135,6 +258,7 @@ case class GpuHashAggregateMetrics( computeAggTime: GpuMetric, concatTime: GpuMetric, sortTime: GpuMetric, + repartitionTime: GpuMetric, numAggOps: GpuMetric, numPreSplits: GpuMetric, singlePassTasks: GpuMetric, @@ -711,6 +835,8 @@ object GpuAggFinalPassIterator { * @param useTieredProject user-specified option to enable tiered projections * @param allowNonFullyAggregatedOutput if allowed to skip third pass Agg * @param skipAggPassReductionRatio skip if the ratio of rows after a pass is bigger than this value + * @param aggFallbackAlgorithm use sort-based fallback or repartition-based fallback + * for oversize agg * @param localInputRowsCount metric to track the number of input rows processed locally */ class GpuMergeAggregateIterator( @@ -726,15 +852,17 @@ class GpuMergeAggregateIterator( useTieredProject: Boolean, allowNonFullyAggregatedOutput: Boolean, skipAggPassReductionRatio: Double, + aggFallbackAlgorithm: String, localInputRowsCount: LocalGpuMetric) extends Iterator[ColumnarBatch] with AutoCloseable with Logging { private[this] val isReductionOnly = groupingExpressions.isEmpty private[this] val targetMergeBatchSize = computeTargetMergeBatchSize(configuredTargetBatchSize) - private[this] val aggregatedBatches = new util.ArrayDeque[SpillableColumnarBatch] + private[this] val aggregatedBatches = ListBuffer.empty[SpillableColumnarBatch] private[this] var outOfCoreIter: Option[GpuOutOfCoreSortIterator] = None + private[this] var repartitionIter: Option[RepartitionAggregateIterator] = None /** Iterator for fetching aggregated batches either if: - * 1. a sort-based fallback has occurred + * 1. a sort-based/repartition-based fallback has occurred * 2. skip third pass agg has occurred **/ private[this] var fallbackIter: Option[Iterator[ColumnarBatch]] = None @@ -752,7 +880,7 @@ class GpuMergeAggregateIterator( override def hasNext: Boolean = { fallbackIter.map(_.hasNext).getOrElse { // reductions produce a result even if the input is empty - hasReductionOnlyBatch || !aggregatedBatches.isEmpty || firstPassIter.hasNext + hasReductionOnlyBatch || aggregatedBatches.nonEmpty || firstPassIter.hasNext } } @@ -769,9 +897,11 @@ class GpuMergeAggregateIterator( if (isReductionOnly || skipAggPassReductionRatio * localInputRowsCount.value >= rowsAfterFirstPassAgg) { // second pass agg - tryMergeAggregatedBatches() + AggregateUtils.tryMergeAggregatedBatches( + aggregatedBatches, isReductionOnly, + metrics, targetMergeBatchSize, concatAndMergeHelper) - val rowsAfterSecondPassAgg = aggregatedBatches.asScala.foldLeft(0L) { + val rowsAfterSecondPassAgg = aggregatedBatches.foldLeft(0L) { (totalRows, batch) => totalRows + batch.numRows() } shouldSkipThirdPassAgg = @@ -784,7 +914,7 @@ class GpuMergeAggregateIterator( } } - if (aggregatedBatches.size() > 1) { + if (aggregatedBatches.size > 1) { // Unable to merge to a single output, so must fall back if (allowNonFullyAggregatedOutput && shouldSkipThirdPassAgg) { // skip third pass agg, return the aggregated batches directly @@ -792,17 +922,23 @@ class GpuMergeAggregateIterator( s"${skipAggPassReductionRatio * 100}% of " + s"rows after first pass, skip the third pass agg") fallbackIter = Some(new Iterator[ColumnarBatch] { - override def hasNext: Boolean = !aggregatedBatches.isEmpty + override def hasNext: Boolean = aggregatedBatches.nonEmpty override def next(): ColumnarBatch = { - withResource(aggregatedBatches.pop()) { spillableBatch => + withResource(aggregatedBatches.remove(0)) { spillableBatch => spillableBatch.getColumnarBatch() } } }) } else { // fallback to sort agg, this is the third pass agg - fallbackIter = Some(buildSortFallbackIterator()) + aggFallbackAlgorithm.toLowerCase match { + case "repartition" => + fallbackIter = Some(buildRepartitionFallbackIterator()) + case "sort" => fallbackIter = Some(buildSortFallbackIterator()) + case _ => throw new IllegalArgumentException( + s"Unsupported aggregation fallback algorithm: $aggFallbackAlgorithm") + } } fallbackIter.get.next() } else if (aggregatedBatches.isEmpty) { @@ -815,7 +951,7 @@ class GpuMergeAggregateIterator( } else { // this will be the last batch hasReductionOnlyBatch = false - withResource(aggregatedBatches.pop()) { spillableBatch => + withResource(aggregatedBatches.remove(0)) { spillableBatch => spillableBatch.getColumnarBatch() } } @@ -823,10 +959,12 @@ class GpuMergeAggregateIterator( } override def close(): Unit = { - aggregatedBatches.forEach(_.safeClose()) + aggregatedBatches.foreach(_.safeClose()) aggregatedBatches.clear() outOfCoreIter.foreach(_.close()) outOfCoreIter = None + repartitionIter.foreach(_.close()) + repartitionIter = None fallbackIter = None hasReductionOnlyBatch = false } @@ -843,133 +981,161 @@ class GpuMergeAggregateIterator( while (firstPassIter.hasNext) { val batch = firstPassIter.next() rowsAfter += batch.numRows() - aggregatedBatches.add(batch) + aggregatedBatches += batch } rowsAfter } - /** - * Attempt to merge adjacent batches in the aggregatedBatches queue until either there is only - * one batch or merging adjacent batches would exceed the target batch size. - */ - private def tryMergeAggregatedBatches(): Unit = { - while (aggregatedBatches.size() > 1) { - val concatTime = metrics.concatTime - val opTime = metrics.opTime - withResource(new NvtxWithMetrics("agg merge pass", NvtxColor.BLUE, concatTime, - opTime)) { _ => - // continue merging as long as some batches are able to be combined - if (!mergePass()) { - if (aggregatedBatches.size() > 1 && isReductionOnly) { - // We were unable to merge the aggregated batches within the target batch size limit, - // which means normally we would fallback to a sort-based approach. However for - // reduction-only aggregation there are no keys to use for a sort. The only way this - // can work is if all batches are merged. This will exceed the target batch size limit, - // but at this point it is either risk an OOM/cudf error and potentially work or - // not work at all. - logWarning(s"Unable to merge reduction-only aggregated batches within " + - s"target batch limit of $targetMergeBatchSize, attempting to merge remaining " + - s"${aggregatedBatches.size()} batches beyond limit") - withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { batchesToConcat => - aggregatedBatches.forEach(b => batchesToConcat += b) - aggregatedBatches.clear() - val batch = concatenateAndMerge(batchesToConcat) - // batch does not need to be marked spillable since it is the last and only batch - // and will be immediately retrieved on the next() call. - aggregatedBatches.add(batch) - } - } - return + private lazy val concatAndMergeHelper = + new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions, + forceMerge = true, useTieredProject = useTieredProject) + + private def cbIteratorStealingFromBuffer(input: ListBuffer[SpillableColumnarBatch]) = { + val aggregatedBatchIter = new Iterator[ColumnarBatch] { + override def hasNext: Boolean = input.nonEmpty + + override def next(): ColumnarBatch = { + withResource(input.remove(0)) { spillable => + spillable.getColumnarBatch() } } } + aggregatedBatchIter } - /** - * Perform a single pass over the aggregated batches attempting to merge adjacent batches. - * @return true if at least one merge operation occurred - */ - private def mergePass(): Boolean = { - val batchesToConcat: mutable.ArrayBuffer[SpillableColumnarBatch] = mutable.ArrayBuffer.empty - var wasBatchMerged = false - // Current size in bytes of the batches targeted for the next concatenation - var concatSize: Long = 0L - var batchesLeftInPass = aggregatedBatches.size() + private case class RepartitionAggregateIterator( + inputBatches: ListBuffer[SpillableColumnarBatch], + hashKeys: Seq[GpuExpression], + targetSize: Long, + opTime: GpuMetric, + repartitionTime: GpuMetric) extends Iterator[ColumnarBatch] + with AutoCloseable { - while (batchesLeftInPass > 0) { - closeOnExcept(batchesToConcat) { _ => - var isConcatSearchFinished = false - // Old batches are picked up at the front of the queue and freshly merged batches are - // appended to the back of the queue. Although tempting to allow the pass to "wrap around" - // and pick up batches freshly merged in this pass, it's avoided to prevent changing the - // order of aggregated batches. - while (batchesLeftInPass > 0 && !isConcatSearchFinished) { - val candidate = aggregatedBatches.getFirst - val potentialSize = concatSize + candidate.sizeInBytes - isConcatSearchFinished = concatSize > 0 && potentialSize > targetMergeBatchSize - if (!isConcatSearchFinished) { - batchesLeftInPass -= 1 - batchesToConcat += aggregatedBatches.removeFirst() - concatSize = potentialSize + case class AggregatePartition(batches: ListBuffer[SpillableColumnarBatch], seed: Int) + extends AutoCloseable { + override def close(): Unit = { + batches.safeClose() + } + + def totalRows(): Long = batches.map(_.numRows()).sum + + def totalSize(): Long = batches.map(_.sizeInBytes).sum + + def split(): ListBuffer[AggregatePartition] = { + withResource(new NvtxWithMetrics("agg repartition", NvtxColor.CYAN, repartitionTime)) { _ => + if (seed > hashSeed + 20) { + throw new IllegalStateException("At most repartition 3 times for a partition") + } + val totalSize = batches.map(_.sizeInBytes).sum + val newSeed = seed + 10 + val iter = cbIteratorStealingFromBuffer(batches) + withResource(new GpuBatchSubPartitioner( + iter, hashKeys, computeNumPartitions(totalSize), newSeed, "aggRepartition")) { + partitioner => + closeOnExcept(ListBuffer.empty[AggregatePartition]) { partitions => + preparePartitions(newSeed, partitioner, partitions) + partitions + } } } } + } - val mergedBatch = if (batchesToConcat.length > 1) { - wasBatchMerged = true - concatenateAndMerge(batchesToConcat) - } else { - // Unable to find a neighboring buffer to produce a valid merge in this pass, - // so simply put this buffer back on the queue for other passes. - batchesToConcat.remove(0) + private def preparePartitions( + newSeed: Int, + partitioner: GpuBatchSubPartitioner, + partitions: ListBuffer[AggregatePartition]): Unit = { + (0 until partitioner.partitionsCount).foreach { id => + val buffer = ListBuffer.empty[SpillableColumnarBatch] + buffer ++= partitioner.releaseBatchesByPartition(id) + val newPart = AggregatePartition.apply(buffer, newSeed) + if (newPart.totalRows() > 0) { + partitions += newPart + } else { + newPart.safeClose() + } } + } - // Add the merged batch to the end of the aggregated batch queue. Only a single pass over - // the batches is being performed due to the batch count check above, so the single-pass - // loop will terminate before picking up this new batch. - aggregatedBatches.addLast(mergedBatch) - batchesToConcat.clear() - concatSize = 0 + private[this] def computeNumPartitions(totalSize: Long): Int = { + Math.floorDiv(totalSize, targetMergeBatchSize).toInt + 1 } - wasBatchMerged - } + private val hashSeed = 100 + private val aggPartitions = ListBuffer.empty[AggregatePartition] + private val deferredAggPartitions = ListBuffer.empty[AggregatePartition] + deferredAggPartitions += AggregatePartition.apply(inputBatches, hashSeed) - private lazy val concatAndMergeHelper = - new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions, - forceMerge = true, useTieredProject = useTieredProject) + override def hasNext: Boolean = aggPartitions.nonEmpty || deferredAggPartitions.nonEmpty - /** - * Concatenate batches together and perform a merge aggregation on the result. The input batches - * will be closed as part of this operation. - * @param batches batches to concatenate and merge aggregate - * @return lazy spillable batch which has NOT been marked spillable - */ - private def concatenateAndMerge( - batches: mutable.ArrayBuffer[SpillableColumnarBatch]): SpillableColumnarBatch = { - // TODO: concatenateAndMerge (and calling code) could output a sequence - // of batches for the partial aggregate case. This would be done in case - // a retry failed a certain number of times. - val concatBatch = withResource(batches) { _ => - val concatSpillable = concatenateBatches(metrics, batches.toSeq) - withResource(concatSpillable) { _.getColumnarBatch() } + override def next(): ColumnarBatch = { + withResource(new NvtxWithMetrics("RepartitionAggregateIterator.next", + NvtxColor.BLUE, opTime)) { _ => + if (aggPartitions.isEmpty && deferredAggPartitions.nonEmpty) { + val headDeferredPartition = deferredAggPartitions.remove(0) + withResource(headDeferredPartition) { _ => + aggPartitions ++= headDeferredPartition.split() + } + return next() + } + + val headPartition = aggPartitions.remove(0) + if (headPartition.totalSize() > targetMergeBatchSize) { + deferredAggPartitions += headPartition + return next() + } + + withResource(headPartition) { _ => + val batchSizeBeforeMerge = headPartition.batches.size + AggregateUtils.tryMergeAggregatedBatches( + headPartition.batches, isReductionOnly, metrics, + targetMergeBatchSize, concatAndMergeHelper) + if (headPartition.batches.size != 1) { + throw new IllegalStateException( + "Expected a single batch after tryMergeAggregatedBatches, but got " + + s"${headPartition.batches.size} batches. Before merge, there were " + + s"$batchSizeBeforeMerge batches.") + } + headPartition.batches.head.getColumnarBatch() + } + } + } + + override def close(): Unit = { + aggPartitions.foreach(_.safeClose()) + deferredAggPartitions.foreach(_.safeClose()) } - computeAggregateAndClose(metrics, concatBatch, concatAndMergeHelper) } + /** Build an iterator that uses a sort-based approach to merge aggregated batches together. */ - private def buildSortFallbackIterator(): Iterator[ColumnarBatch] = { - logInfo(s"Falling back to sort-based aggregation with ${aggregatedBatches.size()} batches") + private def buildRepartitionFallbackIterator(): Iterator[ColumnarBatch] = { + logInfo(s"Falling back to repartition-based aggregation with " + + s"${aggregatedBatches.size} batches") metrics.numTasksFallBacked += 1 - val aggregatedBatchIter = new Iterator[ColumnarBatch] { - override def hasNext: Boolean = !aggregatedBatches.isEmpty - override def next(): ColumnarBatch = { - withResource(aggregatedBatches.removeFirst()) { spillable => - spillable.getColumnarBatch() - } - } - } + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val aggBufferAttributes = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + + val hashKeys: Seq[GpuExpression] = + GpuBindReferences.bindGpuReferences(groupingAttributes, aggBufferAttributes.toSeq) + + + repartitionIter = Some(RepartitionAggregateIterator( + aggregatedBatches, + hashKeys, + targetMergeBatchSize, + opTime = metrics.opTime, + repartitionTime = metrics.repartitionTime)) + repartitionIter.get + } + + /** Build an iterator that uses a sort-based approach to merge aggregated batches together. */ + private def buildSortFallbackIterator(): Iterator[ColumnarBatch] = { + logInfo(s"Falling back to sort-based aggregation with ${aggregatedBatches.size} batches") + metrics.numTasksFallBacked += 1 + val aggregatedBatchIter = cbIteratorStealingFromBuffer(aggregatedBatches) if (isReductionOnly) { // Normally this should never happen because `tryMergeAggregatedBatches` should have done @@ -1332,7 +1498,8 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( conf.forceSinglePassPartialSortAgg, allowSinglePassAgg, allowNonFullyAggregatedOutput, - conf.skipAggPassReductionRatio) + conf.skipAggPassReductionRatio, + conf.aggFallbackAlgorithm) } } @@ -1420,7 +1587,8 @@ abstract class GpuTypedImperativeSupportedAggregateExecMeta[INPUT <: BaseAggrega false, false, false, - 1) + 1, + conf.aggFallbackAlgorithm) } else { super.convertToGpu() } @@ -1773,6 +1941,8 @@ object GpuHashAggregateExecBase { * (can omit non fully aggregated data for non-final * stage of aggregation) * @param skipAggPassReductionRatio skip if the ratio of rows after a pass is bigger than this value + * @param aggFallbackAlgorithm use sort-based fallback or repartition-based fallback for + * oversize agg */ case class GpuHashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -1787,7 +1957,8 @@ case class GpuHashAggregateExec( forceSinglePassAgg: Boolean, allowSinglePassAgg: Boolean, allowNonFullyAggregatedOutput: Boolean, - skipAggPassReductionRatio: Double + skipAggPassReductionRatio: Double, + aggFallbackAlgorithm: String ) extends ShimUnaryExecNode with GpuExec { // lifted directly from `BaseAggregateExec.inputAttributes`, edited comment. @@ -1809,6 +1980,7 @@ case class GpuHashAggregateExec( AGG_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_AGG_TIME), CONCAT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_CONCAT_TIME), SORT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_SORT_TIME), + REPARTITION_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_REPARTITION_TIME), "NUM_AGGS" -> createMetric(DEBUG_LEVEL, "num agg operations"), "NUM_PRE_SPLITS" -> createMetric(DEBUG_LEVEL, "num pre splits"), "NUM_TASKS_SINGLE_PASS" -> createMetric(MODERATE_LEVEL, "number of single pass tasks"), @@ -1839,6 +2011,7 @@ case class GpuHashAggregateExec( computeAggTime = gpuLongMetric(AGG_TIME), concatTime = gpuLongMetric(CONCAT_TIME), sortTime = gpuLongMetric(SORT_TIME), + repartitionTime = gpuLongMetric(REPARTITION_TIME), numAggOps = gpuLongMetric("NUM_AGGS"), numPreSplits = gpuLongMetric("NUM_PRE_SPLITS"), singlePassTasks = gpuLongMetric("NUM_TASKS_SINGLE_PASS"), @@ -1873,7 +2046,8 @@ case class GpuHashAggregateExec( boundGroupExprs, aggregateExprs, aggregateAttrs, resultExprs, modeInfo, localEstimatedPreProcessGrowth, alreadySorted, expectedOrdering, postBoundReferences, targetBatchSize, aggMetrics, useTieredProject, - localForcePre, localAllowPre, allowNonFullyAggregatedOutput, skipAggPassReductionRatio) + localForcePre, localAllowPre, allowNonFullyAggregatedOutput, skipAggPassReductionRatio, + aggFallbackAlgorithm) } } @@ -1991,7 +2165,8 @@ class DynamicGpuPartialSortAggregateIterator( forceSinglePassAgg: Boolean, allowSinglePassAgg: Boolean, allowNonFullyAggregatedOutput: Boolean, - skipAggPassReductionRatio: Double + skipAggPassReductionRatio: Double, + aggFallbackAlgorithm: String ) extends Iterator[ColumnarBatch] { private var aggIter: Option[Iterator[ColumnarBatch]] = None private[this] val isReductionOnly = boundGroupExprs.outputTypes.isEmpty @@ -2092,6 +2267,7 @@ class DynamicGpuPartialSortAggregateIterator( useTiered, allowNonFullyAggregatedOutput, skipAggPassReductionRatio, + aggFallbackAlgorithm, localInputRowsMetrics) GpuAggFinalPassIterator.makeIter(mergeIter, postBoundReferences, metrics) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index d83f20113b2..1cbf899c04d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -61,6 +61,7 @@ object GpuMetric extends Logging { val COLLECT_TIME = "collectTime" val CONCAT_TIME = "concatTime" val SORT_TIME = "sortTime" + val REPARTITION_TIME = "repartitionTime" val AGG_TIME = "computeAggTime" val JOIN_TIME = "joinTime" val FILTER_TIME = "filterTime" @@ -95,6 +96,7 @@ object GpuMetric extends Logging { val DESCRIPTION_COLLECT_TIME = "collect batch time" val DESCRIPTION_CONCAT_TIME = "concat batch time" val DESCRIPTION_SORT_TIME = "sort time" + val DESCRIPTION_REPARTITION_TIME = "repartition time spent in agg" val DESCRIPTION_AGG_TIME = "aggregation time" val DESCRIPTION_JOIN_TIME = "join time" val DESCRIPTION_FILTER_TIME = "filter time" diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index aad4f05b334..46c2806140e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1517,6 +1517,13 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .checkValue(v => v >= 0 && v <= 1, "The ratio value must be in [0, 1].") .createWithDefault(1.0) + val FALLBACK_ALGORITHM_FOR_OVERSIZE_AGG = conf("spark.rapids.sql.agg.fallbackAlgorithm") + .doc("When agg cannot be done in a single pass, use sort-based fallback or " + + "repartition-based fallback.") + .stringConf + .checkValues(Set("sort", "repartition")) + .createWithDefault("sort") + val FORCE_SINGLE_PASS_PARTIAL_SORT_AGG: ConfEntryWithDefault[Boolean] = conf("spark.rapids.sql.agg.forceSinglePassPartialSort") .doc("Force a single pass partial sort agg to happen in all cases that it could, " + @@ -3079,6 +3086,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val skipAggPassReductionRatio: Double = get(SKIP_AGG_PASS_REDUCTION_RATIO) + lazy val aggFallbackAlgorithm: String = get(FALLBACK_ALGORITHM_FOR_OVERSIZE_AGG) + lazy val isRegExpEnabled: Boolean = get(ENABLE_REGEXP) lazy val maxRegExpStateMemory: Long = {