diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md
index 033e332b99c..7be166ed5de 100644
--- a/docs/additional-functionality/advanced_configs.md
+++ b/docs/additional-functionality/advanced_configs.md
@@ -60,6 +60,7 @@ Name | Description | Default Value | Applicable at
spark.rapids.shuffle.ucx.activeMessages.forceRndv|Set to true to force 'rndv' mode for all UCX Active Messages. This should only be required with UCX 1.10.x. UCX 1.11.x deployments should set to false.|false|Startup
spark.rapids.shuffle.ucx.managementServerHost|The host to be used to start the management server|null|Startup
spark.rapids.shuffle.ucx.useWakeup|When set to true, use UCX's event-based progress (epoll) in order to wake up the progress thread when needed, instead of a hot loop.|true|Startup
+spark.rapids.sql.agg.fallbackAlgorithm|When agg cannot be done in a single pass, use sort-based fallback or repartition-based fallback.|sort|Runtime
spark.rapids.sql.agg.skipAggPassReductionRatio|In non-final aggregation stages, if the previous pass has a row reduction ratio greater than this value, the next aggregation pass will be skipped.Setting this to 1 essentially disables this feature.|1.0|Runtime
spark.rapids.sql.allowMultipleJars|Allow multiple rapids-4-spark, spark-rapids-jni, and cudf jars on the classpath. Spark will take the first one it finds, so the version may not be expected. Possisble values are ALWAYS: allow all jars, SAME_REVISION: only allow jars with the same revision, NEVER: do not allow multiple jars at all.|SAME_REVISION|Startup
spark.rapids.sql.castDecimalToFloat.enabled|Casting from decimal to floating point types on the GPU returns results that have tiny difference compared to results returned from CPU.|true|Runtime
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala
index 926f770a683..96254b9f38d 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala
@@ -16,7 +16,7 @@
package com.nvidia.spark.rapids
import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.util.control.ControlThrowable
import com.nvidia.spark.rapids.RapidsPluginImplicits._
@@ -43,7 +43,8 @@ object Arm extends ArmScalaSpecificImpl {
}
/** Executes the provided code block and then closes the sequence of resources */
- def withResource[T <: AutoCloseable, V](r: Seq[T])(block: Seq[T] => V): V = {
+ def withResource[T <: AutoCloseable, V](r: scala.collection.Seq[T])
+ (block: scala.collection.Seq[T] => V): V = {
try {
block(r)
} finally {
@@ -134,6 +135,20 @@ object Arm extends ArmScalaSpecificImpl {
}
}
+ /** Executes the provided code block, closing the resources only if an exception occurs */
+ def closeOnExcept[T <: AutoCloseable, V](r: ListBuffer[T])(block: ListBuffer[T] => V): V = {
+ try {
+ block(r)
+ } catch {
+ case t: ControlThrowable =>
+ // Don't close for these cases..
+ throw t
+ case t: Throwable =>
+ r.safeClose(t)
+ throw t
+ }
+ }
+
/** Executes the provided code block, closing the resources only if an exception occurs */
def closeOnExcept[T <: AutoCloseable, V](r: mutable.Queue[T])(block: mutable.Queue[T] => V): V = {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala
index 7e6a1056d01..b28101f3442 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala
@@ -16,11 +16,9 @@
package com.nvidia.spark.rapids
-import java.util
-
import scala.annotation.tailrec
-import scala.collection.JavaConverters.collectionAsScalaIterableConverter
import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
import ai.rapids.cudf
import ai.rapids.cudf.{NvtxColor, NvtxRange}
@@ -46,11 +44,11 @@ import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{ExplainUtils, SortExec, SparkPlan}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, CudfAggregate, GpuAggregateExpression, GpuToCpuAggregateBufferConverter}
-import org.apache.spark.sql.rapids.execution.{GpuShuffleMeta, TrampolineUtil}
+import org.apache.spark.sql.rapids.execution.{GpuBatchSubPartitioner, GpuShuffleMeta, TrampolineUtil}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
-object AggregateUtils {
+object AggregateUtils extends Logging {
private val aggs = List("min", "max", "avg", "sum", "count", "first", "last")
@@ -86,9 +84,10 @@ object AggregateUtils {
/**
* Computes a target input batch size based on the assumption that computation can consume up to
* 4X the configured batch size.
- * @param confTargetSize user-configured maximum desired batch size
- * @param inputTypes input batch schema
- * @param outputTypes output batch schema
+ *
+ * @param confTargetSize user-configured maximum desired batch size
+ * @param inputTypes input batch schema
+ * @param outputTypes output batch schema
* @param isReductionOnly true if this is a reduction-only aggregation without grouping
* @return maximum target batch size to keep computation under the 4X configured batch limit
*/
@@ -99,6 +98,7 @@ object AggregateUtils {
isReductionOnly: Boolean): Long = {
def typesToSize(types: Seq[DataType]): Long =
types.map(GpuBatchUtils.estimateGpuMemory(_, nullable = false, rowCount = 1)).sum
+
val inputRowSize = typesToSize(inputTypes)
val outputRowSize = typesToSize(outputTypes)
// The cudf hash table implementation allocates four 32-bit integers per input row.
@@ -124,6 +124,129 @@ object AggregateUtils {
// Finally compute the input target batching size taking into account the cudf row limits
Math.min(inputRowSize * maxRows, Int.MaxValue)
}
+
+
+ /**
+ * Concatenate batches together and perform a merge aggregation on the result. The input batches
+ * will be closed as part of this operation.
+ *
+ * @param batches batches to concatenate and merge aggregate
+ * @return lazy spillable batch which has NOT been marked spillable
+ */
+ private def concatenateAndMerge(
+ batches: mutable.Buffer[SpillableColumnarBatch],
+ metrics: GpuHashAggregateMetrics,
+ concatAndMergeHelper: AggHelper): SpillableColumnarBatch = {
+ // TODO: concatenateAndMerge (and calling code) could output a sequence
+ // of batches for the partial aggregate case. This would be done in case
+ // a retry failed a certain number of times.
+ val concatBatch = withResource(batches) { _ =>
+ val concatSpillable = concatenateBatches(metrics, batches.toSeq)
+ withResource(concatSpillable) {
+ _.getColumnarBatch()
+ }
+ }
+ computeAggregateAndClose(metrics, concatBatch, concatAndMergeHelper)
+ }
+
+ /**
+ * Perform a single pass over the aggregated batches attempting to merge adjacent batches.
+ *
+ * @return true if at least one merge operation occurred
+ */
+ private def mergePass(
+ aggregatedBatches: mutable.Buffer[SpillableColumnarBatch],
+ targetMergeBatchSize: Long,
+ helper: AggHelper,
+ metrics: GpuHashAggregateMetrics
+ ): Boolean = {
+ val batchesToConcat: mutable.ArrayBuffer[SpillableColumnarBatch] = mutable.ArrayBuffer.empty
+ var wasBatchMerged = false
+ // Current size in bytes of the batches targeted for the next concatenation
+ var concatSize: Long = 0L
+ var batchesLeftInPass = aggregatedBatches.size
+
+ while (batchesLeftInPass > 0) {
+ closeOnExcept(batchesToConcat) { _ =>
+ var isConcatSearchFinished = false
+ // Old batches are picked up at the front of the queue and freshly merged batches are
+ // appended to the back of the queue. Although tempting to allow the pass to "wrap around"
+ // and pick up batches freshly merged in this pass, it's avoided to prevent changing the
+ // order of aggregated batches.
+ while (batchesLeftInPass > 0 && !isConcatSearchFinished) {
+ val candidate = aggregatedBatches.head
+ val potentialSize = concatSize + candidate.sizeInBytes
+ isConcatSearchFinished = concatSize > 0 && potentialSize > targetMergeBatchSize
+ if (!isConcatSearchFinished) {
+ batchesLeftInPass -= 1
+ batchesToConcat += aggregatedBatches.remove(0)
+ concatSize = potentialSize
+ }
+ }
+ }
+
+ val mergedBatch = if (batchesToConcat.length > 1) {
+ wasBatchMerged = true
+ concatenateAndMerge(batchesToConcat, metrics, helper)
+ } else {
+ // Unable to find a neighboring buffer to produce a valid merge in this pass,
+ // so simply put this buffer back on the queue for other passes.
+ batchesToConcat.remove(0)
+ }
+
+ // Add the merged batch to the end of the aggregated batch queue. Only a single pass over
+ // the batches is being performed due to the batch count check above, so the single-pass
+ // loop will terminate before picking up this new batch.
+ aggregatedBatches += mergedBatch
+ batchesToConcat.clear()
+ concatSize = 0
+ }
+
+ wasBatchMerged
+ }
+
+
+ /**
+ * Attempt to merge adjacent batches in the aggregatedBatches queue until either there is only
+ * one batch or merging adjacent batches would exceed the target batch size.
+ */
+ def tryMergeAggregatedBatches(
+ aggregatedBatches: mutable.Buffer[SpillableColumnarBatch],
+ isReductionOnly: Boolean,
+ metrics: GpuHashAggregateMetrics,
+ targetMergeBatchSize: Long,
+ helper: AggHelper
+ ): Unit = {
+ while (aggregatedBatches.size > 1) {
+ val concatTime = metrics.concatTime
+ val opTime = metrics.opTime
+ withResource(new NvtxWithMetrics("agg merge pass", NvtxColor.BLUE, concatTime,
+ opTime)) { _ =>
+ // continue merging as long as some batches are able to be combined
+ if (!mergePass(aggregatedBatches, targetMergeBatchSize, helper, metrics))
+ if (aggregatedBatches.size > 1 && isReductionOnly) {
+ // We were unable to merge the aggregated batches within the target batch size limit,
+ // which means normally we would fallback to a sort-based approach. However for
+ // reduction-only aggregation there are no keys to use for a sort. The only way this
+ // can work is if all batches are merged. This will exceed the target batch size limit,
+ // but at this point it is either risk an OOM/cudf error and potentially work or
+ // not work at all.
+ logWarning(s"Unable to merge reduction-only aggregated batches within " +
+ s"target batch limit of $targetMergeBatchSize, attempting to merge remaining " +
+ s"${aggregatedBatches.size} batches beyond limit")
+ withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { batchesToConcat =>
+ aggregatedBatches.foreach(b => batchesToConcat += b)
+ aggregatedBatches.clear()
+ val batch = concatenateAndMerge(batchesToConcat, metrics, helper)
+ // batch does not need to be marked spillable since it is the last and only batch
+ // and will be immediately retrieved on the next() call.
+ aggregatedBatches += batch
+ }
+ }
+ return
+ }
+ }
+ }
}
/** Utility class to hold all of the metrics related to hash aggregation */
@@ -135,6 +258,7 @@ case class GpuHashAggregateMetrics(
computeAggTime: GpuMetric,
concatTime: GpuMetric,
sortTime: GpuMetric,
+ repartitionTime: GpuMetric,
numAggOps: GpuMetric,
numPreSplits: GpuMetric,
singlePassTasks: GpuMetric,
@@ -711,6 +835,8 @@ object GpuAggFinalPassIterator {
* @param useTieredProject user-specified option to enable tiered projections
* @param allowNonFullyAggregatedOutput if allowed to skip third pass Agg
* @param skipAggPassReductionRatio skip if the ratio of rows after a pass is bigger than this value
+ * @param aggFallbackAlgorithm use sort-based fallback or repartition-based fallback
+ * for oversize agg
* @param localInputRowsCount metric to track the number of input rows processed locally
*/
class GpuMergeAggregateIterator(
@@ -726,15 +852,17 @@ class GpuMergeAggregateIterator(
useTieredProject: Boolean,
allowNonFullyAggregatedOutput: Boolean,
skipAggPassReductionRatio: Double,
+ aggFallbackAlgorithm: String,
localInputRowsCount: LocalGpuMetric)
extends Iterator[ColumnarBatch] with AutoCloseable with Logging {
private[this] val isReductionOnly = groupingExpressions.isEmpty
private[this] val targetMergeBatchSize = computeTargetMergeBatchSize(configuredTargetBatchSize)
- private[this] val aggregatedBatches = new util.ArrayDeque[SpillableColumnarBatch]
+ private[this] val aggregatedBatches = ListBuffer.empty[SpillableColumnarBatch]
private[this] var outOfCoreIter: Option[GpuOutOfCoreSortIterator] = None
+ private[this] var repartitionIter: Option[RepartitionAggregateIterator] = None
/** Iterator for fetching aggregated batches either if:
- * 1. a sort-based fallback has occurred
+ * 1. a sort-based/repartition-based fallback has occurred
* 2. skip third pass agg has occurred
**/
private[this] var fallbackIter: Option[Iterator[ColumnarBatch]] = None
@@ -752,7 +880,7 @@ class GpuMergeAggregateIterator(
override def hasNext: Boolean = {
fallbackIter.map(_.hasNext).getOrElse {
// reductions produce a result even if the input is empty
- hasReductionOnlyBatch || !aggregatedBatches.isEmpty || firstPassIter.hasNext
+ hasReductionOnlyBatch || aggregatedBatches.nonEmpty || firstPassIter.hasNext
}
}
@@ -769,9 +897,11 @@ class GpuMergeAggregateIterator(
if (isReductionOnly ||
skipAggPassReductionRatio * localInputRowsCount.value >= rowsAfterFirstPassAgg) {
// second pass agg
- tryMergeAggregatedBatches()
+ AggregateUtils.tryMergeAggregatedBatches(
+ aggregatedBatches, isReductionOnly,
+ metrics, targetMergeBatchSize, concatAndMergeHelper)
- val rowsAfterSecondPassAgg = aggregatedBatches.asScala.foldLeft(0L) {
+ val rowsAfterSecondPassAgg = aggregatedBatches.foldLeft(0L) {
(totalRows, batch) => totalRows + batch.numRows()
}
shouldSkipThirdPassAgg =
@@ -784,7 +914,7 @@ class GpuMergeAggregateIterator(
}
}
- if (aggregatedBatches.size() > 1) {
+ if (aggregatedBatches.size > 1) {
// Unable to merge to a single output, so must fall back
if (allowNonFullyAggregatedOutput && shouldSkipThirdPassAgg) {
// skip third pass agg, return the aggregated batches directly
@@ -792,17 +922,23 @@ class GpuMergeAggregateIterator(
s"${skipAggPassReductionRatio * 100}% of " +
s"rows after first pass, skip the third pass agg")
fallbackIter = Some(new Iterator[ColumnarBatch] {
- override def hasNext: Boolean = !aggregatedBatches.isEmpty
+ override def hasNext: Boolean = aggregatedBatches.nonEmpty
override def next(): ColumnarBatch = {
- withResource(aggregatedBatches.pop()) { spillableBatch =>
+ withResource(aggregatedBatches.remove(0)) { spillableBatch =>
spillableBatch.getColumnarBatch()
}
}
})
} else {
// fallback to sort agg, this is the third pass agg
- fallbackIter = Some(buildSortFallbackIterator())
+ aggFallbackAlgorithm.toLowerCase match {
+ case "repartition" =>
+ fallbackIter = Some(buildRepartitionFallbackIterator())
+ case "sort" => fallbackIter = Some(buildSortFallbackIterator())
+ case _ => throw new IllegalArgumentException(
+ s"Unsupported aggregation fallback algorithm: $aggFallbackAlgorithm")
+ }
}
fallbackIter.get.next()
} else if (aggregatedBatches.isEmpty) {
@@ -815,7 +951,7 @@ class GpuMergeAggregateIterator(
} else {
// this will be the last batch
hasReductionOnlyBatch = false
- withResource(aggregatedBatches.pop()) { spillableBatch =>
+ withResource(aggregatedBatches.remove(0)) { spillableBatch =>
spillableBatch.getColumnarBatch()
}
}
@@ -823,10 +959,12 @@ class GpuMergeAggregateIterator(
}
override def close(): Unit = {
- aggregatedBatches.forEach(_.safeClose())
+ aggregatedBatches.foreach(_.safeClose())
aggregatedBatches.clear()
outOfCoreIter.foreach(_.close())
outOfCoreIter = None
+ repartitionIter.foreach(_.close())
+ repartitionIter = None
fallbackIter = None
hasReductionOnlyBatch = false
}
@@ -843,133 +981,161 @@ class GpuMergeAggregateIterator(
while (firstPassIter.hasNext) {
val batch = firstPassIter.next()
rowsAfter += batch.numRows()
- aggregatedBatches.add(batch)
+ aggregatedBatches += batch
}
rowsAfter
}
- /**
- * Attempt to merge adjacent batches in the aggregatedBatches queue until either there is only
- * one batch or merging adjacent batches would exceed the target batch size.
- */
- private def tryMergeAggregatedBatches(): Unit = {
- while (aggregatedBatches.size() > 1) {
- val concatTime = metrics.concatTime
- val opTime = metrics.opTime
- withResource(new NvtxWithMetrics("agg merge pass", NvtxColor.BLUE, concatTime,
- opTime)) { _ =>
- // continue merging as long as some batches are able to be combined
- if (!mergePass()) {
- if (aggregatedBatches.size() > 1 && isReductionOnly) {
- // We were unable to merge the aggregated batches within the target batch size limit,
- // which means normally we would fallback to a sort-based approach. However for
- // reduction-only aggregation there are no keys to use for a sort. The only way this
- // can work is if all batches are merged. This will exceed the target batch size limit,
- // but at this point it is either risk an OOM/cudf error and potentially work or
- // not work at all.
- logWarning(s"Unable to merge reduction-only aggregated batches within " +
- s"target batch limit of $targetMergeBatchSize, attempting to merge remaining " +
- s"${aggregatedBatches.size()} batches beyond limit")
- withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { batchesToConcat =>
- aggregatedBatches.forEach(b => batchesToConcat += b)
- aggregatedBatches.clear()
- val batch = concatenateAndMerge(batchesToConcat)
- // batch does not need to be marked spillable since it is the last and only batch
- // and will be immediately retrieved on the next() call.
- aggregatedBatches.add(batch)
- }
- }
- return
+ private lazy val concatAndMergeHelper =
+ new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions,
+ forceMerge = true, useTieredProject = useTieredProject)
+
+ private def cbIteratorStealingFromBuffer(input: ListBuffer[SpillableColumnarBatch]) = {
+ val aggregatedBatchIter = new Iterator[ColumnarBatch] {
+ override def hasNext: Boolean = input.nonEmpty
+
+ override def next(): ColumnarBatch = {
+ withResource(input.remove(0)) { spillable =>
+ spillable.getColumnarBatch()
}
}
}
+ aggregatedBatchIter
}
- /**
- * Perform a single pass over the aggregated batches attempting to merge adjacent batches.
- * @return true if at least one merge operation occurred
- */
- private def mergePass(): Boolean = {
- val batchesToConcat: mutable.ArrayBuffer[SpillableColumnarBatch] = mutable.ArrayBuffer.empty
- var wasBatchMerged = false
- // Current size in bytes of the batches targeted for the next concatenation
- var concatSize: Long = 0L
- var batchesLeftInPass = aggregatedBatches.size()
+ private case class RepartitionAggregateIterator(
+ inputBatches: ListBuffer[SpillableColumnarBatch],
+ hashKeys: Seq[GpuExpression],
+ targetSize: Long,
+ opTime: GpuMetric,
+ repartitionTime: GpuMetric) extends Iterator[ColumnarBatch]
+ with AutoCloseable {
- while (batchesLeftInPass > 0) {
- closeOnExcept(batchesToConcat) { _ =>
- var isConcatSearchFinished = false
- // Old batches are picked up at the front of the queue and freshly merged batches are
- // appended to the back of the queue. Although tempting to allow the pass to "wrap around"
- // and pick up batches freshly merged in this pass, it's avoided to prevent changing the
- // order of aggregated batches.
- while (batchesLeftInPass > 0 && !isConcatSearchFinished) {
- val candidate = aggregatedBatches.getFirst
- val potentialSize = concatSize + candidate.sizeInBytes
- isConcatSearchFinished = concatSize > 0 && potentialSize > targetMergeBatchSize
- if (!isConcatSearchFinished) {
- batchesLeftInPass -= 1
- batchesToConcat += aggregatedBatches.removeFirst()
- concatSize = potentialSize
+ case class AggregatePartition(batches: ListBuffer[SpillableColumnarBatch], seed: Int)
+ extends AutoCloseable {
+ override def close(): Unit = {
+ batches.safeClose()
+ }
+
+ def totalRows(): Long = batches.map(_.numRows()).sum
+
+ def totalSize(): Long = batches.map(_.sizeInBytes).sum
+
+ def split(): ListBuffer[AggregatePartition] = {
+ withResource(new NvtxWithMetrics("agg repartition", NvtxColor.CYAN, repartitionTime)) { _ =>
+ if (seed > hashSeed + 20) {
+ throw new IllegalStateException("At most repartition 3 times for a partition")
+ }
+ val totalSize = batches.map(_.sizeInBytes).sum
+ val newSeed = seed + 10
+ val iter = cbIteratorStealingFromBuffer(batches)
+ withResource(new GpuBatchSubPartitioner(
+ iter, hashKeys, computeNumPartitions(totalSize), newSeed, "aggRepartition")) {
+ partitioner =>
+ closeOnExcept(ListBuffer.empty[AggregatePartition]) { partitions =>
+ preparePartitions(newSeed, partitioner, partitions)
+ partitions
+ }
}
}
}
+ }
- val mergedBatch = if (batchesToConcat.length > 1) {
- wasBatchMerged = true
- concatenateAndMerge(batchesToConcat)
- } else {
- // Unable to find a neighboring buffer to produce a valid merge in this pass,
- // so simply put this buffer back on the queue for other passes.
- batchesToConcat.remove(0)
+ private def preparePartitions(
+ newSeed: Int,
+ partitioner: GpuBatchSubPartitioner,
+ partitions: ListBuffer[AggregatePartition]): Unit = {
+ (0 until partitioner.partitionsCount).foreach { id =>
+ val buffer = ListBuffer.empty[SpillableColumnarBatch]
+ buffer ++= partitioner.releaseBatchesByPartition(id)
+ val newPart = AggregatePartition.apply(buffer, newSeed)
+ if (newPart.totalRows() > 0) {
+ partitions += newPart
+ } else {
+ newPart.safeClose()
+ }
}
+ }
- // Add the merged batch to the end of the aggregated batch queue. Only a single pass over
- // the batches is being performed due to the batch count check above, so the single-pass
- // loop will terminate before picking up this new batch.
- aggregatedBatches.addLast(mergedBatch)
- batchesToConcat.clear()
- concatSize = 0
+ private[this] def computeNumPartitions(totalSize: Long): Int = {
+ Math.floorDiv(totalSize, targetMergeBatchSize).toInt + 1
}
- wasBatchMerged
- }
+ private val hashSeed = 100
+ private val aggPartitions = ListBuffer.empty[AggregatePartition]
+ private val deferredAggPartitions = ListBuffer.empty[AggregatePartition]
+ deferredAggPartitions += AggregatePartition.apply(inputBatches, hashSeed)
- private lazy val concatAndMergeHelper =
- new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions,
- forceMerge = true, useTieredProject = useTieredProject)
+ override def hasNext: Boolean = aggPartitions.nonEmpty || deferredAggPartitions.nonEmpty
- /**
- * Concatenate batches together and perform a merge aggregation on the result. The input batches
- * will be closed as part of this operation.
- * @param batches batches to concatenate and merge aggregate
- * @return lazy spillable batch which has NOT been marked spillable
- */
- private def concatenateAndMerge(
- batches: mutable.ArrayBuffer[SpillableColumnarBatch]): SpillableColumnarBatch = {
- // TODO: concatenateAndMerge (and calling code) could output a sequence
- // of batches for the partial aggregate case. This would be done in case
- // a retry failed a certain number of times.
- val concatBatch = withResource(batches) { _ =>
- val concatSpillable = concatenateBatches(metrics, batches.toSeq)
- withResource(concatSpillable) { _.getColumnarBatch() }
+ override def next(): ColumnarBatch = {
+ withResource(new NvtxWithMetrics("RepartitionAggregateIterator.next",
+ NvtxColor.BLUE, opTime)) { _ =>
+ if (aggPartitions.isEmpty && deferredAggPartitions.nonEmpty) {
+ val headDeferredPartition = deferredAggPartitions.remove(0)
+ withResource(headDeferredPartition) { _ =>
+ aggPartitions ++= headDeferredPartition.split()
+ }
+ return next()
+ }
+
+ val headPartition = aggPartitions.remove(0)
+ if (headPartition.totalSize() > targetMergeBatchSize) {
+ deferredAggPartitions += headPartition
+ return next()
+ }
+
+ withResource(headPartition) { _ =>
+ val batchSizeBeforeMerge = headPartition.batches.size
+ AggregateUtils.tryMergeAggregatedBatches(
+ headPartition.batches, isReductionOnly, metrics,
+ targetMergeBatchSize, concatAndMergeHelper)
+ if (headPartition.batches.size != 1) {
+ throw new IllegalStateException(
+ "Expected a single batch after tryMergeAggregatedBatches, but got " +
+ s"${headPartition.batches.size} batches. Before merge, there were " +
+ s"$batchSizeBeforeMerge batches.")
+ }
+ headPartition.batches.head.getColumnarBatch()
+ }
+ }
+ }
+
+ override def close(): Unit = {
+ aggPartitions.foreach(_.safeClose())
+ deferredAggPartitions.foreach(_.safeClose())
}
- computeAggregateAndClose(metrics, concatBatch, concatAndMergeHelper)
}
+
/** Build an iterator that uses a sort-based approach to merge aggregated batches together. */
- private def buildSortFallbackIterator(): Iterator[ColumnarBatch] = {
- logInfo(s"Falling back to sort-based aggregation with ${aggregatedBatches.size()} batches")
+ private def buildRepartitionFallbackIterator(): Iterator[ColumnarBatch] = {
+ logInfo(s"Falling back to repartition-based aggregation with " +
+ s"${aggregatedBatches.size} batches")
metrics.numTasksFallBacked += 1
- val aggregatedBatchIter = new Iterator[ColumnarBatch] {
- override def hasNext: Boolean = !aggregatedBatches.isEmpty
- override def next(): ColumnarBatch = {
- withResource(aggregatedBatches.removeFirst()) { spillable =>
- spillable.getColumnarBatch()
- }
- }
- }
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
+ val aggBufferAttributes = groupingAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+
+ val hashKeys: Seq[GpuExpression] =
+ GpuBindReferences.bindGpuReferences(groupingAttributes, aggBufferAttributes.toSeq)
+
+
+ repartitionIter = Some(RepartitionAggregateIterator(
+ aggregatedBatches,
+ hashKeys,
+ targetMergeBatchSize,
+ opTime = metrics.opTime,
+ repartitionTime = metrics.repartitionTime))
+ repartitionIter.get
+ }
+
+ /** Build an iterator that uses a sort-based approach to merge aggregated batches together. */
+ private def buildSortFallbackIterator(): Iterator[ColumnarBatch] = {
+ logInfo(s"Falling back to sort-based aggregation with ${aggregatedBatches.size} batches")
+ metrics.numTasksFallBacked += 1
+ val aggregatedBatchIter = cbIteratorStealingFromBuffer(aggregatedBatches)
if (isReductionOnly) {
// Normally this should never happen because `tryMergeAggregatedBatches` should have done
@@ -1332,7 +1498,8 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan](
conf.forceSinglePassPartialSortAgg,
allowSinglePassAgg,
allowNonFullyAggregatedOutput,
- conf.skipAggPassReductionRatio)
+ conf.skipAggPassReductionRatio,
+ conf.aggFallbackAlgorithm)
}
}
@@ -1420,7 +1587,8 @@ abstract class GpuTypedImperativeSupportedAggregateExecMeta[INPUT <: BaseAggrega
false,
false,
false,
- 1)
+ 1,
+ conf.aggFallbackAlgorithm)
} else {
super.convertToGpu()
}
@@ -1773,6 +1941,8 @@ object GpuHashAggregateExecBase {
* (can omit non fully aggregated data for non-final
* stage of aggregation)
* @param skipAggPassReductionRatio skip if the ratio of rows after a pass is bigger than this value
+ * @param aggFallbackAlgorithm use sort-based fallback or repartition-based fallback for
+ * oversize agg
*/
case class GpuHashAggregateExec(
requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -1787,7 +1957,8 @@ case class GpuHashAggregateExec(
forceSinglePassAgg: Boolean,
allowSinglePassAgg: Boolean,
allowNonFullyAggregatedOutput: Boolean,
- skipAggPassReductionRatio: Double
+ skipAggPassReductionRatio: Double,
+ aggFallbackAlgorithm: String
) extends ShimUnaryExecNode with GpuExec {
// lifted directly from `BaseAggregateExec.inputAttributes`, edited comment.
@@ -1809,6 +1980,7 @@ case class GpuHashAggregateExec(
AGG_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_AGG_TIME),
CONCAT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_CONCAT_TIME),
SORT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_SORT_TIME),
+ REPARTITION_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_REPARTITION_TIME),
"NUM_AGGS" -> createMetric(DEBUG_LEVEL, "num agg operations"),
"NUM_PRE_SPLITS" -> createMetric(DEBUG_LEVEL, "num pre splits"),
"NUM_TASKS_SINGLE_PASS" -> createMetric(MODERATE_LEVEL, "number of single pass tasks"),
@@ -1839,6 +2011,7 @@ case class GpuHashAggregateExec(
computeAggTime = gpuLongMetric(AGG_TIME),
concatTime = gpuLongMetric(CONCAT_TIME),
sortTime = gpuLongMetric(SORT_TIME),
+ repartitionTime = gpuLongMetric(REPARTITION_TIME),
numAggOps = gpuLongMetric("NUM_AGGS"),
numPreSplits = gpuLongMetric("NUM_PRE_SPLITS"),
singlePassTasks = gpuLongMetric("NUM_TASKS_SINGLE_PASS"),
@@ -1873,7 +2046,8 @@ case class GpuHashAggregateExec(
boundGroupExprs, aggregateExprs, aggregateAttrs, resultExprs, modeInfo,
localEstimatedPreProcessGrowth, alreadySorted, expectedOrdering,
postBoundReferences, targetBatchSize, aggMetrics, useTieredProject,
- localForcePre, localAllowPre, allowNonFullyAggregatedOutput, skipAggPassReductionRatio)
+ localForcePre, localAllowPre, allowNonFullyAggregatedOutput, skipAggPassReductionRatio,
+ aggFallbackAlgorithm)
}
}
@@ -1991,7 +2165,8 @@ class DynamicGpuPartialSortAggregateIterator(
forceSinglePassAgg: Boolean,
allowSinglePassAgg: Boolean,
allowNonFullyAggregatedOutput: Boolean,
- skipAggPassReductionRatio: Double
+ skipAggPassReductionRatio: Double,
+ aggFallbackAlgorithm: String
) extends Iterator[ColumnarBatch] {
private var aggIter: Option[Iterator[ColumnarBatch]] = None
private[this] val isReductionOnly = boundGroupExprs.outputTypes.isEmpty
@@ -2092,6 +2267,7 @@ class DynamicGpuPartialSortAggregateIterator(
useTiered,
allowNonFullyAggregatedOutput,
skipAggPassReductionRatio,
+ aggFallbackAlgorithm,
localInputRowsMetrics)
GpuAggFinalPassIterator.makeIter(mergeIter, postBoundReferences, metrics)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
index d83f20113b2..1cbf899c04d 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
@@ -61,6 +61,7 @@ object GpuMetric extends Logging {
val COLLECT_TIME = "collectTime"
val CONCAT_TIME = "concatTime"
val SORT_TIME = "sortTime"
+ val REPARTITION_TIME = "repartitionTime"
val AGG_TIME = "computeAggTime"
val JOIN_TIME = "joinTime"
val FILTER_TIME = "filterTime"
@@ -95,6 +96,7 @@ object GpuMetric extends Logging {
val DESCRIPTION_COLLECT_TIME = "collect batch time"
val DESCRIPTION_CONCAT_TIME = "concat batch time"
val DESCRIPTION_SORT_TIME = "sort time"
+ val DESCRIPTION_REPARTITION_TIME = "repartition time spent in agg"
val DESCRIPTION_AGG_TIME = "aggregation time"
val DESCRIPTION_JOIN_TIME = "join time"
val DESCRIPTION_FILTER_TIME = "filter time"
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
index aad4f05b334..46c2806140e 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
@@ -1517,6 +1517,13 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.checkValue(v => v >= 0 && v <= 1, "The ratio value must be in [0, 1].")
.createWithDefault(1.0)
+ val FALLBACK_ALGORITHM_FOR_OVERSIZE_AGG = conf("spark.rapids.sql.agg.fallbackAlgorithm")
+ .doc("When agg cannot be done in a single pass, use sort-based fallback or " +
+ "repartition-based fallback.")
+ .stringConf
+ .checkValues(Set("sort", "repartition"))
+ .createWithDefault("sort")
+
val FORCE_SINGLE_PASS_PARTIAL_SORT_AGG: ConfEntryWithDefault[Boolean] =
conf("spark.rapids.sql.agg.forceSinglePassPartialSort")
.doc("Force a single pass partial sort agg to happen in all cases that it could, " +
@@ -3079,6 +3086,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {
lazy val skipAggPassReductionRatio: Double = get(SKIP_AGG_PASS_REDUCTION_RATIO)
+ lazy val aggFallbackAlgorithm: String = get(FALLBACK_ALGORITHM_FOR_OVERSIZE_AGG)
+
lazy val isRegExpEnabled: Boolean = get(ENABLE_REGEXP)
lazy val maxRegExpStateMemory: Long = {