From 8e8aefcbbfff20ccf7d61b5d296cbb712e98820f Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Fri, 6 Dec 2024 08:23:29 -0800 Subject: [PATCH] Break down spill into a two-step process, so we can inject a device synchronize between --- .../spark/rapids/spill/SpillFramework.scala | 165 +++++++++--------- 1 file changed, 78 insertions(+), 87 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 4165b1a6cc0..f0ba32a9b89 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -99,6 +99,11 @@ import org.apache.spark.storage.BlockId * responsibility to initiate that spill, and to track the spilled handle (a device spillable * would have a `host` handle, which tracks the host spilled object). * + * Spill is broken down into two methods: `spill` and `releaseSpilled`. This is a two stage + * process because we need to make sure that there is no code running kernels on the spilled + * data before we actually free it. See method documentations for `spill` and `releasedSpilled` + * for more info. + * * A cascade of spills can occur device -> host -> disk, given that host allocations can fail, or * could not fit in the SpillableHostStore's limit (if defined). In this case, the call to spill * will either create a host handle tracking an object on the host store (if we made room), or it @@ -160,6 +165,10 @@ trait SpillableHandle extends StoreHandle { /** * Method called to spill this handle. It can be triggered from the spill store, * or directly against the handle. + * + * This will not free the spilled data. If you would like to free the spill + * call `releaseSpilled` + * * @note The size returned from this method is only used by the spill framework * to track the approximate size. It should just return `approxSizeInBytes`, as * that's the size that it used when it first started tracking the object. @@ -167,6 +176,16 @@ trait SpillableHandle extends StoreHandle { */ def spill(): Long + /** + * Part two of the two-stage process for spilling. We call `releaseSpilled` after + * a handle has spilled, and after a device synchronize. This prevents a race + * between threads working on cuDF kernels, that did not synchronize while holding the + * materialized handle's refCount, and the spiller thread (the spiller thread cannot + * free a device buffer that the worker thread isn't done with). + * See https://github.com/NVIDIA/spark-rapids/issues/8610 for more info. + */ + def releaseSpilled(): Unit + /** * Method used to determine whether a handle tracks an object that could be spilled * @note At the level of `SpillableHandle`, the only requirement of spillability @@ -188,6 +207,18 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { private[spill] override def spillable: Boolean = synchronized { super.spillable && dev.isDefined } + + protected def releaseDeviceResource(): Unit = { + SpillFramework.removeFromDeviceStore(this) + synchronized { + dev.foreach(_.close()) + dev = None + } + } + + override def releaseSpilled(): Unit = { + releaseDeviceResource() + } } /** @@ -201,6 +232,18 @@ trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle { private[spill] override def spillable: Boolean = synchronized { super.spillable && host.isDefined } + + protected def releaseHostResource(): Unit = { + SpillFramework.removeFromHostStore(this) + synchronized { + host.foreach(_.close()) + host = None + } + } + + override def releaseSpilled(): Unit = { + releaseHostResource() + } } object SpillableHostBufferHandle extends Logging { @@ -284,7 +327,7 @@ class SpillableHostBufferHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { + synchronized { if (disk.isEmpty && host.isDefined) { withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => val outputChannel = diskHandleBuilder.getChannel @@ -307,19 +350,9 @@ class SpillableHostBufferHandle private ( 0L } } - releaseHostResource() - spilled } } - private def releaseHostResource(): Unit = { - SpillFramework.removeFromHostStore(this) - synchronized { - host.foreach(_.close()) - host = None - } - } - override def close(): Unit = { releaseHostResource() synchronized { @@ -399,12 +432,12 @@ class SpillableDeviceBufferHandle private ( var materialized: DeviceMemoryBuffer = null var hostHandle: SpillableHostBufferHandle = null synchronized { - if (dev.isDefined) { - materialized = dev.get - materialized.incRefCount() - } else if (host.isDefined) { + if (host.isDefined) { // since we spilled, host must be set. hostHandle = host.get + } else if (dev.isDefined) { + materialized = dev.get + materialized.incRefCount() } else { throw new IllegalStateException( "attempting to materialize a closed handle") @@ -426,7 +459,7 @@ class SpillableDeviceBufferHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { + synchronized { if (host.isEmpty && dev.isDefined) { host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(dev.get)) sizeInBytes @@ -434,21 +467,11 @@ class SpillableDeviceBufferHandle private ( 0L } } - releaseDeviceResources() - spilled - } - } - - private def releaseDeviceResources(): Unit = { - SpillFramework.removeFromDeviceStore(this) - synchronized { - dev.foreach(_.close()) - dev = None } } override def close(): Unit = { - releaseDeviceResources() + releaseDeviceResource() synchronized { host.foreach(_.close()) host = None @@ -483,10 +506,10 @@ class SpillableColumnarBatchHandle private ( var materialized: ColumnarBatch = null var hostHandle: SpillableHostBufferHandle = null synchronized { - if (dev.isDefined) { - materialized = GpuColumnVector.incRefCounts(dev.get) - } else if (host.isDefined) { + if (host.isDefined) { hostHandle = host.get + } else if (dev.isDefined) { + materialized = GpuColumnVector.incRefCounts(dev.get) } else { throw new IllegalStateException( "attempting to materialize a closed handle") @@ -511,22 +534,20 @@ class SpillableColumnarBatchHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { + synchronized { if (host.isEmpty && dev.isDefined) { withChunkedPacker { chunkedPacker => meta = Some(chunkedPacker.getPackedMeta) host = Some(SpillableHostBufferHandle.createHostHandleWithPacker(chunkedPacker)) } + // We return the size we were created with. This is not the actual size + // of this batch when it is packed, and it is used by the calling code + // to figure out more or less how much did we free in the device. approxSizeInBytes } else { 0L } } - releaseDeviceResource() - // We return the size we were created with. This is not the actual size - // of this batch when it is packed, and it is used by the calling code - // to figure out more or less how much did we free in the device. - spilled } } @@ -544,14 +565,6 @@ class SpillableColumnarBatchHandle private ( } } - private def releaseDeviceResource(): Unit = { - SpillFramework.removeFromDeviceStore(this) - synchronized { - dev.foreach(_.close()) - dev = None - } - } - override def close(): Unit = { releaseDeviceResource() synchronized { @@ -616,10 +629,10 @@ class SpillableColumnarBatchFromBufferHandle private ( var materialized: ColumnarBatch = null var hostHandle: SpillableHostBufferHandle = null synchronized { - if (dev.isDefined) { - materialized = GpuColumnVector.incRefCounts(dev.get) - } else if (host.isDefined) { + if (host.isDefined) { hostHandle = host.get + } else if (dev.isDefined) { + materialized = GpuColumnVector.incRefCounts(dev.get) } else { throw new IllegalStateException( "attempting to materialize a closed handle") @@ -644,7 +657,7 @@ class SpillableColumnarBatchFromBufferHandle private ( if (!spillable) { 0 } else { - val spilled = synchronized { + synchronized { if (host.isEmpty && dev.isDefined) { val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuColumnVectorFromBuffer] meta = Some(cvFromBuffer.getTableMeta) @@ -655,16 +668,6 @@ class SpillableColumnarBatchFromBufferHandle private ( 0L } } - releaseDeviceResource() - spilled - } - } - - private def releaseDeviceResource(): Unit = { - SpillFramework.removeFromDeviceStore(this) - synchronized { - dev.foreach(_.close()) - dev = None } } @@ -713,10 +716,10 @@ class SpillableCompressedColumnarBatchHandle private ( var materialized: ColumnarBatch = null var hostHandle: SpillableHostBufferHandle = null synchronized { - if (dev.isDefined) { - materialized = GpuCompressedColumnVector.incRefCounts(dev.get) - } else if (host.isDefined) { + if (host.isDefined) { hostHandle = host.get + } else if (dev.isDefined) { + materialized = GpuCompressedColumnVector.incRefCounts(dev.get) } else { throw new IllegalStateException( "attempting to materialize a closed handle") @@ -738,7 +741,7 @@ class SpillableCompressedColumnarBatchHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { + synchronized { if (host.isEmpty && dev.isDefined) { val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuCompressedColumnVector] meta = Some(cvFromBuffer.getTableMeta) @@ -749,16 +752,6 @@ class SpillableCompressedColumnarBatchHandle private ( 0L } } - releaseDeviceResource() - spilled - } - } - - private def releaseDeviceResource(): Unit = { - SpillFramework.removeFromDeviceStore(this) - synchronized { - dev.foreach(_.close()) - dev = None } } @@ -832,7 +825,7 @@ class SpillableHostColumnarBatchHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { + synchronized { if (disk.isEmpty && host.isDefined) { withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => GpuTaskMetrics.get.spillToDiskTime { @@ -847,16 +840,6 @@ class SpillableHostColumnarBatchHandle private ( 0L } } - releaseHostResource() - spilled - } - } - - private def releaseHostResource(): Unit = { - SpillFramework.removeFromHostStore(this) - synchronized { - host.foreach(_.close()) - host = None } } @@ -1019,11 +1002,19 @@ trait SpillableStore extends HandleStore[SpillableHandle] with Logging { // this thread was successful at spilling handle. amountSpilled += spilled numSpilled += 1 - } // else, either: - // - this thread lost the race and the handle was closed - // - another thread spilled it - // - the handle isn't spillable anymore, due to ref count. + } else { + // else, either: + // - this thread lost the race and the handle was closed + // - another thread spilled it + // - the handle isn't spillable anymore, due to ref count. + it.remove() + } } + // spillables is the list of handles that have to be closed + // we synchronize every thread before we release what was spilled + Cuda.deviceSynchronize() + // this is safe to be called unconditionally if another thread spilled + spillables.forEach(_.releaseSpilled()) amountSpilled }