Skip to content

Commit

Permalink
Address comments and add more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman committed Sep 26, 2023
1 parent 9c4136b commit 2f76d2b
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 54 deletions.
1 change: 1 addition & 0 deletions dist/unshimmed-common-from-spark311.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ com/nvidia/spark/ExclusiveModeGpuDiscoveryPlugin*
com/nvidia/spark/GpuCachedBatchSerializer*
com/nvidia/spark/ParquetCachedBatchSerializer*
com/nvidia/spark/RapidsUDF*
com/nvidia/spark/Retryable*
com/nvidia/spark/SQLPlugin*
com/nvidia/spark/rapids/ColumnarRdd*
com/nvidia/spark/rapids/GpuColumnVectorUtils*
Expand Down
49 changes: 17 additions & 32 deletions sql-plugin/src/main/java/com/nvidia/spark/Retryable.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,33 @@
package com.nvidia.spark;

/**
* An interface that can be used by Retry framework of RAPIDS Plugin to handle the GPU OOMs.
* An interface that can be used to retry the processing on non-deterministic
* expressions on the GPU.
*
* GPU memory is a limited resource, so OOM can happen if too many tasks run in parallel.
* Retry framework is introduced to improve the stability by retrying the work when it
* meets OOMs. The overall process of Retry framework is similar as the below.
* ```
* Retryable retryable
* retryable.checkpoint()
* boolean hasOOM = false
* do {
* try {
* runWorkOnGpu(retryable) // May lead to GPU OOM
* hasOOM = false
* } catch (OOMError oom) {
* hasOOM = true
* tryToReleaseSomeGpuMemoryFromLowPriorityTasks()
* retryable.restore()
* }
* } while(hasOOM)
* ```
* In a retry, "checkpoint" will be called first and only once, which is used to save the
* state for later loops. When OOM happens, "restore" will be called to restore the
* state that saved by "checkpoint". After that, it will try the same work again. And
* the whole process runs on Spark executors.
* GPU memory is a limited resource. When it runs out the RAPIDS Accelerator
* for Apache Spark will use several different strategies to try and free more
* GPU memory to let the query complete.
* One of these strategies is to roll back the processioning for one task, pause
* the task thread, then retry the task when more memory is available. This
* works transparently for any stateless deterministic processing. But technically
* an expression/UDF can be non-deterministic and/or keep state in between calls.
* This interface provides a checkpoint method to save any needed state, and a
* restore method to reset the state in the case of a retry.
*
* Retry framework expects the "runWorkOnGpu" always outputs the same result when running
* it multiple times in a retry. So if "runWorkOnGpu" is non-deterministic, it can not be
* used by Retry framework.
* The "Retryable" is designed for this kind of cases. By implementing this interface,
* "runWorkOnGpu" can become deterministic inside a retry process, making it usable for
* Retry framework to improve the stability.
* Please note that a retry is not isolated to a single expression, so a restore can
* be called even after the expression returned one or more batches of results. And
* each time checkpoint it called any previously saved state can be overwritten.
*/
public interface Retryable {
/**
* Save the state, so it can be restored in case of an OOM Retry.
* This is called inside a Spark task context on executors.
* Save the state, so it can be restored in the case of a retry.
* (This is called inside a Spark task context on executors.)
*/
void checkpoint();

/**
* Restore the state that was saved by calling to "checkpoint".
* This is called inside a Spark task context on executors.
* (This is called inside a Spark task context on executors.)
*/
void restore();
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@

package com.nvidia.spark.rapids

import ai.rapids.cudf.ColumnVector
import ai.rapids.cudf.{ColumnVector, Table}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq
import com.nvidia.spark.rapids.jni.RmmSpark

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId}
import org.apache.spark.sql.rapids.GpuGreaterThan
import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{DoubleType, IntegerType}
import org.apache.spark.sql.vectorized.ColumnarBatch

class NonDeterministicRetrySuite extends RmmSparkRetrySuiteBase {
private val NUM_ROWS = 500
private val RAND_SEED = 10
private val batchAttrs = Seq(AttributeReference("int", IntegerType)(ExprId(10)))

private def buildBatch(ints: Seq[Int] = 0 until NUM_ROWS): ColumnarBatch = {
new ColumnarBatch(
Expand All @@ -43,14 +46,14 @@ class NonDeterministicRetrySuite extends RmmSparkRetrySuiteBase {
randCol1.copyToHost()
}
withResource(randHCol1) { _ =>
// store the state, and generate data again
assert(randHCol1.getRowCount.toInt == NUM_ROWS)
// Restore the state, and generate data again
gpuRand.restore()
val randHCol2 = withResource(gpuRand.columnarEval(inputCB)) { randCol2 =>
randCol2.copyToHost()
}
withResource(randHCol2) { _ =>
// check the two random columns are equal.
assert(randHCol1.getRowCount.toInt == NUM_ROWS)
assert(randHCol1.getRowCount == randHCol2.getRowCount)
(0 until randHCol1.getRowCount.toInt).foreach { pos =>
assert(randHCol1.getDouble(pos) == randHCol2.getDouble(pos))
Expand All @@ -61,27 +64,86 @@ class NonDeterministicRetrySuite extends RmmSparkRetrySuiteBase {
}

test("GPU project retry with GPU rand") {
val childOutput = Seq(AttributeReference("int", IntegerType)(NamedExpression.newExprId))
val projectRandOnly = Seq(
GpuAlias(GpuRand(GpuLiteral(RAND_SEED)), "rand")(NamedExpression.newExprId))
val projectList = projectRandOnly ++ childOutput
def projectRand(): Seq[GpuExpression] = Seq(
GpuAlias(GpuRand(GpuLiteral(RAND_SEED)), "rand")())

Seq(true, false).foreach { useTieredProject =>
// expression should be retryable
val randOnlyProjectList = GpuBindReferences.bindGpuReferencesTiered(projectRandOnly,
childOutput, useTieredProject)
assert(randOnlyProjectList.areAllRetryable)
val boundProjectList = GpuBindReferences.bindGpuReferencesTiered(projectList,
childOutput, useTieredProject)
assert(boundProjectList.areAllRetryable)
val boundProjectRand = GpuBindReferences.bindGpuReferencesTiered(projectRand(),
batchAttrs, useTieredProject)
assert(boundProjectRand.areAllRetryable)
// project with and without retry
val batches = Seq(true, false).safeMap { forceRetry =>
val boundProjectList = GpuBindReferences.bindGpuReferencesTiered(
projectRand() ++ batchAttrs, batchAttrs, useTieredProject)
assert(boundProjectList.areAllRetryable)

val sb = closeOnExcept(buildBatch()) { cb =>
SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
}
closeOnExcept(sb) { _ =>
if (forceRetry) {
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId)
}
}
boundProjectList.projectAndCloseWithRetrySingleBatch(sb)
}
// check the random columns
val randCols = withResource(batches) { case Seq(retriedBatch, batch) =>
assert(retriedBatch.numRows() == batch.numRows())
assert(retriedBatch.numCols() == batch.numCols())
batches.safeMap(_.column(0).asInstanceOf[GpuColumnVector].copyToHost())
}
withResource(randCols) { case Seq(retriedRand, rand) =>
(0 until rand.getRowCount.toInt).foreach { pos =>
assert(retriedRand.getDouble(pos) == rand.getDouble(pos))
}
}
}
}

test("GPU filter retry with GPU rand") {
def filterRand(): Seq[GpuExpression] = Seq(
GpuGreaterThan(
GpuRand(GpuLiteral.create(RAND_SEED, IntegerType)),
GpuLiteral.create(0.1d, DoubleType)))

Seq(true, false).foreach { useTieredProject =>
// filter with and without retry
val tables = Seq(true, false).safeMap { forceRetry =>
val boundCondition = GpuBindReferences.bindGpuReferencesTiered(filterRand(),
batchAttrs, useTieredProject)
assert(boundCondition.areAllRetryable)

val cb = buildBatch()
if (forceRetry) {
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId)
}
val batchSeq = GpuFilter.filterAndClose(cb, boundCondition,
NoopMetric, NoopMetric, NoopMetric).toSeq
withResource(batchSeq) { _ =>
val tables = batchSeq.safeMap(GpuColumnVector.from)
if (tables.size == 1) {
tables.head
} else {
withResource(tables) { _ =>
assert(tables.size > 1)
Table.concatenate(tables: _*)
}
}
}
}

// project with retry
val sb = closeOnExcept(buildBatch()) { cb =>
SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
// check the outputs
val cols = withResource(tables) { case Seq(retriedTable, table) =>
assert(retriedTable.getRowCount == table.getRowCount)
assert(retriedTable.getNumberOfColumns == table.getNumberOfColumns)
tables.safeMap(_.getColumn(0).copyToHost())
}
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId)
withResource(boundProjectList.projectAndCloseWithRetrySingleBatch(sb)) { outCB =>
// We can not verify the data, so only rows number here
assertResult(NUM_ROWS)(outCB.numRows())
withResource(cols) { case Seq(retriedInts, ints) =>
(0 until ints.getRowCount.toInt).foreach { pos =>
assert(retriedInts.getInt(pos) == ints.getInt(pos))
}
}
}
}
Expand Down

0 comments on commit 2f76d2b

Please sign in to comment.