Skip to content

Commit

Permalink
Spill framework refactor for better performance and extensibility [da…
Browse files Browse the repository at this point in the history
…tabricks] (#11747)

* Spill framework refactor for better performance and extensibility

Signed-off-by: Alessandro Bellina <[email protected]>

---------

Signed-off-by: Alessandro Bellina <[email protected]>
  • Loading branch information
abellina authored Dec 13, 2024
1 parent 561068c commit e3798d2
Show file tree
Hide file tree
Showing 62 changed files with 4,146 additions and 6,908 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,10 @@ public final int numNulls() {

public static long getTotalDeviceMemoryUsed(ColumnarBatch batch) {
long sum = 0;
if (batch.numCols() == 1 && batch.column(0) instanceof GpuPackedTableColumn) {
// this is a special case for a packed batch
return ((GpuPackedTableColumn) batch.column(0)).getTableBuffer().getLength();
}
if (batch.numCols() > 0) {
if (batch.column(0) instanceof WithTableBuffer) {
WithTableBuffer wtb = (WithTableBuffer) batch.column(0);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -113,6 +113,29 @@ public GpuColumnVectorFromBuffer(DataType type, ColumnVector cudfColumn,
this.tableMeta = meta;
}

public static boolean isFromBuffer(ColumnarBatch cb) {
if (cb.numCols() > 0) {
long bufferAddr = 0L;
boolean isSet = false;
for (int i = 0; i < cb.numCols(); ++i) {
GpuColumnVectorFromBuffer gcvfb = null;
if (!(cb.column(i) instanceof GpuColumnVectorFromBuffer)) {
return false;
} else {
gcvfb = (GpuColumnVectorFromBuffer) cb.column(i);
if (!isSet) {
bufferAddr = gcvfb.buffer.getAddress();
isSet = true;
} else if (bufferAddr != gcvfb.buffer.getAddress()) {
return false;
}
}
}
return true;
}
return false;
}

/**
* Get the underlying contiguous buffer, shared between columns of the original
* `ContiguousTable`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,6 +47,15 @@ public static boolean isBatchCompressed(ColumnarBatch batch) {
return batch.numCols() == 1 && batch.column(0) instanceof GpuCompressedColumnVector;
}

public static ColumnarBatch incRefCounts(ColumnarBatch batch) {
if (!isBatchCompressed(batch)) {
throw new IllegalStateException(
"Attempted to incRefCount for a compressed batch, but the batch was not compressed.");
}
((GpuCompressedColumnVector)batch.column(0)).buffer.incRefCount();
return batch;
}

/**
* Build a columnar batch from a compressed data buffer and specified table metadata
* NOTE: The data remains compressed and cannot be accessed directly from the columnar batch.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,11 @@
import java.util.NoSuchElementException;
import java.util.Optional;

import com.nvidia.spark.Retryable;
import scala.Option;
import scala.Tuple2;
import scala.collection.Iterator;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVector;
import ai.rapids.cudf.HostColumnVectorCore;
import ai.rapids.cudf.HostMemoryBuffer;
import ai.rapids.cudf.NvtxColor;
import ai.rapids.cudf.NvtxRange;
import ai.rapids.cudf.Table;
import ai.rapids.cudf.*;
import com.nvidia.spark.rapids.jni.RowConversion;
import com.nvidia.spark.rapids.shims.CudfUnsafeRow;

Expand Down Expand Up @@ -236,8 +228,7 @@ private HostMemoryBuffer[] getHostBuffersWithRetry(
try {
hBuf = HostAlloc$.MODULE$.alloc((dataBytes + offsetBytes),true);
SpillableHostBuffer sBuf = SpillableHostBuffer$.MODULE$.apply(hBuf, hBuf.getLength(),
SpillPriorities$.MODULE$.ACTIVE_ON_DECK_PRIORITY(),
RapidsBufferCatalog$.MODULE$.singleton());
SpillPriorities$.MODULE$.ACTIVE_ON_DECK_PRIORITY());
hBuf = null; // taken over by spillable host buffer
return Tuple2.apply(sBuf, numRowsWrapper);
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,6 @@ object RapidsPluginImplicits {
}
}

implicit class RapidsBufferColumn(rapidsBuffer: RapidsBuffer) {

/**
* safeFree: Is an implicit on RapidsBuffer class that tries to free the resource, if an
* Exception was thrown prior to this free, it adds the new exception to the suppressed
* exceptions, otherwise just throws
*
* @param e Exception which we don't want to suppress
*/
def safeFree(e: Throwable = null): Unit = {
if (rapidsBuffer != null) {
try {
rapidsBuffer.free()
} catch {
case suppressed: Throwable if e != null => e.addSuppressed(suppressed)
}
}
}
}

implicit class AutoCloseableSeq[A <: AutoCloseable](val in: collection.SeqLike[A, _]) {
/**
* safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each
Expand Down Expand Up @@ -111,46 +91,12 @@ object RapidsPluginImplicits {
}
}

implicit class RapidsBufferSeq[A <: RapidsBuffer](val in: collection.SeqLike[A, _]) {
/**
* safeFree: Is an implicit on a sequence of RapidsBuffer classes that tries to free each
* element of the sequence, even if prior free calls fail. In case of failure in any of the
* free calls, an Exception is thrown containing the suppressed exceptions (getSuppressed),
* if any.
*/
def safeFree(error: Throwable = null): Unit = if (in != null) {
var freeException: Throwable = null
in.foreach { element =>
if (element != null) {
try {
element.free()
} catch {
case e: Throwable if error != null => error.addSuppressed(e)
case e: Throwable if freeException == null => freeException = e
case e: Throwable => freeException.addSuppressed(e)
}
}
}
if (freeException != null) {
// an exception happened while we were trying to safely free
// resources, throw the exception to alert the caller
throw freeException
}
}
}

implicit class AutoCloseableArray[A <: AutoCloseable](val in: Array[A]) {
def safeClose(e: Throwable = null): Unit = if (in != null) {
in.toSeq.safeClose(e)
}
}

implicit class RapidsBufferArray[A <: RapidsBuffer](val in: Array[A]) {
def safeFree(e: Throwable = null): Unit = if (in != null) {
in.toSeq.safeFree(e)
}
}

class MapsSafely[A, Repr] {
/**
* safeMap: safeMap implementation that is leveraged by other type-specific implicits.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -63,26 +63,6 @@ object RapidsPluginImplicits {
}
}

implicit class RapidsBufferColumn(rapidsBuffer: RapidsBuffer) {

/**
* safeFree: Is an implicit on RapidsBuffer class that tries to free the resource, if an
* Exception was thrown prior to this free, it adds the new exception to the suppressed
* exceptions, otherwise just throws
*
* @param e Exception which we don't want to suppress
*/
def safeFree(e: Throwable = null): Unit = {
if (rapidsBuffer != null) {
try {
rapidsBuffer.free()
} catch {
case suppressed: Throwable if e != null => e.addSuppressed(suppressed)
}
}
}
}

implicit class AutoCloseableSeq[A <: AutoCloseable](val in: collection.Iterable[A]) {
/**
* safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each
Expand Down Expand Up @@ -111,46 +91,12 @@ object RapidsPluginImplicits {
}
}

implicit class RapidsBufferSeq[A <: RapidsBuffer](val in: collection.SeqLike[A, _]) {
/**
* safeFree: Is an implicit on a sequence of RapidsBuffer classes that tries to free each
* element of the sequence, even if prior free calls fail. In case of failure in any of the
* free calls, an Exception is thrown containing the suppressed exceptions (getSuppressed),
* if any.
*/
def safeFree(error: Throwable = null): Unit = if (in != null) {
var freeException: Throwable = null
in.foreach { element =>
if (element != null) {
try {
element.free()
} catch {
case e: Throwable if error != null => error.addSuppressed(e)
case e: Throwable if freeException == null => freeException = e
case e: Throwable => freeException.addSuppressed(e)
}
}
}
if (freeException != null) {
// an exception happened while we were trying to safely free
// resources, throw the exception to alert the caller
throw freeException
}
}
}

implicit class AutoCloseableArray[A <: AutoCloseable](val in: Array[A]) {
def safeClose(e: Throwable = null): Unit = if (in != null) {
in.toSeq.safeClose(e)
}
}

implicit class RapidsBufferArray[A <: RapidsBuffer](val in: Array[A]) {
def safeFree(e: Throwable = null): Unit = if (in != null) {
in.toSeq.safeFree(e)
}
}

class IterableMapsSafely[A,
From[A] <: collection.Iterable[A] with collection.IterableOps[A, From, _]] {
/**
Expand Down
14 changes: 0 additions & 14 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,6 @@ object Arm extends ArmScalaSpecificImpl {
}
}

/** Executes the provided code block, freeing the RapidsBuffer only if an exception occurs */
def freeOnExcept[T <: RapidsBuffer, V](r: T)(block: T => V): V = {
try {
block(r)
} catch {
case t: ControlThrowable =>
// Don't close for these cases..
throw t
case t: Throwable =>
r.safeFree(t)
throw t
}
}

/** Executes the provided code block and then closes the resource */
def withResource[T <: AutoCloseable, V](h: CloseableHolder[T])
(block: CloseableHolder[T] => V): V = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@ import java.lang.management.ManagementFactory
import java.util.concurrent.atomic.AtomicLong

import ai.rapids.cudf.{Cuda, Rmm, RmmEventHandler}
import com.nvidia.spark.rapids.spill.SpillableDeviceStore
import com.sun.management.HotSpotDiagnosticMXBean

import org.apache.spark.internal.Logging
Expand All @@ -34,8 +35,7 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil
* depleting the device store
*/
class DeviceMemoryEventHandler(
catalog: RapidsBufferCatalog,
store: RapidsDeviceMemoryStore,
store: SpillableDeviceStore,
oomDumpDir: Option[String],
maxFailedOOMRetries: Int) extends RmmEventHandler with Logging {

Expand Down Expand Up @@ -92,8 +92,8 @@ class DeviceMemoryEventHandler(
* from cuDF. If we succeed, cuDF resets `retryCount`, and so the new count sent to us
* must be <= than what we saw last, so we can reset our tracking.
*/
def resetIfNeeded(retryCount: Int, storeSpillableSize: Long): Unit = {
if (storeSpillableSize != 0 || retryCount <= retryCountLastSynced) {
def resetIfNeeded(retryCount: Int, couldSpill: Boolean): Unit = {
if (couldSpill || retryCount <= retryCountLastSynced) {
reset()
}
}
Expand All @@ -114,22 +114,20 @@ class DeviceMemoryEventHandler(
s"onAllocFailure invoked with invalid retryCount $retryCount")

try {
val storeSize = store.currentSize
val storeSpillableSize = store.currentSpillableSize

val attemptMsg = if (retryCount > 0) {
s"Attempt ${retryCount}. "
} else {
"First attempt. "
}

val retryState = oomRetryState.get()
retryState.resetIfNeeded(retryCount, storeSpillableSize)

logInfo(s"Device allocation of $allocSize bytes failed, device store has " +
s"$storeSize total and $storeSpillableSize spillable bytes. $attemptMsg" +
s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes. ")
if (storeSpillableSize == 0) {
val amountSpilled = store.spill(allocSize)
retryState.resetIfNeeded(retryCount, amountSpilled > 0)
logInfo(s"Device allocation of $allocSize bytes failed. " +
s"Device store spilled $amountSpilled bytes. $attemptMsg" +
s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.")
if (amountSpilled == 0) {
if (retryState.shouldTrySynchronizing(retryCount)) {
Cuda.deviceSynchronize()
logWarning(s"[RETRY ${retryState.getRetriesSoFar}] " +
Expand All @@ -149,13 +147,7 @@ class DeviceMemoryEventHandler(
false
}
} else {
val targetSize = Math.max(storeSpillableSize - allocSize, 0)
logDebug(s"Targeting device store size of $targetSize bytes")
val maybeAmountSpilled = catalog.synchronousSpill(store, targetSize, Cuda.DEFAULT_STREAM)
maybeAmountSpilled.foreach { amountSpilled =>
logInfo(s"Spilled $amountSpilled bytes from the device store")
TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled)
}
TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled)
true
}
} catch {
Expand Down
Loading

0 comments on commit e3798d2

Please sign in to comment.