Skip to content

Commit

Permalink
Break down spill into a two-step process, so we can inject a device s…
Browse files Browse the repository at this point in the history
…ynchronize between
  • Loading branch information
abellina committed Dec 6, 2024
1 parent 418bfe4 commit 8e8aefc
Showing 1 changed file with 78 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ import org.apache.spark.storage.BlockId
* responsibility to initiate that spill, and to track the spilled handle (a device spillable
* would have a `host` handle, which tracks the host spilled object).
*
* Spill is broken down into two methods: `spill` and `releaseSpilled`. This is a two stage
* process because we need to make sure that there is no code running kernels on the spilled
* data before we actually free it. See method documentations for `spill` and `releasedSpilled`
* for more info.
*
* A cascade of spills can occur device -> host -> disk, given that host allocations can fail, or
* could not fit in the SpillableHostStore's limit (if defined). In this case, the call to spill
* will either create a host handle tracking an object on the host store (if we made room), or it
Expand Down Expand Up @@ -160,13 +165,27 @@ trait SpillableHandle extends StoreHandle {
/**
* Method called to spill this handle. It can be triggered from the spill store,
* or directly against the handle.
*
* This will not free the spilled data. If you would like to free the spill
* call `releaseSpilled`
*
* @note The size returned from this method is only used by the spill framework
* to track the approximate size. It should just return `approxSizeInBytes`, as
* that's the size that it used when it first started tracking the object.
* @return approxSizeInBytes if spilled, 0 for any other reason (not spillable, closed)
*/
def spill(): Long

/**
* Part two of the two-stage process for spilling. We call `releaseSpilled` after
* a handle has spilled, and after a device synchronize. This prevents a race
* between threads working on cuDF kernels, that did not synchronize while holding the
* materialized handle's refCount, and the spiller thread (the spiller thread cannot
* free a device buffer that the worker thread isn't done with).
* See https://github.com/NVIDIA/spark-rapids/issues/8610 for more info.
*/
def releaseSpilled(): Unit

/**
* Method used to determine whether a handle tracks an object that could be spilled
* @note At the level of `SpillableHandle`, the only requirement of spillability
Expand All @@ -188,6 +207,18 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle {
private[spill] override def spillable: Boolean = synchronized {
super.spillable && dev.isDefined
}

protected def releaseDeviceResource(): Unit = {
SpillFramework.removeFromDeviceStore(this)
synchronized {
dev.foreach(_.close())
dev = None
}
}

override def releaseSpilled(): Unit = {
releaseDeviceResource()
}
}

/**
Expand All @@ -201,6 +232,18 @@ trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle {
private[spill] override def spillable: Boolean = synchronized {
super.spillable && host.isDefined
}

protected def releaseHostResource(): Unit = {
SpillFramework.removeFromHostStore(this)
synchronized {
host.foreach(_.close())
host = None
}
}

override def releaseSpilled(): Unit = {
releaseHostResource()
}
}

object SpillableHostBufferHandle extends Logging {
Expand Down Expand Up @@ -284,7 +327,7 @@ class SpillableHostBufferHandle private (
if (!spillable) {
0L
} else {
val spilled = synchronized {
synchronized {
if (disk.isEmpty && host.isDefined) {
withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder =>
val outputChannel = diskHandleBuilder.getChannel
Expand All @@ -307,19 +350,9 @@ class SpillableHostBufferHandle private (
0L
}
}
releaseHostResource()
spilled
}
}

private def releaseHostResource(): Unit = {
SpillFramework.removeFromHostStore(this)
synchronized {
host.foreach(_.close())
host = None
}
}

override def close(): Unit = {
releaseHostResource()
synchronized {
Expand Down Expand Up @@ -399,12 +432,12 @@ class SpillableDeviceBufferHandle private (
var materialized: DeviceMemoryBuffer = null
var hostHandle: SpillableHostBufferHandle = null
synchronized {
if (dev.isDefined) {
materialized = dev.get
materialized.incRefCount()
} else if (host.isDefined) {
if (host.isDefined) {
// since we spilled, host must be set.
hostHandle = host.get
} else if (dev.isDefined) {
materialized = dev.get
materialized.incRefCount()
} else {
throw new IllegalStateException(
"attempting to materialize a closed handle")
Expand All @@ -426,29 +459,19 @@ class SpillableDeviceBufferHandle private (
if (!spillable) {
0L
} else {
val spilled = synchronized {
synchronized {
if (host.isEmpty && dev.isDefined) {
host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(dev.get))
sizeInBytes
} else {
0L
}
}
releaseDeviceResources()
spilled
}
}

private def releaseDeviceResources(): Unit = {
SpillFramework.removeFromDeviceStore(this)
synchronized {
dev.foreach(_.close())
dev = None
}
}

override def close(): Unit = {
releaseDeviceResources()
releaseDeviceResource()
synchronized {
host.foreach(_.close())
host = None
Expand Down Expand Up @@ -483,10 +506,10 @@ class SpillableColumnarBatchHandle private (
var materialized: ColumnarBatch = null
var hostHandle: SpillableHostBufferHandle = null
synchronized {
if (dev.isDefined) {
materialized = GpuColumnVector.incRefCounts(dev.get)
} else if (host.isDefined) {
if (host.isDefined) {
hostHandle = host.get
} else if (dev.isDefined) {
materialized = GpuColumnVector.incRefCounts(dev.get)
} else {
throw new IllegalStateException(
"attempting to materialize a closed handle")
Expand All @@ -511,22 +534,20 @@ class SpillableColumnarBatchHandle private (
if (!spillable) {
0L
} else {
val spilled = synchronized {
synchronized {
if (host.isEmpty && dev.isDefined) {
withChunkedPacker { chunkedPacker =>
meta = Some(chunkedPacker.getPackedMeta)
host = Some(SpillableHostBufferHandle.createHostHandleWithPacker(chunkedPacker))
}
// We return the size we were created with. This is not the actual size
// of this batch when it is packed, and it is used by the calling code
// to figure out more or less how much did we free in the device.
approxSizeInBytes
} else {
0L
}
}
releaseDeviceResource()
// We return the size we were created with. This is not the actual size
// of this batch when it is packed, and it is used by the calling code
// to figure out more or less how much did we free in the device.
spilled
}
}

Expand All @@ -544,14 +565,6 @@ class SpillableColumnarBatchHandle private (
}
}

private def releaseDeviceResource(): Unit = {
SpillFramework.removeFromDeviceStore(this)
synchronized {
dev.foreach(_.close())
dev = None
}
}

override def close(): Unit = {
releaseDeviceResource()
synchronized {
Expand Down Expand Up @@ -616,10 +629,10 @@ class SpillableColumnarBatchFromBufferHandle private (
var materialized: ColumnarBatch = null
var hostHandle: SpillableHostBufferHandle = null
synchronized {
if (dev.isDefined) {
materialized = GpuColumnVector.incRefCounts(dev.get)
} else if (host.isDefined) {
if (host.isDefined) {
hostHandle = host.get
} else if (dev.isDefined) {
materialized = GpuColumnVector.incRefCounts(dev.get)
} else {
throw new IllegalStateException(
"attempting to materialize a closed handle")
Expand All @@ -644,7 +657,7 @@ class SpillableColumnarBatchFromBufferHandle private (
if (!spillable) {
0
} else {
val spilled = synchronized {
synchronized {
if (host.isEmpty && dev.isDefined) {
val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuColumnVectorFromBuffer]
meta = Some(cvFromBuffer.getTableMeta)
Expand All @@ -655,16 +668,6 @@ class SpillableColumnarBatchFromBufferHandle private (
0L
}
}
releaseDeviceResource()
spilled
}
}

private def releaseDeviceResource(): Unit = {
SpillFramework.removeFromDeviceStore(this)
synchronized {
dev.foreach(_.close())
dev = None
}
}

Expand Down Expand Up @@ -713,10 +716,10 @@ class SpillableCompressedColumnarBatchHandle private (
var materialized: ColumnarBatch = null
var hostHandle: SpillableHostBufferHandle = null
synchronized {
if (dev.isDefined) {
materialized = GpuCompressedColumnVector.incRefCounts(dev.get)
} else if (host.isDefined) {
if (host.isDefined) {
hostHandle = host.get
} else if (dev.isDefined) {
materialized = GpuCompressedColumnVector.incRefCounts(dev.get)
} else {
throw new IllegalStateException(
"attempting to materialize a closed handle")
Expand All @@ -738,7 +741,7 @@ class SpillableCompressedColumnarBatchHandle private (
if (!spillable) {
0L
} else {
val spilled = synchronized {
synchronized {
if (host.isEmpty && dev.isDefined) {
val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuCompressedColumnVector]
meta = Some(cvFromBuffer.getTableMeta)
Expand All @@ -749,16 +752,6 @@ class SpillableCompressedColumnarBatchHandle private (
0L
}
}
releaseDeviceResource()
spilled
}
}

private def releaseDeviceResource(): Unit = {
SpillFramework.removeFromDeviceStore(this)
synchronized {
dev.foreach(_.close())
dev = None
}
}

Expand Down Expand Up @@ -832,7 +825,7 @@ class SpillableHostColumnarBatchHandle private (
if (!spillable) {
0L
} else {
val spilled = synchronized {
synchronized {
if (disk.isEmpty && host.isDefined) {
withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder =>
GpuTaskMetrics.get.spillToDiskTime {
Expand All @@ -847,16 +840,6 @@ class SpillableHostColumnarBatchHandle private (
0L
}
}
releaseHostResource()
spilled
}
}

private def releaseHostResource(): Unit = {
SpillFramework.removeFromHostStore(this)
synchronized {
host.foreach(_.close())
host = None
}
}

Expand Down Expand Up @@ -1019,11 +1002,19 @@ trait SpillableStore extends HandleStore[SpillableHandle] with Logging {
// this thread was successful at spilling handle.
amountSpilled += spilled
numSpilled += 1
} // else, either:
// - this thread lost the race and the handle was closed
// - another thread spilled it
// - the handle isn't spillable anymore, due to ref count.
} else {
// else, either:
// - this thread lost the race and the handle was closed
// - another thread spilled it
// - the handle isn't spillable anymore, due to ref count.
it.remove()
}
}
// spillables is the list of handles that have to be closed
// we synchronize every thread before we release what was spilled
Cuda.deviceSynchronize()
// this is safe to be called unconditionally if another thread spilled
spillables.forEach(_.releaseSpilled())

amountSpilled
}
Expand Down

0 comments on commit 8e8aefc

Please sign in to comment.