Skip to content

Commit

Permalink
Improve performance of Sort for the common single batch use case (#10572
Browse files Browse the repository at this point in the history
)

Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Mar 13, 2024
1 parent f26ce1f commit 9105fd7
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 98 deletions.
118 changes: 77 additions & 41 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -178,6 +178,45 @@ case class GpuSortEachBatchIterator(
}
}

/**
* Create an iterator that will sort each batch as it comes in. It will keep any projected
* columns in place after doing the sort on the assumption that you want to possibly combine
* them in some way afterwards.
*/
object GpuSpillableProjectedSortEachBatchIterator {
def apply(
iter: Iterator[ColumnarBatch],
sorter: GpuSorter,
opTime: GpuMetric = NoopMetric,
sortTime: GpuMetric = NoopMetric): Iterator[SpillableColumnarBatch] = {
val spillableIter = iter.flatMap { cb =>
// Filter out empty batches and make them spillable
if (cb.numRows() > 0) {
Some(SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY))
} else {
cb.close()
None
}
}

val sortedBatchIter = spillableIter.flatMap { scb =>
withRetry(scb, splitSpillableInHalfByRows) { attemptScb =>
opTime.ns {
val sortedTbl = withResource(attemptScb.getColumnarBatch()) { attemptCb =>
sorter.appendProjectedAndSort(attemptCb, sortTime)
}
withResource(sortedTbl) { _ =>
closeOnExcept(GpuColumnVector.from(sortedTbl, sorter.projectedBatchTypes)) { cb =>
SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
}
}
}
}
}
sortedBatchIter
}
}

/**
* Holds data for the out of core sort. It includes the batch of data and the first row in that
* batch so we can sort the batches.
Expand Down Expand Up @@ -249,6 +288,12 @@ case class GpuOutOfCoreSortIterator(
outputRows: GpuMetric) extends Iterator[ColumnarBatch]
with AutoCloseable {

/**
* This has already sorted the data, and it still has the projected columns in it that need to
* be removed before it is returned.
*/
val alreadySortedIter = GpuSpillableProjectedSortEachBatchIterator(iter, sorter, opTime, sortTime)

private val cpuOrd = new LazilyGeneratedOrdering(sorter.cpuOrdering)
// A priority queue of data that is not merged yet.
private val pending = new Pending(cpuOrd)
Expand All @@ -258,7 +303,7 @@ case class GpuOutOfCoreSortIterator(
// how much data, in bytes, that is stored in `sorted`
private var sortedSize = 0L

override def hasNext: Boolean = !sorted.isEmpty || !pending.isEmpty || iter.hasNext
override def hasNext: Boolean = !sorted.isEmpty || !pending.isEmpty || alreadySortedIter.hasNext

// Use types for the UnsafeProjection otherwise we need to have CPU BoundAttributeReferences
// used for converting between columnar data and rows (to get the first row in each batch).
Expand Down Expand Up @@ -398,45 +443,28 @@ case class GpuOutOfCoreSortIterator(
}

/**
* First pass through the data. Read in all of the batches, sort each batch and split them up into
* smaller chunks for later merge sorting.
* Take a single sorted batch from the `alreadySortedIter`, split it up and store them for
* merging.
*/
private final def firstPassReadBatches(): Unit = {
while(iter.hasNext) {
val spillBatch = closeOnExcept(iter.next()) { batch =>
SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
private final def splitOneSortedBatch(scb: SpillableColumnarBatch): Unit = {
withResource(new NvtxWithMetrics("split input batch", NvtxColor.CYAN, opTime)) { _ =>
val ret = withRetryNoSplit(scb) { attempt =>
onFirstPassSplit()
splitAfterSort(attempt)
}
val sortedIt =
withResource(new NvtxWithMetrics("initial sort", NvtxColor.CYAN, opTime)){ _ =>
withRetry(spillBatch, splitSpillableInHalfByRows) { attemptScb =>
onFirstPassSort()
withResource(attemptScb.getColumnarBatch()) { attemptCb =>
sorter.appendProjectedAndSort(attemptCb, sortTime)
}
}
}
saveSplitResult(ret)
}
}

withResource(new NvtxWithMetrics("split input batch", NvtxColor.CYAN, opTime)) { _ =>
while(sortedIt.hasNext) {
val sortedTbl = sortedIt.next()
val rows = sortedTbl.getRowCount.toInt
// filter out empty batches
if (rows > 0) {
val sp = withResource(sortedTbl) { _ =>
closeOnExcept(GpuColumnVector.from(sortedTbl, sorter.projectedBatchTypes)) { cb =>
SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
}
}
val ret = withRetryNoSplit(sp) { attempt =>
onFirstPassSplit()
splitAfterSort(attempt)
}
saveSplitResult(ret)
} else {
sortedTbl.close()
}
}
}
/**
* First pass through the data. Conceptually we are going to read in all of the batches, that are
* already sorted and split them up into smaller chunks for later merge sorting. But we are
* only going to do that if we have more than one batch to sort.
*/
private final def firstPassReadBatches(scb: SpillableColumnarBatch): Unit = {
splitOneSortedBatch(scb)
while (alreadySortedIter.hasNext) {
splitOneSortedBatch(alreadySortedIter.next())
}
}

Expand Down Expand Up @@ -564,10 +592,19 @@ case class GpuOutOfCoreSortIterator(
override def next(): ColumnarBatch = {
if (sorter.projectedBatchSchema.isEmpty) {
// special case, no columns just rows
iter.next()
withRetryNoSplit(alreadySortedIter.next()) { scb =>
// This should have no columns so no need to remove anything from the projected data
scb.getColumnarBatch()
}
} else {
if (pending.isEmpty && sorted.isEmpty) {
firstPassReadBatches()
closeOnExcept(alreadySortedIter.next()) { scb =>
if (!alreadySortedIter.hasNext) {
sorted.add(scb)
} else {
firstPassReadBatches(scb)
}
}
}
withResource(new NvtxWithMetrics("Sort next output batch", NvtxColor.CYAN, opTime)) { _ =>
val ret = mergeSortEnoughToOutput().getOrElse(concatOutput())
Expand All @@ -590,7 +627,6 @@ case class GpuOutOfCoreSortIterator(
}

/** Callbacks designed for unit tests only. Don't do any heavy things inside. */
protected def onFirstPassSort(): Unit = {}
protected def onFirstPassSplit(): Unit = {}
protected def onMergeSortSplit(): Unit = {}
protected def onConcatOutput(): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {
private val gpuSorter = new GpuSorter(Seq(sortOrder), Array(attrs))
private val NUM_ROWS = 100

private def batchIter(batches: Int): Iterator[ColumnarBatch] =
((0 until batches)).map { _ =>
buildBatch
}.toIterator

private def buildBatch: ColumnarBatch = {
val ints = (NUM_ROWS / 2 until NUM_ROWS) ++ (0 until NUM_ROWS / 2)
new ColumnarBatch(
Expand All @@ -41,66 +46,36 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {

test("GPU out-of-core sort without OOM failures") {
val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows(
Iterator(buildBatch),
batchIter(2),
gpuSorter,
targetSize = 1024)
withResource(outCoreIter) { _ =>
withResource(outCoreIter.next()) { cb =>
// only one batch
assertResult(NUM_ROWS)(cb.numRows())
assertResult(true)(GpuColumnVector.isTaggedAsFinalBatch(cb))
}
}
}

test("GPU out-of-core sort with retry when first-pass-sort GpuRetryOOM") {
val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows(
Iterator(buildBatch),
gpuSorter,
targetSize = 1024,
firstPassSortExp = new GpuRetryOOM())
withResource(outCoreIter) { _ =>
withResource(outCoreIter.next()) { cb =>
// only one batch
assertResult(NUM_ROWS)(cb.numRows())
assertResult(true)(GpuColumnVector.isTaggedAsFinalBatch(cb))
}
}
}

test("GPU out-of-core sort with retry when first-pass-sort GpuSplitAndRetryOOM") {
val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows(
Iterator(buildBatch),
gpuSorter,
targetSize = 1024,
firstPassSortExp = new GpuSplitAndRetryOOM())
withResource(outCoreIter) { _ =>
withResource(outCoreIter.next()) { cb =>
// only one batch
assertResult(NUM_ROWS)(cb.numRows())
assertResult(NUM_ROWS * 2)(cb.numRows())
assertResult(true)(GpuColumnVector.isTaggedAsFinalBatch(cb))
}
}
}

test("GPU out-of-core sort with retry when first-pass-split GpuRetryOOM") {
val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows(
Iterator(buildBatch),
batchIter(2),
gpuSorter,
targetSize = 1024,
firstPassSplitExp = new GpuRetryOOM())
withResource(outCoreIter) { _ =>
withResource(outCoreIter.next()) { cb =>
// only one batch
assertResult(NUM_ROWS)(cb.numRows())
assertResult(NUM_ROWS * 2)(cb.numRows())
assertResult(true)(GpuColumnVector.isTaggedAsFinalBatch(cb))
}
}
}

test("GPU out-of-core sort throws when first-pass-split GpuSplitAndRetryOOM") {
val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows(
Iterator(buildBatch),
batchIter(2),
gpuSorter,
targetSize = 1024,
firstPassSplitExp = new GpuSplitAndRetryOOM())
Expand All @@ -113,7 +88,7 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {

test("GPU out-of-core sort with retry when merge-sort-split GpuRetryOOM") {
val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows(
Iterator(buildBatch, buildBatch),
batchIter(2),
gpuSorter,
targetSize = 400,
mergeSortExp = new GpuRetryOOM())
Expand All @@ -130,7 +105,7 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {

test("GPU out-of-core sort throws when merge-sort-split GpuSplitAndRetryOOM") {
val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows(
Iterator(buildBatch, buildBatch),
batchIter(2),
gpuSorter,
targetSize = 400,
mergeSortExp = new GpuSplitAndRetryOOM())
Expand All @@ -143,7 +118,7 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {

test("GPU out-of-core sort with retry when concat-output GpuRetryOOM") {
val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows(
Iterator(buildBatch, buildBatch),
batchIter(2),
gpuSorter,
targetSize = 400,
concatOutExp = new GpuRetryOOM())
Expand All @@ -160,7 +135,7 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {

test("GPU out-of-core sort throws when concat-output GpuSplitAndRetryOOM") {
val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows(
Iterator(buildBatch, buildBatch),
batchIter(2),
gpuSorter,
targetSize = 400,
concatOutExp = new GpuSplitAndRetryOOM())
Expand All @@ -175,7 +150,6 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {
iter: Iterator[ColumnarBatch],
sorter: GpuSorter,
targetSize: Long,
firstPassSortExp: Throwable = null,
firstPassSplitExp: Throwable = null,
mergeSortExp: Throwable = null,
concatOutExp: Throwable = null,
Expand All @@ -185,11 +159,6 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {

private var expCnt = expMaxCount

override def onFirstPassSort(): Unit = if (firstPassSortExp != null && expCnt > 0) {
expCnt -= 1
throw firstPassSortExp
}

override def onFirstPassSplit(): Unit = if (firstPassSplitExp != null && expCnt > 0) {
expCnt -= 1
throw firstPassSplitExp
Expand All @@ -207,8 +176,8 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {
}

test("GPU each batch sort with GpuRetryOOM") {
val eachBatchIter = new GpuSortEachBatchIterator(
Iterator(buildBatch, buildBatch),
val eachBatchIter = GpuSortEachBatchIterator(
batchIter(2),
gpuSorter,
singleBatch = false)
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 2,
Expand All @@ -230,16 +199,19 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {
}

test("GPU each batch sort throws GpuSplitAndRetryOOM") {
val inputIter = Iterator(buildBatch, buildBatch)
val eachBatchIter = new GpuSortEachBatchIterator(
inputIter,
gpuSorter,
singleBatch = false)
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
assertThrows[GpuSplitAndRetryOOM] {
eachBatchIter.next()
val inputIter = batchIter(2)
try {
val eachBatchIter = GpuSortEachBatchIterator(
inputIter,
gpuSorter,
singleBatch = false)
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
assertThrows[GpuSplitAndRetryOOM] {
eachBatchIter.next()
}
} finally {
inputIter.foreach(_.close())
}
inputIter.foreach(_.close())
}
}

0 comments on commit 9105fd7

Please sign in to comment.