diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java index f2be4264162..b72a389d2a0 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java @@ -1101,6 +1101,10 @@ public final int numNulls() { public static long getTotalDeviceMemoryUsed(ColumnarBatch batch) { long sum = 0; + if (batch.numCols() == 1 && batch.column(0) instanceof GpuPackedTableColumn) { + // this is a special case for a packed batch + return ((GpuPackedTableColumn) batch.column(0)).getTableBuffer().getLength(); + } if (batch.numCols() > 0) { if (batch.column(0) instanceof WithTableBuffer) { WithTableBuffer wtb = (WithTableBuffer) batch.column(0); diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java index e23fa76c9f3..b5ed621821b 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -113,6 +113,29 @@ public GpuColumnVectorFromBuffer(DataType type, ColumnVector cudfColumn, this.tableMeta = meta; } + public static boolean isFromBuffer(ColumnarBatch cb) { + if (cb.numCols() > 0) { + long bufferAddr = 0L; + boolean isSet = false; + for (int i = 0; i < cb.numCols(); ++i) { + GpuColumnVectorFromBuffer gcvfb = null; + if (!(cb.column(i) instanceof GpuColumnVectorFromBuffer)) { + return false; + } else { + gcvfb = (GpuColumnVectorFromBuffer) cb.column(i); + if (!isSet) { + bufferAddr = gcvfb.buffer.getAddress(); + isSet = true; + } else if (bufferAddr != gcvfb.buffer.getAddress()) { + return false; + } + } + } + return true; + } + return false; + } + /** * Get the underlying contiguous buffer, shared between columns of the original * `ContiguousTable` diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuCompressedColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuCompressedColumnVector.java index cd34f35ecab..1dc85cb2031 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuCompressedColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuCompressedColumnVector.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,6 +47,15 @@ public static boolean isBatchCompressed(ColumnarBatch batch) { return batch.numCols() == 1 && batch.column(0) instanceof GpuCompressedColumnVector; } + public static ColumnarBatch incRefCounts(ColumnarBatch batch) { + if (!isBatchCompressed(batch)) { + throw new IllegalStateException( + "Attempted to incRefCount for a compressed batch, but the batch was not compressed."); + } + ((GpuCompressedColumnVector)batch.column(0)).buffer.incRefCount(); + return batch; + } + /** * Build a columnar batch from a compressed data buffer and specified table metadata * NOTE: The data remains compressed and cannot be accessed directly from the columnar batch. diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java index 0aa3f0978e9..400b54626d8 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java @@ -21,19 +21,11 @@ import java.util.NoSuchElementException; import java.util.Optional; -import com.nvidia.spark.Retryable; import scala.Option; import scala.Tuple2; import scala.collection.Iterator; -import ai.rapids.cudf.ColumnVector; -import ai.rapids.cudf.DType; -import ai.rapids.cudf.HostColumnVector; -import ai.rapids.cudf.HostColumnVectorCore; -import ai.rapids.cudf.HostMemoryBuffer; -import ai.rapids.cudf.NvtxColor; -import ai.rapids.cudf.NvtxRange; -import ai.rapids.cudf.Table; +import ai.rapids.cudf.*; import com.nvidia.spark.rapids.jni.RowConversion; import com.nvidia.spark.rapids.shims.CudfUnsafeRow; @@ -236,8 +228,7 @@ private HostMemoryBuffer[] getHostBuffersWithRetry( try { hBuf = HostAlloc$.MODULE$.alloc((dataBytes + offsetBytes),true); SpillableHostBuffer sBuf = SpillableHostBuffer$.MODULE$.apply(hBuf, hBuf.getLength(), - SpillPriorities$.MODULE$.ACTIVE_ON_DECK_PRIORITY(), - RapidsBufferCatalog$.MODULE$.singleton()); + SpillPriorities$.MODULE$.ACTIVE_ON_DECK_PRIORITY()); hBuf = null; // taken over by spillable host buffer return Tuple2.apply(sBuf, numRowsWrapper); } finally { diff --git a/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala b/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala index eddad69ba97..89e717788f9 100644 --- a/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala +++ b/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala @@ -63,26 +63,6 @@ object RapidsPluginImplicits { } } - implicit class RapidsBufferColumn(rapidsBuffer: RapidsBuffer) { - - /** - * safeFree: Is an implicit on RapidsBuffer class that tries to free the resource, if an - * Exception was thrown prior to this free, it adds the new exception to the suppressed - * exceptions, otherwise just throws - * - * @param e Exception which we don't want to suppress - */ - def safeFree(e: Throwable = null): Unit = { - if (rapidsBuffer != null) { - try { - rapidsBuffer.free() - } catch { - case suppressed: Throwable if e != null => e.addSuppressed(suppressed) - } - } - } - } - implicit class AutoCloseableSeq[A <: AutoCloseable](val in: collection.SeqLike[A, _]) { /** * safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each @@ -111,46 +91,12 @@ object RapidsPluginImplicits { } } - implicit class RapidsBufferSeq[A <: RapidsBuffer](val in: collection.SeqLike[A, _]) { - /** - * safeFree: Is an implicit on a sequence of RapidsBuffer classes that tries to free each - * element of the sequence, even if prior free calls fail. In case of failure in any of the - * free calls, an Exception is thrown containing the suppressed exceptions (getSuppressed), - * if any. - */ - def safeFree(error: Throwable = null): Unit = if (in != null) { - var freeException: Throwable = null - in.foreach { element => - if (element != null) { - try { - element.free() - } catch { - case e: Throwable if error != null => error.addSuppressed(e) - case e: Throwable if freeException == null => freeException = e - case e: Throwable => freeException.addSuppressed(e) - } - } - } - if (freeException != null) { - // an exception happened while we were trying to safely free - // resources, throw the exception to alert the caller - throw freeException - } - } - } - implicit class AutoCloseableArray[A <: AutoCloseable](val in: Array[A]) { def safeClose(e: Throwable = null): Unit = if (in != null) { in.toSeq.safeClose(e) } } - implicit class RapidsBufferArray[A <: RapidsBuffer](val in: Array[A]) { - def safeFree(e: Throwable = null): Unit = if (in != null) { - in.toSeq.safeFree(e) - } - } - class MapsSafely[A, Repr] { /** * safeMap: safeMap implementation that is leveraged by other type-specific implicits. diff --git a/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala b/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala index 5bdded6dbd4..1e4c5e39a19 100644 --- a/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala +++ b/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,26 +63,6 @@ object RapidsPluginImplicits { } } - implicit class RapidsBufferColumn(rapidsBuffer: RapidsBuffer) { - - /** - * safeFree: Is an implicit on RapidsBuffer class that tries to free the resource, if an - * Exception was thrown prior to this free, it adds the new exception to the suppressed - * exceptions, otherwise just throws - * - * @param e Exception which we don't want to suppress - */ - def safeFree(e: Throwable = null): Unit = { - if (rapidsBuffer != null) { - try { - rapidsBuffer.free() - } catch { - case suppressed: Throwable if e != null => e.addSuppressed(suppressed) - } - } - } - } - implicit class AutoCloseableSeq[A <: AutoCloseable](val in: collection.Iterable[A]) { /** * safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each @@ -111,46 +91,12 @@ object RapidsPluginImplicits { } } - implicit class RapidsBufferSeq[A <: RapidsBuffer](val in: collection.SeqLike[A, _]) { - /** - * safeFree: Is an implicit on a sequence of RapidsBuffer classes that tries to free each - * element of the sequence, even if prior free calls fail. In case of failure in any of the - * free calls, an Exception is thrown containing the suppressed exceptions (getSuppressed), - * if any. - */ - def safeFree(error: Throwable = null): Unit = if (in != null) { - var freeException: Throwable = null - in.foreach { element => - if (element != null) { - try { - element.free() - } catch { - case e: Throwable if error != null => error.addSuppressed(e) - case e: Throwable if freeException == null => freeException = e - case e: Throwable => freeException.addSuppressed(e) - } - } - } - if (freeException != null) { - // an exception happened while we were trying to safely free - // resources, throw the exception to alert the caller - throw freeException - } - } - } - implicit class AutoCloseableArray[A <: AutoCloseable](val in: Array[A]) { def safeClose(e: Throwable = null): Unit = if (in != null) { in.toSeq.safeClose(e) } } - implicit class RapidsBufferArray[A <: RapidsBuffer](val in: Array[A]) { - def safeFree(e: Throwable = null): Unit = if (in != null) { - in.toSeq.safeFree(e) - } - } - class IterableMapsSafely[A, From[A] <: collection.Iterable[A] with collection.IterableOps[A, From, _]] { /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala index b0cd798c179..fcf65e1bc00 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala @@ -177,20 +177,6 @@ object Arm extends ArmScalaSpecificImpl { } } - /** Executes the provided code block, freeing the RapidsBuffer only if an exception occurs */ - def freeOnExcept[T <: RapidsBuffer, V](r: T)(block: T => V): V = { - try { - block(r) - } catch { - case t: ControlThrowable => - // Don't close for these cases.. - throw t - case t: Throwable => - r.safeFree(t) - throw t - } - } - /** Executes the provided code block and then closes the resource */ def withResource[T <: AutoCloseable, V](h: CloseableHolder[T]) (block: CloseableHolder[T] => V): V = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala index 72808d1f376..9c867bb6a90 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import java.lang.management.ManagementFactory import java.util.concurrent.atomic.AtomicLong import ai.rapids.cudf.{Cuda, Rmm, RmmEventHandler} +import com.nvidia.spark.rapids.spill.SpillableDeviceStore import com.sun.management.HotSpotDiagnosticMXBean import org.apache.spark.internal.Logging @@ -34,8 +35,7 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil * depleting the device store */ class DeviceMemoryEventHandler( - catalog: RapidsBufferCatalog, - store: RapidsDeviceMemoryStore, + store: SpillableDeviceStore, oomDumpDir: Option[String], maxFailedOOMRetries: Int) extends RmmEventHandler with Logging { @@ -92,8 +92,8 @@ class DeviceMemoryEventHandler( * from cuDF. If we succeed, cuDF resets `retryCount`, and so the new count sent to us * must be <= than what we saw last, so we can reset our tracking. */ - def resetIfNeeded(retryCount: Int, storeSpillableSize: Long): Unit = { - if (storeSpillableSize != 0 || retryCount <= retryCountLastSynced) { + def resetIfNeeded(retryCount: Int, couldSpill: Boolean): Unit = { + if (couldSpill || retryCount <= retryCountLastSynced) { reset() } } @@ -114,9 +114,6 @@ class DeviceMemoryEventHandler( s"onAllocFailure invoked with invalid retryCount $retryCount") try { - val storeSize = store.currentSize - val storeSpillableSize = store.currentSpillableSize - val attemptMsg = if (retryCount > 0) { s"Attempt ${retryCount}. " } else { @@ -124,12 +121,13 @@ class DeviceMemoryEventHandler( } val retryState = oomRetryState.get() - retryState.resetIfNeeded(retryCount, storeSpillableSize) - logInfo(s"Device allocation of $allocSize bytes failed, device store has " + - s"$storeSize total and $storeSpillableSize spillable bytes. $attemptMsg" + - s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes. ") - if (storeSpillableSize == 0) { + val amountSpilled = store.spill(allocSize) + retryState.resetIfNeeded(retryCount, amountSpilled > 0) + logInfo(s"Device allocation of $allocSize bytes failed. " + + s"Device store spilled $amountSpilled bytes. $attemptMsg" + + s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.") + if (amountSpilled == 0) { if (retryState.shouldTrySynchronizing(retryCount)) { Cuda.deviceSynchronize() logWarning(s"[RETRY ${retryState.getRetriesSoFar}] " + @@ -149,13 +147,7 @@ class DeviceMemoryEventHandler( false } } else { - val targetSize = Math.max(storeSpillableSize - allocSize, 0) - logDebug(s"Targeting device store size of $targetSize bytes") - val maybeAmountSpilled = catalog.synchronousSpill(store, targetSize, Cuda.DEFAULT_STREAM) - maybeAmountSpilled.foreach { amountSpilled => - logInfo(s"Spilled $amountSpilled bytes from the device store") - TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled) - } + TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled) true } } catch { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala index b0c86773166..42776a6cab0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import ai.rapids.cudf._ +import com.nvidia.spark.rapids.jni.RmmSpark +import com.nvidia.spark.rapids.spill.SpillFramework import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging @@ -169,7 +171,9 @@ object GpuDeviceManager extends Logging { chunkedPackMemoryResource = None poolSizeLimit = 0L - RapidsBufferCatalog.close() + SpillFramework.shutdown() + RmmSpark.clearEventHandler() + Rmm.clearEventHandler() GpuShuffleEnv.shutdown() // try to avoid segfault on RMM shutdown val timeout = System.nanoTime() + TimeUnit.SECONDS.toNanos(10) @@ -278,6 +282,8 @@ object GpuDeviceManager extends Logging { } } + private var memoryEventHandler: DeviceMemoryEventHandler = _ + private def initializeRmm(gpuId: Int, rapidsConf: Option[RapidsConf]): Unit = { if (!Rmm.isInitialized) { val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf)) @@ -385,8 +391,25 @@ object GpuDeviceManager extends Logging { } } - RapidsBufferCatalog.init(conf) - GpuShuffleEnv.init(conf, RapidsBufferCatalog.getDiskBlockManager()) + SpillFramework.initialize(conf) + + memoryEventHandler = new DeviceMemoryEventHandler( + SpillFramework.stores.deviceStore, + conf.gpuOomDumpDir, + conf.gpuOomMaxRetries) + + if (conf.sparkRmmStateEnable) { + val debugLoc = if (conf.sparkRmmDebugLocation.isEmpty) { + null + } else { + conf.sparkRmmDebugLocation + } + RmmSpark.setEventHandler(memoryEventHandler, debugLoc) + } else { + logWarning("SparkRMM retry has been disabled") + Rmm.setEventHandler(memoryEventHandler) + } + GpuShuffleEnv.init(conf) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala index 6a34d15dc6e..6079c0352df 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{DefaultHostMemoryAllocator, HostMemoryAllocator, HostMemoryBuffer, MemoryBuffer, PinnedMemoryPool} import com.nvidia.spark.rapids.jni.{CpuRetryOOM, RmmSpark} +import com.nvidia.spark.rapids.spill.SpillFramework import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -137,9 +138,7 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L require(retryCount >= 0, s"spillAndCheckRetry invoked with invalid retryCount $retryCount") - val store = RapidsBufferCatalog.getHostStorage - val storeSize = store.currentSize - val storeSpillableSize = store.currentSpillableSize + val store = SpillFramework.stores.hostStore val totalSize: Long = synchronized { currentPinnedAllocated + currentNonPinnedAllocated } @@ -150,21 +149,20 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L "First attempt" } - logInfo(s"Host allocation of $allocSize bytes failed, host store has " + - s"$storeSize total and $storeSpillableSize spillable bytes. $attemptMsg.") - if (storeSpillableSize == 0) { - logWarning(s"Host store exhausted, unable to allocate $allocSize bytes. " + - s"Total host allocated is $totalSize bytes.") - false - } else { - val targetSize = Math.max(storeSpillableSize - allocSize, 0) - logDebug(s"Targeting host store size of $targetSize bytes") - // We could not make it work so try and spill enough to make it work - val maybeAmountSpilled = - RapidsBufferCatalog.synchronousSpill(RapidsBufferCatalog.getHostStorage, targetSize) - maybeAmountSpilled.foreach { amountSpilled => - logInfo(s"Spilled $amountSpilled bytes from the host store") + val amountSpilled = store.spill(allocSize) + + if (amountSpilled == 0) { + val shouldRetry = store.numHandles > 0 + val exhaustedMsg = s"Host store exhausted, unable to allocate $allocSize bytes. " + + s"Total host allocated is $totalSize bytes. $attemptMsg." + if (!shouldRetry) { + logWarning(exhaustedMsg) + } else { + logWarning(s"$exhaustedMsg Attempting a retry.") } + shouldRetry + } else { + logInfo(s"Spilled $amountSpilled bytes from the host store") true } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala index 80acddcb257..f1561e2c251 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,6 +51,13 @@ object MetaUtils { ct.getMetadataDirectBuffer, ct.getRowCount) + def buildTableMeta(tableId: Int, compressed: GpuCompressedColumnVector): TableMeta = + buildTableMeta( + tableId, + compressed.getTableBuffer.getLength, + compressed.getTableMeta.bufferMeta().getByteBuffer, + compressed.getTableMeta.rowCount()) + def buildTableMeta(tableId: Int, bufferSize: Long, packedMeta: ByteBuffer, rowCount: Long): TableMeta = { val fbb = new FlatBufferBuilder(1024) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala deleted file mode 100644 index a332755745f..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala +++ /dev/null @@ -1,485 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.File -import java.nio.channels.WritableByteChannel - -import scala.collection.mutable.ArrayBuffer - -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Table} -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.rapids.RapidsDiskBlockManager -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * An identifier for a RAPIDS buffer that can be automatically spilled between buffer stores. - * NOTE: Derived classes MUST implement proper hashCode and equals methods, as these objects are - * used as keys in hash maps. Scala case classes are recommended. - */ -trait RapidsBufferId { - val tableId: Int - - /** - * Indicates whether the buffer may share a spill file with other buffers. - * If false then the spill file will be automatically removed when the buffer is freed. - * If true then the spill file will not be automatically removed, and another subsystem needs - * to be responsible for cleaning up the spill files for those types of buffers. - */ - val canShareDiskPaths: Boolean = false - - /** - * Generate a path to a local file that can be used to spill the corresponding buffer to disk. - * The path must be unique across all buffers unless canShareDiskPaths is true. - */ - def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File -} - -/** Enumeration of the storage tiers */ -object StorageTier extends Enumeration { - type StorageTier = Value - val DEVICE: StorageTier = Value(0, "device memory") - val HOST: StorageTier = Value(1, "host memory") - val DISK: StorageTier = Value(2, "local disk") -} - -/** - * ChunkedPacker is an Iterator that uses a cudf::chunked_pack to copy a cuDF `Table` - * to a target buffer in chunks. - * - * Each chunk is sized at most `bounceBuffer.getLength`, and the caller should cudaMemcpy - * bytes from `bounceBuffer` to a target buffer after each call to `next()`. - * - * @note `ChunkedPacker` must be closed by the caller as it has GPU and host resources - * associated with it. - * - * @param id The RapidsBufferId for this pack operation to be included in the metadata - * @param table cuDF Table to chunk_pack - * @param bounceBuffer GPU memory to be used for packing. The buffer should be at least 1MB - * in length. - */ -class ChunkedPacker( - id: RapidsBufferId, - table: Table, - bounceBuffer: DeviceMemoryBuffer) - extends Iterator[MemoryBuffer] - with Logging - with AutoCloseable { - - private var closed: Boolean = false - - // When creating cudf::chunked_pack use a pool if available, otherwise default to the - // per-device memory resource - private val chunkedPack = { - val pool = GpuDeviceManager.chunkedPackMemoryResource - val cudfChunkedPack = try { - pool.flatMap { chunkedPool => - Some(table.makeChunkedPack(bounceBuffer.getLength, chunkedPool)) - } - } catch { - case _: OutOfMemoryError => - if (!ChunkedPacker.warnedAboutPoolFallback) { - ChunkedPacker.warnedAboutPoolFallback = true - logWarning( - s"OOM while creating chunked_pack using pool sized ${pool.map(_.getMaxSize)}B. " + - "Falling back to the per-device memory resource.") - } - None - } - - // if the pool is not configured, or we got an OOM, try again with the per-device pool - cudfChunkedPack.getOrElse { - table.makeChunkedPack(bounceBuffer.getLength) - } - } - - private val tableMeta = withResource(chunkedPack.buildMetadata()) { packedMeta => - MetaUtils.buildTableMeta( - id.tableId, - chunkedPack.getTotalContiguousSize, - packedMeta.getMetadataDirectBuffer, - table.getRowCount) - } - - // take out a lease on the bounce buffer - bounceBuffer.incRefCount() - - def getTotalContiguousSize: Long = chunkedPack.getTotalContiguousSize - - def getMeta: TableMeta = { - tableMeta - } - - override def hasNext: Boolean = synchronized { - if (closed) { - throw new IllegalStateException(s"ChunkedPacker for $id is closed") - } - chunkedPack.hasNext - } - - def next(): MemoryBuffer = synchronized { - if (closed) { - throw new IllegalStateException(s"ChunkedPacker for $id is closed") - } - val bytesWritten = chunkedPack.next(bounceBuffer) - // we increment the refcount because the caller has no idea where - // this memory came from, so it should close it. - bounceBuffer.slice(0, bytesWritten) - } - - override def close(): Unit = synchronized { - if (!closed) { - closed = true - val toClose = new ArrayBuffer[AutoCloseable]() - toClose.append(chunkedPack, bounceBuffer) - toClose.safeClose() - } - } -} - -object ChunkedPacker { - private var warnedAboutPoolFallback: Boolean = false -} - -/** - * This iterator encapsulates a buffer's internal `MemoryBuffer` access - * for spill reasons. Internally, there are two known implementations: - * - either this is a "single shot" copy, where the entirety of the `RapidsBuffer` is - * already represented as a single contiguous blob of memory, then the expectation - * is that this iterator is exhausted with a single call to `next` - * - or, we have a `RapidsBuffer` that isn't contiguous. This iteration will then - * drive a `ChunkedPacker` to pack the `RapidsBuffer`'s table as needed. The - * iterator will likely need several calls to `next` to be exhausted. - * - * @param buffer `RapidsBuffer` to copy out of its tier. - */ -class RapidsBufferCopyIterator(buffer: RapidsBuffer) - extends Iterator[MemoryBuffer] with AutoCloseable with Logging { - - private val chunkedPacker: Option[ChunkedPacker] = if (buffer.supportsChunkedPacker) { - Some(buffer.makeChunkedPacker) - } else { - None - } - def isChunked: Boolean = chunkedPacker.isDefined - - // this is used for the single shot case to flag when `next` is call - // to satisfy the Iterator interface - private var singleShotCopyHasNext: Boolean = false - private var singleShotBuffer: MemoryBuffer = _ - - if (!isChunked) { - singleShotCopyHasNext = true - singleShotBuffer = buffer.getMemoryBuffer - } - - override def hasNext: Boolean = - chunkedPacker.map(_.hasNext).getOrElse(singleShotCopyHasNext) - - override def next(): MemoryBuffer = { - require(hasNext, - "next called on exhausted iterator") - chunkedPacker.map(_.next()).getOrElse { - singleShotCopyHasNext = false - singleShotBuffer.slice(0, singleShotBuffer.getLength) - } - } - - def getTotalCopySize: Long = { - chunkedPacker - .map(_.getTotalContiguousSize) - .getOrElse(singleShotBuffer.getLength) - } - - override def close(): Unit = { - val toClose = new ArrayBuffer[AutoCloseable]() - toClose.appendAll(chunkedPacker) - toClose.appendAll(Option(singleShotBuffer)) - - toClose.safeClose() - } -} - -/** Interface provided by all types of RAPIDS buffers */ -trait RapidsBuffer extends AutoCloseable { - /** The buffer identifier for this buffer. */ - val id: RapidsBufferId - - /** - * The size of this buffer in bytes in its _current_ store. As the buffer goes through - * contiguous split (either added as a contiguous table already, or spilled to host), - * its size changes because contiguous_split adds its own alignment padding. - * - * @note Do not use this size to allocate a target buffer to copy, always use `getPackedSize.` - */ - val memoryUsedBytes: Long - - /** - * The size of this buffer if it has already gone through contiguous_split. - * - * @note Use this function when allocating a target buffer for spill or shuffle purposes. - */ - def getPackedSizeBytes: Long = memoryUsedBytes - - /** - * At spill time, obtain an iterator used to copy this buffer to a different tier. - */ - def getCopyIterator: RapidsBufferCopyIterator = - new RapidsBufferCopyIterator(this) - - /** Descriptor for how the memory buffer is formatted */ - def meta: TableMeta - - /** The storage tier for this buffer */ - val storageTier: StorageTier - - /** - * Get the columnar batch within this buffer. The caller must have - * successfully acquired the buffer beforehand. - * @param sparkTypes the spark data types the batch should have - * @see [[addReference]] - * @note It is the responsibility of the caller to close the batch. - * @note If the buffer is compressed data then the resulting batch will be built using - * `GpuCompressedColumnVector`, and it is the responsibility of the caller to deal - * with decompressing the data if necessary. - */ - def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch - - /** - * Get the host-backed columnar batch from this buffer. The caller must have - * successfully acquired the buffer beforehand. - * - * If this `RapidsBuffer` was added originally to the device tier, or if this is - * a just a buffer (not a batch), this function will throw. - * - * @param sparkTypes the spark data types the batch should have - * @see [[addReference]] - * @note It is the responsibility of the caller to close the batch. - */ - def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - throw new IllegalStateException(s"$this does not support host columnar batches.") - } - - /** - * Get the underlying memory buffer. This may be either a HostMemoryBuffer or a DeviceMemoryBuffer - * depending on where the buffer currently resides. - * The caller must have successfully acquired the buffer beforehand. - * @see [[addReference]] - * @note It is the responsibility of the caller to close the buffer. - */ - def getMemoryBuffer: MemoryBuffer - - val supportsChunkedPacker: Boolean = false - - /** - * Makes a new chunked packer. It is the responsibility of the caller to close this. - */ - def makeChunkedPacker: ChunkedPacker = { - throw new NotImplementedError("not implemented for this store") - } - - /** - * Copy the content of this buffer into the specified memory buffer, starting from the given - * offset. - * - * @param srcOffset offset to start copying from. - * @param dst the memory buffer to copy into. - * @param dstOffset offset to copy into. - * @param length number of bytes to copy. - * @param stream CUDA stream to use - */ - def copyToMemoryBuffer( - srcOffset: Long, dst: MemoryBuffer, dstOffset: Long, length: Long, stream: Cuda.Stream): Unit - - /** - * Get the device memory buffer from the underlying storage. If the buffer currently resides - * outside of device memory, a new DeviceMemoryBuffer is created with the data copied over. - * The caller must have successfully acquired the buffer beforehand. - * @see [[addReference]] - * @note It is the responsibility of the caller to close the buffer. - */ - def getDeviceMemoryBuffer: DeviceMemoryBuffer - - /** - * Get the host memory buffer from the underlying storage. If the buffer currently resides - * outside of host memory, a new HostMemoryBuffer is created with the data copied over. - * The caller must have successfully acquired the buffer beforehand. - * @see [[addReference]] - * @note It is the responsibility of the caller to close the buffer. - */ - def getHostMemoryBuffer: HostMemoryBuffer - - /** - * Try to add a reference to this buffer to acquire it. - * @note The close method must be called for every successfully obtained reference. - * @return true if the reference was added or false if this buffer is no longer valid - */ - def addReference(): Boolean - - /** - * Schedule the release of the buffer's underlying resources. - * Subsequent attempts to acquire the buffer will fail. As soon as the - * buffer has no outstanding references, the resources will be released. - *
- * This is separate from the close method which does not normally release - * resources. close will only release resources if called as the last - * outstanding reference and the buffer was previously marked as freed. - */ - def free(): Unit - - /** - * Get the spill priority value for this buffer. Lower values are higher - * priority for spilling, meaning buffers with lower values will be - * preferred for spilling over buffers with a higher value. - */ - def getSpillPriority: Long - - /** - * Set the spill priority for this buffer. Lower values are higher priority - * for spilling, meaning buffers with lower values will be preferred for - * spilling over buffers with a higher value. - * @note should only be called from the buffer catalog - * @param priority new priority value for this buffer - */ - def setSpillPriority(priority: Long): Unit - - /** - * Function invoked by the `RapidsBufferStore.addBuffer` method that prompts - * the specific `RapidsBuffer` to check its reference counting to make itself - * spillable or not. Only `RapidsTable` and `RapidsHostMemoryBuffer` implement - * this method. - */ - def updateSpillability(): Unit = {} - - /** - * Obtains a read lock on this instance of `RapidsBuffer` and calls the function - * in `body` while holding the lock. - * @param body function that takes a `MemoryBuffer` and produces `K` - * @tparam K any return type specified by `body` - * @return the result of body(memoryBuffer) - */ - def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K - - /** - * Obtains a write lock on this instance of `RapidsBuffer` and calls the function - * in `body` while holding the lock. - * @param body function that takes a `MemoryBuffer` and produces `K` - * @tparam K any return type specified by `body` - * @return the result of body(memoryBuffer) - */ - def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K -} - -/** - * A buffer with no corresponding device data (zero rows or columns). - * These buffers are not tracked in buffer stores since they have no - * device memory. They are only tracked in the catalog and provide - * a representative `ColumnarBatch` but cannot provide a - * `MemoryBuffer`. - * @param id buffer ID to associate with the buffer - * @param meta schema metadata - */ -sealed class DegenerateRapidsBuffer( - override val id: RapidsBufferId, - override val meta: TableMeta) extends RapidsBuffer { - - override val memoryUsedBytes: Long = 0L - - override val storageTier: StorageTier = StorageTier.DEVICE - - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - val rowCount = meta.rowCount - val packedMeta = meta.packedMetaAsByteBuffer() - if (packedMeta != null) { - withResource(DeviceMemoryBuffer.allocate(0)) { deviceBuffer => - withResource(Table.fromPackedTable(meta.packedMetaAsByteBuffer(), deviceBuffer)) { table => - GpuColumnVectorFromBuffer.from(table, deviceBuffer, meta, sparkTypes) - } - } - } else { - // no packed metadata, must be a table with zero columns - new ColumnarBatch(Array.empty, rowCount.toInt) - } - } - - override def free(): Unit = {} - - override def getMemoryBuffer: MemoryBuffer = - throw new UnsupportedOperationException("degenerate buffer has no memory buffer") - - override def copyToMemoryBuffer(srcOffset: Long, dst: MemoryBuffer, dstOffset: Long, length: Long, - stream: Cuda.Stream): Unit = - throw new UnsupportedOperationException("degenerate buffer cannot copy to memory buffer") - - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = - throw new UnsupportedOperationException("degenerate buffer has no device memory buffer") - - override def getHostMemoryBuffer: HostMemoryBuffer = - throw new UnsupportedOperationException("degenerate buffer has no host memory buffer") - - override def addReference(): Boolean = true - - override def getSpillPriority: Long = Long.MaxValue - - override def setSpillPriority(priority: Long): Unit = {} - - override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { - throw new UnsupportedOperationException("degenerate buffer has no memory buffer") - } - - override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { - throw new UnsupportedOperationException("degenerate buffer has no memory buffer") - } - - override def close(): Unit = {} -} - -trait RapidsHostBatchBuffer extends AutoCloseable { - /** - * Get the host-backed columnar batch from this buffer. The caller must have - * successfully acquired the buffer beforehand. - * - * If this `RapidsBuffer` was added originally to the device tier, or if this is - * a just a buffer (not a batch), this function will throw. - * - * @param sparkTypes the spark data types the batch should have - * @see [[addReference]] - * @note It is the responsibility of the caller to close the batch. - */ - def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch - - val memoryUsedBytes: Long -} - -trait RapidsBufferChannelWritable { - /** - * At spill time, write this buffer to an nio WritableByteChannel. - * @param writableChannel that this buffer can just write itself to, either byte-for-byte - * or via serialization if needed. - * @param stream the Cuda.Stream for the spilling thread. If the `RapidsBuffer` that - * implements this method is on the device, synchronization may be needed - * for staged copies. - * @return the amount of bytes written to the channel - */ - def writeToChannel(writableChannel: WritableByteChannel, stream: Cuda.Stream): Long -} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala deleted file mode 100644 index f61291a31ce..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ /dev/null @@ -1,1005 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.util.concurrent.ConcurrentHashMap -import java.util.function.BiFunction - -import scala.collection.JavaConverters.collectionAsScalaIterableConverter - -import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Rmm, Table} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} -import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcquire -import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta -import com.nvidia.spark.rapids.jni.RmmSpark - -import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.internal.Logging -import org.apache.spark.sql.rapids.{RapidsDiskBlockManager, TempSpillBufferId} -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * Exception thrown when inserting a buffer into the catalog with a duplicate buffer ID - * and storage tier combination. - */ -class DuplicateBufferException(s: String) extends RuntimeException(s) {} - -/** - * An object that client code uses to interact with an underlying RapidsBufferId. - * - * A handle is obtained when a buffer, batch, or table is added to the spill framework - * via the `RapidsBufferCatalog` api. - */ -trait RapidsBufferHandle extends AutoCloseable { - val id: RapidsBufferId - - /** - * Sets the spill priority for this handle and updates the maximum priority - * for the underlying `RapidsBuffer` if this new priority is the maximum. - * @param newPriority new priority for this handle - */ - def setSpillPriority(newPriority: Long): Unit -} - -/** - * Catalog for lookup of buffers by ID. The constructor is only visible for testing, generally - * `RapidsBufferCatalog.singleton` should be used instead. - */ -class RapidsBufferCatalog( - deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.deviceStorage, - hostStorage: RapidsHostMemoryStore = RapidsBufferCatalog.hostStorage) - extends AutoCloseable with Logging { - - /** Map of buffer IDs to buffers sorted by storage tier */ - private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBuffer]] - - /** Map of buffer IDs to buffer handles in insertion order */ - private[this] val bufferIdToHandles = - new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBufferHandleImpl]]() - - /** A counter used to skip a spill attempt if we detect a different thread has spilled */ - @volatile private[this] var spillCount: Integer = 0 - - class RapidsBufferHandleImpl( - override val id: RapidsBufferId, - var priority: Long) - extends RapidsBufferHandle { - - private var closed = false - - override def toString: String = - s"buffer handle $id at $priority" - - override def setSpillPriority(newPriority: Long): Unit = { - priority = newPriority - updateUnderlyingRapidsBuffer(this) - } - - /** - * Get the spill priority that was associated with this handle. Since there can - * be multiple handles associated with one `RapidsBuffer`, the priority returned - * here is only useful for code in the catalog that updates the maximum priority - * for the underlying `RapidsBuffer` as handles are added and removed. - * - * @return this handle's spill priority - */ - def getSpillPriority: Long = priority - - override def close(): Unit = synchronized { - // since the handle is stored in the catalog in addition to being - // handed out to potentially a `SpillableColumnarBatch` or `SpillableBuffer` - // there is a chance we may double close it. For example, a broadcast exec - // that is closing its spillable (and therefore the handle) + the handle being - // closed from the catalog's close method. - if (!closed) { - removeBuffer(this) - } - closed = true - } - } - - /** - * Makes a new `RapidsBufferHandle` associated with `id`, keeping track - * of the spill priority and callback within this handle. - * - * This function also adds the handle for internal tracking in the catalog. - * - * @param id the `RapidsBufferId` that this handle refers to - * @param spillPriority the spill priority specified on creation of the handle - * @note public for testing - * @return a new instance of `RapidsBufferHandle` - */ - def makeNewHandle( - id: RapidsBufferId, - spillPriority: Long): RapidsBufferHandle = { - val handle = new RapidsBufferHandleImpl(id, spillPriority) - trackNewHandle(handle) - handle - } - - /** - * Adds a handle to the internal `bufferIdToHandles` map. - * - * The priority and callback of the `RapidsBuffer` will also be updated. - * - * @param handle handle to start tracking - */ - private def trackNewHandle(handle: RapidsBufferHandleImpl): Unit = { - bufferIdToHandles.compute(handle.id, (_, h) => { - var handles = h - if (handles == null) { - handles = Seq.empty[RapidsBufferHandleImpl] - } - handles :+ handle - }) - updateUnderlyingRapidsBuffer(handle) - } - - /** - * Called when the `RapidsBufferHandle` is no longer needed by calling code - * - * If this is the last handle associated with a `RapidsBuffer`, `stopTrackingHandle` - * returns true, otherwise it returns false. - * - * @param handle handle to stop tracking - * @return true: if this was the last `RapidsBufferHandle` associated with the - * underlying buffer. - * false: if there are remaining live handles - */ - private def stopTrackingHandle(handle: RapidsBufferHandle): Boolean = { - withResource(acquireBuffer(handle)) { buffer => - val id = handle.id - var maxPriority = Long.MinValue - val newHandles = bufferIdToHandles.compute(id, (_, handles) => { - if (handles == null) { - throw new IllegalStateException( - s"$id not found and we attempted to remove handles!") - } - if (handles.size == 1) { - require(handles.head == handle, - "Tried to remove a single handle, and we couldn't match on it") - null - } else { - val newHandles = handles.filter(h => h != handle).map { h => - maxPriority = maxPriority.max(h.getSpillPriority) - h - } - if (newHandles.isEmpty) { - null // remove since no more handles exist, should not happen - } else { - newHandles - } - } - }) - - if (newHandles == null) { - // tell calling code that no more handles exist, - // for this RapidsBuffer - true - } else { - // more handles remain, our priority changed so we need to update things - buffer.setSpillPriority(maxPriority) - false // we have handles left - } - } - } - - /** - * Adds a buffer to the catalog and store. This does NOT take ownership of the - * buffer, so it is the responsibility of the caller to close it. - * - * This version of `addBuffer` should not be called from the shuffle catalogs - * since they provide their own ids. - * - * @param buffer buffer that will be owned by the store - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return RapidsBufferHandle handle for this buffer - */ - def addBuffer( - buffer: MemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean = true): RapidsBufferHandle = synchronized { - // first time we see `buffer` - val existing = getExistingRapidsBufferAndAcquire(buffer) - existing match { - case None => - addBuffer( - TempSpillBufferId(), - buffer, - tableMeta, - initialSpillPriority, - needsSync) - case Some(rapidsBuffer) => - withResource(rapidsBuffer) { _ => - makeNewHandle(rapidsBuffer.id, initialSpillPriority) - } - } - } - - /** - * Adds a contiguous table to the device storage. This does NOT take ownership of the - * contiguous table, so it is the responsibility of the caller to close it. The refcount of the - * underlying device buffer will be incremented so the contiguous table can be closed before - * this buffer is destroyed. - * - * This version of `addContiguousTable` should not be called from the shuffle catalogs - * since they provide their own ids. - * - * @param contigTable contiguous table to track in storage - * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return RapidsBufferHandle handle for this table - */ - def addContiguousTable( - contigTable: ContiguousTable, - initialSpillPriority: Long, - needsSync: Boolean = true): RapidsBufferHandle = synchronized { - val existing = getExistingRapidsBufferAndAcquire(contigTable.getBuffer) - existing match { - case None => - addContiguousTable( - TempSpillBufferId(), - contigTable, - initialSpillPriority, - needsSync) - case Some(rapidsBuffer) => - withResource(rapidsBuffer) { _ => - makeNewHandle(rapidsBuffer.id, initialSpillPriority) - } - } - } - - /** - * Adds a contiguous table to the device storage. This does NOT take ownership of the - * contiguous table, so it is the responsibility of the caller to close it. The refcount of the - * underlying device buffer will be incremented so the contiguous table can be closed before - * this buffer is destroyed. - * - * @param id the RapidsBufferId to use for this buffer - * @param contigTable contiguous table to track in storage - * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return RapidsBufferHandle handle for this table - */ - def addContiguousTable( - id: RapidsBufferId, - contigTable: ContiguousTable, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = synchronized { - addBuffer( - id, - contigTable.getBuffer, - MetaUtils.buildTableMeta(id.tableId, contigTable), - initialSpillPriority, - needsSync) - } - - /** - * Adds a buffer to either the device or host storage. This does NOT take - * ownership of the buffer, so it is the responsibility of the caller to close it. - * - * @param id the RapidsBufferId to use for this buffer - * @param buffer buffer that will be owned by the target store - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this buffer (defaults to true) - * @return RapidsBufferHandle handle for this RapidsBuffer - */ - def addBuffer( - id: RapidsBufferId, - buffer: MemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = synchronized { - val rapidsBuffer = buffer match { - case gpuBuffer: DeviceMemoryBuffer => - deviceStorage.addBuffer( - id, - gpuBuffer, - tableMeta, - initialSpillPriority, - needsSync) - case hostBuffer: HostMemoryBuffer => - hostStorage.addBuffer( - id, - hostBuffer, - tableMeta, - initialSpillPriority, - needsSync) - case _ => - throw new IllegalArgumentException( - s"Cannot call addBuffer with buffer $buffer") - } - registerNewBuffer(rapidsBuffer) - makeNewHandle(id, initialSpillPriority) - } - - /** - * Adds a batch to the device storage. This does NOT take ownership of the - * batch, so it is the responsibility of the caller to close it. - * - * @param batch batch that will be added to the store - * @param initialSpillPriority starting spill priority value for the batch - * @param needsSync whether the spill framework should stream synchronize while adding - * this batch (defaults to true) - * @return RapidsBufferHandle handle for this RapidsBuffer - */ - def addBatch( - batch: ColumnarBatch, - initialSpillPriority: Long, - needsSync: Boolean = true): RapidsBufferHandle = { - require(batch.numCols() > 0, - "Cannot call addBatch with a batch that doesn't have columns") - batch.column(0) match { - case _: RapidsHostColumnVector => - addHostBatch(batch, initialSpillPriority, needsSync) - case _ => - closeOnExcept(GpuColumnVector.from(batch)) { table => - addTable(table, initialSpillPriority, needsSync) - } - } - } - - /** - * Adds a table to the device storage. - * - * This takes ownership of the table. The reason for this is that tables - * don't have a reference count, so we cannot cleanly capture ownership by increasing - * ref count and decreasing from the caller. - * - * @param table table that will be owned by the store - * @param initialSpillPriority starting spill priority value - * @param needsSync whether the spill framework should stream synchronize while adding - * this table (defaults to true) - * @return RapidsBufferHandle handle for this RapidsBuffer - */ - def addTable( - table: Table, - initialSpillPriority: Long, - needsSync: Boolean = true): RapidsBufferHandle = { - addTable(TempSpillBufferId(), table, initialSpillPriority, needsSync) - } - - /** - * Adds a table to the device storage. - * - * This takes ownership of the table. The reason for this is that tables - * don't have a reference count, so we cannot cleanly capture ownership by increasing - * ref count and decreasing from the caller. - * - * @param id specific RapidsBufferId to use for this table - * @param table table that will be owned by the store - * @param initialSpillPriority starting spill priority value - * @param needsSync whether the spill framework should stream synchronize while adding - * this table (defaults to true) - * @return RapidsBufferHandle handle for this RapidsBuffer - */ - def addTable( - id: RapidsBufferId, - table: Table, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = { - val rapidsBuffer = deviceStorage.addTable( - id, - table, - initialSpillPriority, - needsSync) - registerNewBuffer(rapidsBuffer) - makeNewHandle(id, initialSpillPriority) - } - - - /** - * Add a host-backed ColumnarBatch to the catalog. This is only called from addBatch - * after we detect that this is a host-backed batch. - */ - private def addHostBatch( - hostCb: ColumnarBatch, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = { - val id = TempSpillBufferId() - val rapidsBuffer = hostStorage.addBatch( - id, - hostCb, - initialSpillPriority, - needsSync) - registerNewBuffer(rapidsBuffer) - makeNewHandle(id, initialSpillPriority) - } - - /** - * Register a degenerate RapidsBufferId given a TableMeta - * @note this is called from the shuffle catalogs only - */ - def registerDegenerateBuffer( - bufferId: RapidsBufferId, - meta: TableMeta): RapidsBufferHandle = synchronized { - val buffer = new DegenerateRapidsBuffer(bufferId, meta) - registerNewBuffer(buffer) - makeNewHandle(buffer.id, buffer.getSpillPriority) - } - - /** - * Called by the catalog when a handle is first added to the catalog, or to refresh - * the priority of the underlying buffer if a handle's priority changed. - */ - private def updateUnderlyingRapidsBuffer(handle: RapidsBufferHandle): Unit = { - withResource(acquireBuffer(handle)) { buffer => - val handles = bufferIdToHandles.get(buffer.id) - val maxPriority = handles.map(_.getSpillPriority).max - // update the priority of the underlying RapidsBuffer to be the - // maximum priority for all handles associated with it - buffer.setSpillPriority(maxPriority) - } - } - - /** - * Lookup the buffer that corresponds to the specified handle at the highest storage tier, - * and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * @param handle handle associated with this `RapidsBuffer` - * @return buffer that has been acquired - */ - def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = { - val id = handle.id - def lookupAndReturn: Option[RapidsBuffer] = { - val buffers = bufferMap.get(id) - if (buffers == null || buffers.isEmpty) { - throw new NoSuchElementException( - s"Cannot locate buffers associated with ID: $id") - } - val buffer = buffers.head - if (buffer.addReference()) { - Some(buffer) - } else { - None - } - } - - // fast path - (0 until RapidsBufferCatalog.MAX_BUFFER_LOOKUP_ATTEMPTS).foreach { _ => - val mayBuffer = lookupAndReturn - if (mayBuffer.isDefined) { - return mayBuffer.get - } - } - - // try one last time after locking the catalog (slow path) - // if there is a lot of contention here, I would rather lock the world than - // have tasks error out with "Unable to acquire" - synchronized { - val mayBuffer = lookupAndReturn - if (mayBuffer.isDefined) { - return mayBuffer.get - } - } - throw new IllegalStateException(s"Unable to acquire buffer for ID: $id") - } - - /** - * Acquires a RapidsBuffer that the caller expects to be host-backed and not - * device bound. This ensures that the buffer acquired implements the correct - * trait, otherwise it throws and removes its buffer acquisition. - * - * @param handle handle associated with this `RapidsBuffer` - * @return host-backed RapidsBuffer that has been acquired - */ - def acquireHostBatchBuffer(handle: RapidsBufferHandle): RapidsHostBatchBuffer = { - closeOnExcept(acquireBuffer(handle)) { - case hrb: RapidsHostBatchBuffer => hrb - case other => - throw new IllegalStateException( - s"Attempted to acquire a RapidsHostBatchBuffer, but got $other instead") - } - } - - /** - * Lookup the buffer that corresponds to the specified buffer ID at the specified storage tier, - * and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * @param id buffer identifier - * @return buffer that has been acquired, None if not found - */ - def acquireBuffer(id: RapidsBufferId, tier: StorageTier): Option[RapidsBuffer] = { - val buffers = bufferMap.get(id) - if (buffers != null) { - buffers.find(_.storageTier == tier).foreach(buffer => - if (buffer.addReference()) { - return Some(buffer) - } - ) - } - None - } - - /** - * Check if the buffer that corresponds to the specified buffer ID is stored in a slower storage - * tier. - * - * @param id buffer identifier - * @param tier storage tier to check - * @note public for testing - * @return true if the buffer is stored in multiple tiers - */ - def isBufferSpilled(id: RapidsBufferId, tier: StorageTier): Boolean = { - val buffers = bufferMap.get(id) - buffers != null && buffers.exists(_.storageTier > tier) - } - - /** Get the table metadata corresponding to a buffer ID. */ - def getBufferMeta(id: RapidsBufferId): TableMeta = { - val buffers = bufferMap.get(id) - if (buffers == null || buffers.isEmpty) { - throw new NoSuchElementException(s"Cannot locate buffer associated with ID: $id") - } - buffers.head.meta - } - - /** - * Register a new buffer with the catalog. An exception will be thrown if an - * existing buffer was registered with the same buffer ID and storage tier. - * @note public for testing - */ - def registerNewBuffer(buffer: RapidsBuffer): Unit = { - val updater = new BiFunction[RapidsBufferId, Seq[RapidsBuffer], Seq[RapidsBuffer]] { - override def apply(key: RapidsBufferId, value: Seq[RapidsBuffer]): Seq[RapidsBuffer] = { - if (value == null) { - Seq(buffer) - } else { - val(first, second) = value.partition(_.storageTier < buffer.storageTier) - if (second.nonEmpty && second.head.storageTier == buffer.storageTier) { - throw new DuplicateBufferException( - s"Buffer ID ${buffer.id} at tier ${buffer.storageTier} already registered " + - s"${second.head}") - } - first ++ Seq(buffer) ++ second - } - } - } - - bufferMap.compute(buffer.id, updater) - } - - /** - * Free memory in `store` by spilling buffers to the spill store synchronously. - * @param store store to spill from - * @param targetTotalSize maximum total size of this store after spilling completes - * @param stream CUDA stream to use or omit for default stream - * @return optionally number of bytes that were spilled, or None if this call - * made no attempt to spill due to a detected spill race - */ - def synchronousSpill( - store: RapidsBufferStore, - targetTotalSize: Long, - stream: Cuda.Stream = Cuda.DEFAULT_STREAM): Option[Long] = { - if (store.spillStore == null) { - throw new OutOfMemoryError("Requested to spill without a spill store") - } - require(targetTotalSize >= 0, s"Negative spill target size: $targetTotalSize") - - val mySpillCount = spillCount - - // we have to hold this lock while freeing buffers, otherwise we could run - // into the case where a buffer is spilled yet it is aliased in addBuffer - // via an event handler that hasn't been reset (it resets during the free) - synchronized { - if (mySpillCount != spillCount) { - // a different thread already spilled, returning - // None which lets the calling code know that rmm should retry allocation - None - } else { - // this thread wins the race and should spill - spillCount += 1 - Some(store.synchronousSpill(targetTotalSize, this, stream)) - } - } - } - - def updateTiers(bufferSpill: SpillAction): Long = bufferSpill match { - case BufferSpill(spilledBuffer, maybeNewBuffer) => - logDebug(s"Spilled ${spilledBuffer.id} from tier ${spilledBuffer.storageTier}. " + - s"Removing. Registering ${maybeNewBuffer.map(_.id).getOrElse ("None")} " + - s"${maybeNewBuffer}") - maybeNewBuffer.foreach(registerNewBuffer) - removeBufferTier(spilledBuffer.id, spilledBuffer.storageTier) - spilledBuffer.memoryUsedBytes - - case BufferUnspill(unspilledBuffer, maybeNewBuffer) => - logDebug(s"Unspilled ${unspilledBuffer.id} from tier ${unspilledBuffer.storageTier}. " + - s"Removing. Registering ${maybeNewBuffer.map(_.id).getOrElse ("None")} " + - s"${maybeNewBuffer}") - maybeNewBuffer.foreach(registerNewBuffer) - removeBufferTier(unspilledBuffer.id, unspilledBuffer.storageTier) - unspilledBuffer.memoryUsedBytes - } - - /** - * Copies `buffer` to the `deviceStorage` store, registering a new `RapidsBuffer` in - * the process - * @param buffer - buffer to copy - * @param stream - Cuda.Stream to synchronize on - * @return - The `RapidsBuffer` instance that was added to the device store. - */ - def unspillBufferToDeviceStore( - buffer: RapidsBuffer, - stream: Cuda.Stream): RapidsBuffer = synchronized { - // try to acquire the buffer, if it's already in the store - // do not create a new one, else add a reference - acquireBuffer(buffer.id, StorageTier.DEVICE) match { - case None => - val maybeNewBuffer = deviceStorage.copyBuffer(buffer, this, stream) - maybeNewBuffer.map { newBuffer => - newBuffer.addReference() // add a reference since we are about to use it - registerNewBuffer(newBuffer) - newBuffer - }.get // the GPU store has to return a buffer here for now, or throw OOM - case Some(existingBuffer) => existingBuffer - } - } - - /** - * Copies `buffer` to the `hostStorage` store, registering a new `RapidsBuffer` in - * the process - * - * @param buffer - buffer to copy - * @param stream - Cuda.Stream to synchronize on - * @return - The `RapidsBuffer` instance that was added to the host store. - */ - def unspillBufferToHostStore( - buffer: RapidsBuffer, - stream: Cuda.Stream): RapidsBuffer = synchronized { - // try to acquire the buffer, if it's already in the store - // do not create a new one, else add a reference - acquireBuffer(buffer.id, StorageTier.HOST) match { - case Some(existingBuffer) => existingBuffer - case None => - val maybeNewBuffer = hostStorage.copyBuffer(buffer, this, stream) - maybeNewBuffer.map { newBuffer => - logDebug(s"got new RapidsHostMemoryStore buffer ${newBuffer.id}") - newBuffer.addReference() // add a reference since we are about to use it - updateTiers(BufferUnspill(buffer, Some(newBuffer))) - buffer.safeFree() - newBuffer - }.get // the host store has to return a buffer here for now, or throw OOM - } - } - - - /** - * Remove a buffer ID from the catalog at the specified storage tier. - * @note public for testing - */ - def removeBufferTier(id: RapidsBufferId, tier: StorageTier): Unit = synchronized { - val updater = new BiFunction[RapidsBufferId, Seq[RapidsBuffer], Seq[RapidsBuffer]] { - override def apply(key: RapidsBufferId, value: Seq[RapidsBuffer]): Seq[RapidsBuffer] = { - val updated = value.filter(_.storageTier != tier) - if (updated.isEmpty) { - null - } else { - updated - } - } - } - bufferMap.computeIfPresent(id, updater) - } - - /** - * Remove a buffer handle from the catalog and, if it this was the final handle, - * release the resources of the registered buffers. - * - * @return true: if the buffer for this handle was removed from the spill framework - * (`handle` was the last handle) - * false: if buffer was not removed due to other live handles. - */ - private def removeBuffer(handle: RapidsBufferHandle): Boolean = synchronized { - // if this is the last handle, remove the buffer - if (stopTrackingHandle(handle)) { - logDebug(s"Removing buffer ${handle.id}") - bufferMap.remove(handle.id).safeFree() - true - } else { - false - } - } - - /** Return the number of buffers currently in the catalog. */ - def numBuffers: Int = bufferMap.size() - - override def close(): Unit = { - bufferIdToHandles.values.asScala.toSeq.flatMap(_.seq).safeClose() - - bufferIdToHandles.clear() - } -} - -object RapidsBufferCatalog extends Logging { - private val MAX_BUFFER_LOOKUP_ATTEMPTS = 100 - - private var deviceStorage: RapidsDeviceMemoryStore = _ - private var hostStorage: RapidsHostMemoryStore = _ - private var diskBlockManager: RapidsDiskBlockManager = _ - private var diskStorage: RapidsDiskStore = _ - private var memoryEventHandler: DeviceMemoryEventHandler = _ - private var _shouldUnspill: Boolean = _ - private var _singleton: RapidsBufferCatalog = null - - def singleton: RapidsBufferCatalog = { - if (_singleton == null) { - synchronized { - if (_singleton == null) { - _singleton = new RapidsBufferCatalog(deviceStorage) - } - } - } - _singleton - } - - private lazy val conf: SparkConf = { - val env = SparkEnv.get - if (env != null) { - env.conf - } else { - // For some unit tests - new SparkConf() - } - } - - /** - * Set a `RapidsDeviceMemoryStore` instance to use when instantiating our - * catalog. - * @note This should only be called from tests! - */ - def setDeviceStorage(rdms: RapidsDeviceMemoryStore): Unit = { - deviceStorage = rdms - } - - /** - * Set a `RapidsDiskStore` instance to use when instantiating our - * catalog. - * - * @note This should only be called from tests! - */ - def setDiskStorage(rdms: RapidsDiskStore): Unit = { - diskStorage = rdms - } - - /** - * Set a `RapidsHostMemoryStore` instance to use when instantiating our - * catalog. - * - * @note This should only be called from tests! - */ - def setHostStorage(rhms: RapidsHostMemoryStore): Unit = { - hostStorage = rhms - } - - /** - * Set a `RapidsBufferCatalog` instance to use our singleton. - * @note This should only be called from tests! - */ - def setCatalog(catalog: RapidsBufferCatalog): Unit = synchronized { - if (_singleton != null) { - _singleton.close() - } - _singleton = catalog - } - - def init(rapidsConf: RapidsConf): Unit = { - // We are going to re-initialize so make sure all of the old things were closed... - closeImpl() - assert(memoryEventHandler == null) - deviceStorage = new RapidsDeviceMemoryStore( - rapidsConf.chunkedPackBounceBufferSize, - rapidsConf.spillToDiskBounceBufferSize) - diskBlockManager = new RapidsDiskBlockManager(conf) - val hostSpillStorageSize = if (rapidsConf.offHeapLimitEnabled) { - // Disable the limit because it is handled by the RapidsHostMemoryStore - None - } else if (rapidsConf.hostSpillStorageSize == -1) { - // + 1 GiB by default to match backwards compatibility - Some(rapidsConf.pinnedPoolSize + (1024 * 1024 * 1024)) - } else { - Some(rapidsConf.hostSpillStorageSize) - } - hostStorage = new RapidsHostMemoryStore(hostSpillStorageSize) - diskStorage = new RapidsDiskStore(diskBlockManager) - deviceStorage.setSpillStore(hostStorage) - hostStorage.setSpillStore(diskStorage) - - logInfo("Installing GPU memory handler for spill") - memoryEventHandler = new DeviceMemoryEventHandler( - singleton, - deviceStorage, - rapidsConf.gpuOomDumpDir, - rapidsConf.gpuOomMaxRetries) - - if (rapidsConf.sparkRmmStateEnable) { - val debugLoc = if (rapidsConf.sparkRmmDebugLocation.isEmpty) { - null - } else { - rapidsConf.sparkRmmDebugLocation - } - - RmmSpark.setEventHandler(memoryEventHandler, debugLoc) - } else { - logWarning("SparkRMM retry has been disabled") - Rmm.setEventHandler(memoryEventHandler) - } - - _shouldUnspill = rapidsConf.isUnspillEnabled - } - - def close(): Unit = { - logInfo("Closing storage") - closeImpl() - } - - /** - * Only used in unit tests, it returns the number of buffers in the catalog. - */ - def numBuffers: Int = { - _singleton.numBuffers - } - - private def closeImpl(): Unit = synchronized { - Seq(_singleton, deviceStorage, hostStorage, diskStorage).safeClose() - - _singleton = null - // Workaround for shutdown ordering problems where device buffers allocated - // with this handler are being freed after the handler is destroyed - //Rmm.clearEventHandler() - memoryEventHandler = null - deviceStorage = null - hostStorage = null - diskStorage = null - } - - def getDeviceStorage: RapidsDeviceMemoryStore = deviceStorage - - def getHostStorage: RapidsHostMemoryStore = hostStorage - - def shouldUnspill: Boolean = _shouldUnspill - - /** - * Adds a contiguous table to the device storage. This does NOT take ownership of the - * contiguous table, so it is the responsibility of the caller to close it. The refcount of the - * underlying device buffer will be incremented so the contiguous table can be closed before - * this buffer is destroyed. - * @param contigTable contiguous table to trackNewHandle in device storage - * @param initialSpillPriority starting spill priority value for the buffer - * @return RapidsBufferHandle associated with this buffer - */ - def addContiguousTable( - contigTable: ContiguousTable, - initialSpillPriority: Long): RapidsBufferHandle = { - singleton.addContiguousTable(contigTable, initialSpillPriority) - } - - /** - * Adds a buffer to the catalog and store. This does NOT take ownership of the - * buffer, so it is the responsibility of the caller to close it. - * @param buffer buffer that will be owned by the store - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @return RapidsBufferHandle associated with this buffer - */ - def addBuffer( - buffer: MemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long): RapidsBufferHandle = { - singleton.addBuffer(buffer, tableMeta, initialSpillPriority) - } - - def addBatch( - batch: ColumnarBatch, - initialSpillPriority: Long): RapidsBufferHandle = { - singleton.addBatch(batch, initialSpillPriority) - } - - /** - * Lookup the buffer that corresponds to the specified buffer handle and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * @param handle buffer handle - * @return buffer that has been acquired - */ - def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = - singleton.acquireBuffer(handle) - - /** - * Acquires a RapidsBuffer that the caller expects to be host-backed and not - * device bound. This ensures that the buffer acquired implements the correct - * trait, otherwise it throws and removes its buffer acquisition. - * - * @param handle handle associated with this `RapidsBuffer` - * @return host-backed RapidsBuffer that has been acquired - */ - def acquireHostBatchBuffer(handle: RapidsBufferHandle): RapidsHostBatchBuffer = - singleton.acquireHostBatchBuffer(handle) - - def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager - - /** - * Free memory in `store` by spilling buffers to its spill store synchronously. - * @param store store to spill from - * @param targetTotalSize maximum total size of this store after spilling completes - * @param stream CUDA stream to use or omit for default stream - * @return optionally number of bytes that were spilled, or None if this call - * made no attempt to spill due to a detected spill race - */ - def synchronousSpill( - store: RapidsBufferStore, - targetTotalSize: Long, - stream: Cuda.Stream = Cuda.DEFAULT_STREAM): Option[Long] = { - singleton.synchronousSpill(store, targetTotalSize, stream) - } - - /** - * Given a `MemoryBuffer` find out if a `MemoryBuffer.EventHandler` is associated - * with it. - * - * After getting the `RapidsBuffer` try to acquire it via `addReference`. - * If successful, we can point to this buffer with a new handle, otherwise the buffer is - * about to be removed/freed (unlikely, because we are holding onto the reference as we - * are adding it again). - * - * @note public for testing - * @param buffer - the `MemoryBuffer` to inspect - * @return - Some(RapidsBuffer): the handler is associated with a rapids buffer - * and the rapids buffer is currently valid, or - * - * - None: if no `RapidsBuffer` is associated with this buffer (it is - * brand new to the store, or the `RapidsBuffer` is invalid and - * about to be removed). - */ - private def getExistingRapidsBufferAndAcquire(buffer: MemoryBuffer): Option[RapidsBuffer] = { - buffer match { - case hb: HostMemoryBuffer => - HostAlloc.findEventHandler(hb) { - case rapidsBuffer: RapidsBuffer => - if (rapidsBuffer.addReference()) { - Some(rapidsBuffer) - } else { - None - } - }.flatten - case _ => - val eh = buffer.getEventHandler - eh match { - case null => - None - case rapidsBuffer: RapidsBuffer => - if (rapidsBuffer.addReference()) { - Some(rapidsBuffer) - } else { - None - } - case _ => - throw new IllegalStateException("Unknown event handler") - } - } - } -} - diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala deleted file mode 100644 index b1ee9e7a863..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala +++ /dev/null @@ -1,640 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.util.Comparator -import java.util.concurrent.locks.ReentrantReadWriteLock - -import scala.collection.mutable - -import ai.rapids.cudf.{BaseDeviceMemoryBuffer, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.Arm._ -import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.StorageTier.{DEVICE, HOST, StorageTier} -import com.nvidia.spark.rapids.format.TableMeta - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * Helper case classes that contain the buffer we spilled or unspilled from our current tier - * and likely a new buffer created in a target store tier, but it can be set to None. - * If the buffer already exists in the target store, `newBuffer` will be None. - * @param spillBuffer a `RapidsBuffer` we spilled or unspilled from this store - * @param newBuffer an optional `RapidsBuffer` in the target store. - */ -trait SpillAction { - val spillBuffer: RapidsBuffer - val newBuffer: Option[RapidsBuffer] -} - -case class BufferSpill(spillBuffer: RapidsBuffer, newBuffer: Option[RapidsBuffer]) - extends SpillAction - -case class BufferUnspill(spillBuffer: RapidsBuffer, newBuffer: Option[RapidsBuffer]) - extends SpillAction - -/** - * Base class for all buffer store types. - * - * @param tier storage tier of this store - * @param catalog catalog to register this store - */ -abstract class RapidsBufferStore(val tier: StorageTier) - extends AutoCloseable with Logging { - - val name: String = tier.toString - - private class BufferTracker { - private[this] val comparator: Comparator[RapidsBufferBase] = - (o1: RapidsBufferBase, o2: RapidsBufferBase) => - java.lang.Long.compare(o1.getSpillPriority, o2.getSpillPriority) - // buffers: contains all buffers in this store, whether spillable or not - private[this] val buffers = new java.util.HashMap[RapidsBufferId, RapidsBufferBase] - // spillable: contains only those buffers that are currently spillable - private[this] val spillable = new HashedPriorityQueue[RapidsBufferBase](comparator) - // spilling: contains only those buffers that are currently being spilled, but - // have not been removed from the store - private[this] val spilling = new mutable.HashSet[RapidsBufferId]() - // total bytes stored, regardless of spillable status - private[this] var totalBytesStored: Long = 0L - // total bytes that are currently eligible to be spilled - private[this] var totalBytesSpillable: Long = 0L - - def add(buffer: RapidsBufferBase): Unit = synchronized { - val old = buffers.put(buffer.id, buffer) - // it is unlikely that the buffer was in this collection, but removing - // anyway. We assume the buffer is safe in this tier, and is not spilling - spilling.remove(buffer.id) - if (old != null) { - throw new DuplicateBufferException(s"duplicate buffer registered: ${buffer.id}") - } - totalBytesStored += buffer.memoryUsedBytes - - // device buffers "spillability" is handled via DeviceMemoryBuffer ref counting - // so spillableOnAdd should be false, all other buffer tiers are spillable at - // all times. - if (spillableOnAdd && buffer.memoryUsedBytes > 0) { - if (spillable.offer(buffer)) { - totalBytesSpillable += buffer.memoryUsedBytes - } - } - } - - def remove(id: RapidsBufferId): Unit = synchronized { - // when removing a buffer we no longer need to know if it was spilling - spilling.remove(id) - val obj = buffers.remove(id) - if (obj != null) { - totalBytesStored -= obj.memoryUsedBytes - if (spillable.remove(obj)) { - totalBytesSpillable -= obj.memoryUsedBytes - } - } - } - - def freeAll(): Unit = { - val values = synchronized { - val buffs = buffers.values().toArray(new Array[RapidsBufferBase](0)) - buffers.clear() - spillable.clear() - spilling.clear() - buffs - } - // We need to release the `RapidsBufferStore` lock to prevent a lock order inversion - // deadlock: (1) `RapidsBufferBase.free` calls (2) `RapidsBufferStore.remove` and - // (1) `RapidsBufferStore.freeAll` calls (2) `RapidsBufferBase.free`. - values.safeFree() - } - - /** - * Sets a buffers state to spillable or non-spillable. - * - * If the buffer is currently being spilled or it is no longer in the `buffers` collection - * (e.g. it is not in this store), the action is skipped. - * - * @param buffer the buffer to mark as spillable or not - * @param isSpillable whether the buffer should now be spillable - */ - def setSpillable(buffer: RapidsBufferBase, isSpillable: Boolean): Unit = synchronized { - if (isSpillable && buffer.memoryUsedBytes > 0) { - // if this buffer is in the store and isn't currently spilling - if (!spilling.contains(buffer.id) && buffers.containsKey(buffer.id)) { - // try to add it to the spillable collection - if (spillable.offer(buffer)) { - totalBytesSpillable += buffer.memoryUsedBytes - logDebug(s"Buffer ${buffer.id} is spillable. " + - s"total=${totalBytesStored} spillable=${totalBytesSpillable}") - } // else it was already there (unlikely) - } - } else { - if (spillable.remove(buffer)) { - totalBytesSpillable -= buffer.memoryUsedBytes - logDebug(s"Buffer ${buffer.id} is not spillable. " + - s"total=${totalBytesStored}, spillable=${totalBytesSpillable}") - } // else it was already removed - } - } - - def nextSpillableBuffer(): RapidsBufferBase = synchronized { - val buffer = spillable.poll() - if (buffer != null) { - // mark the id as "spilling" (this buffer is in the middle of a spill operation) - spilling.add(buffer.id) - totalBytesSpillable -= buffer.memoryUsedBytes - logDebug(s"Spilling buffer ${buffer.id}. size=${buffer.memoryUsedBytes} " + - s"total=${totalBytesStored}, new spillable=${totalBytesSpillable}") - } - buffer - } - - def updateSpillPriority(buffer: RapidsBufferBase, priority:Long): Unit = synchronized { - buffer.updateSpillPriorityValue(priority) - spillable.priorityUpdated(buffer) - } - - def getTotalBytes: Long = synchronized { totalBytesStored } - - def getTotalSpillableBytes: Long = synchronized { totalBytesSpillable } - } - - /** - * Stores that need to stay within a specific byte limit of buffers stored override - * this function. Only the `HostMemoryBufferStore` requires such a limit. - * @return maximum amount of bytes that can be stored in the store, None for no - * limit - */ - def getMaxSize: Option[Long] = None - - private[this] val buffers = new BufferTracker - - /** A store that can be used for spilling. */ - var spillStore: RapidsBufferStore = _ - - /** Return the current byte total of buffers in this store. */ - def currentSize: Long = buffers.getTotalBytes - - def currentSpillableSize: Long = buffers.getTotalSpillableBytes - - /** - * A store that manages spillability of buffers should override this method - * to false, otherwise `BufferTracker` treats buffers as always spillable. - */ - protected def spillableOnAdd: Boolean = true - - /** - * Specify another store that can be used when this store needs to spill. - * @note Only one spill store can be registered. This will throw if a - * spill store has already been registered. - */ - def setSpillStore(store: RapidsBufferStore): Unit = { - require(spillStore == null, "spill store already registered") - spillStore = store - } - - /** - * Adds an existing buffer from another store to this store. The buffer must already - * have an active reference by the caller and needs to be eventually closed by the caller - * (i.e.: this method will not take ownership of the incoming buffer object). - * This does not need to update the catalog, the caller is responsible for that. - * @param buffer data from another store - * @param catalog RapidsBufferCatalog we may need to modify during this copy - * @param stream CUDA stream to use for copy or null - * @return the new buffer that was created - */ - def copyBuffer( - buffer: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] = { - createBuffer(buffer, catalog, stream).map { newBuffer => - freeOnExcept(newBuffer) { newBuffer => - addBuffer(newBuffer) - newBuffer - } - } - } - - protected def setSpillable(buffer: RapidsBufferBase, isSpillable: Boolean): Unit = { - buffers.setSpillable(buffer, isSpillable) - } - - /** - * Create a new buffer from an existing buffer in another store. - * If the data transfer will be performed asynchronously, this method is responsible for - * adding a reference to the existing buffer and later closing it when the transfer completes. - * - * @note DO NOT close the buffer unless adding a reference! - * @note `createBuffer` impls should synchronize against `stream` before returning, if needed. - * @param buffer data from another store - * @param catalog RapidsBufferCatalog we may need to modify during this create - * @param stream CUDA stream to use or null - * @return the new buffer that was created. - */ - protected def createBuffer( - buffer: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] - - /** Update bookkeeping for a new buffer */ - protected def addBuffer(buffer: RapidsBufferBase): Unit = { - buffers.add(buffer) - buffer.updateSpillability() - } - - /** - * Adds a buffer to the spill framework, stream synchronizing with the producer - * stream to ensure that the buffer is fully materialized, and can be safely copied - * as part of the spill. - * - * @param needsSync true if we should stream synchronize before adding the buffer - */ - protected def addBuffer(buffer: RapidsBufferBase, needsSync: Boolean): Unit = { - if (needsSync) { - Cuda.DEFAULT_STREAM.sync() - } - addBuffer(buffer) - } - - override def close(): Unit = { - buffers.freeAll() - } - - def nextSpillable(): RapidsBuffer = { - buffers.nextSpillableBuffer() - } - - def synchronousSpill( - targetTotalSize: Long, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream = Cuda.DEFAULT_STREAM): Long = { - if (currentSpillableSize > targetTotalSize) { - logWarning(s"Targeting a ${name} size of $targetTotalSize. " + - s"Current total ${currentSize}. " + - s"Current spillable ${currentSpillableSize}") - val bufferSpills = new mutable.ArrayBuffer[BufferSpill]() - withResource(new NvtxRange(s"${name} sync spill", NvtxColor.ORANGE)) { _ => - logWarning(s"${name} store spilling to reduce usage from " + - s"${currentSize} total (${currentSpillableSize} spillable) " + - s"to $targetTotalSize bytes") - - // If the store has 0 spillable bytes left, it has exhausted. - try { - var exhausted = false - var totalSpilled = 0L - while (!exhausted && - currentSpillableSize > targetTotalSize) { - val nextSpillableBuffer = nextSpillable() - if (nextSpillableBuffer != null) { - if (nextSpillableBuffer.addReference()) { - withResource(nextSpillableBuffer) { _ => - val bufferHasSpilled = - catalog.isBufferSpilled( - nextSpillableBuffer.id, - nextSpillableBuffer.storageTier) - val bufferSpill = if (!bufferHasSpilled) { - spillBuffer( - nextSpillableBuffer, this, catalog, stream) - } else { - // if `nextSpillableBuffer` already spilled, we still need to - // remove it from our tier and call free on it, but set - // `newBuffer` to None because there's nothing to register - // as it has already spilled. - BufferSpill(nextSpillableBuffer, None) - } - totalSpilled += bufferSpill.spillBuffer.memoryUsedBytes - bufferSpills.append(bufferSpill) - catalog.updateTiers(bufferSpill) - } - } - } - } - if (totalSpilled <= 0) { - // we didn't spill in this iteration, exit loop - exhausted = true - logWarning("Unable to spill enough to meet request. " + - s"Total=${currentSize} " + - s"Spillable=${currentSpillableSize} " + - s"Target=$targetTotalSize") - } - totalSpilled - } finally { - if (bufferSpills.nonEmpty) { - // This is a hack in order to completely synchronize with the GPU before we free - // a buffer. It is necessary because of non-synchronous cuDF calls that could fall - // behind where the CPU is. Freeing a rapids buffer in these cases needs to wait for - // all launched GPU work, otherwise crashes or data corruption could occur. - // A more performant implementation would be to synchronize on the thread that read - // the buffer via events. - // https://github.com/NVIDIA/spark-rapids/issues/8610 - Cuda.deviceSynchronize() - bufferSpills.foreach(_.spillBuffer.safeFree()) - } - } - } - } else { - 0L // nothing spilled - } - } - - /** - * Given a specific `RapidsBuffer` spill it to `spillStore` - * - * @return a `BufferSpill` instance with the target buffer in this store, and an optional - * new `RapidsBuffer` in the target spill store if this rapids buffer hadn't already - * spilled. - * @note called with catalog lock held - */ - private def spillBuffer( - buffer: RapidsBuffer, - store: RapidsBufferStore, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): BufferSpill = { - // copy the buffer to spillStore - var maybeNewBuffer: Option[RapidsBuffer] = None - var lastTier: Option[StorageTier] = None - var nextSpillStore = store.spillStore - while (maybeNewBuffer.isEmpty && nextSpillStore != null) { - lastTier = Some(nextSpillStore.tier) - // copy buffer if it fits - maybeNewBuffer = nextSpillStore.copyBuffer(buffer, catalog, stream) - - // if it didn't fit, we can try a lower tier that has more space - if (maybeNewBuffer.isEmpty) { - nextSpillStore = nextSpillStore.spillStore - } - } - if (maybeNewBuffer.isEmpty) { - throw new IllegalStateException( - s"Unable to spill buffer ${buffer.id} of size ${buffer.memoryUsedBytes} " + - s"to tier ${lastTier}") - } - // return the buffer to free and the new buffer to register - BufferSpill(buffer, maybeNewBuffer) - } - - /** - * Tries to make room for `buffer` in the host store by spilling. - * - * @param buffer buffer that will be copied to the host store if it fits - * @param stream CUDA stream to synchronize for memory operations - * @return true if the buffer fits after a potential spill - */ - protected def trySpillToMaximumSize( - buffer: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Boolean = { - true // default to success, HostMemoryStore overrides this - } - - /** Base class for all buffers in this store. */ - abstract class RapidsBufferBase( - override val id: RapidsBufferId, - _meta: TableMeta, - initialSpillPriority: Long, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton) - extends RapidsBuffer { - private val MAX_UNSPILL_ATTEMPTS = 100 - - // isValid and refcount must be used with the `RapidsBufferBase` lock held - protected[this] var isValid = true - protected[this] var refcount = 0 - - private[this] var spillPriority: Long = initialSpillPriority - - private[this] val rwl: ReentrantReadWriteLock = new ReentrantReadWriteLock() - - - def meta: TableMeta = _meta - - /** Release the underlying resources for this buffer. */ - protected def releaseResources(): Unit - - /** - * Materialize the memory buffer from the underlying storage. - * - * If the buffer resides in device or host memory, only reference count is incremented. - * If the buffer resides in secondary storage, a new host or device memory buffer is created, - * with the data copied to the new buffer. - * The caller must have successfully acquired the buffer beforehand. - * @see [[addReference]] - * @note It is the responsibility of the caller to close the buffer. - * @note This is an internal API only used by Rapids buffer stores. - */ - protected def materializeMemoryBuffer: MemoryBuffer = getMemoryBuffer - - override def addReference(): Boolean = synchronized { - if (isValid) { - refcount += 1 - } - isValid - } - - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - // NOTE: Cannot hold a lock on this buffer here because memory is being - // allocated. Allocations can trigger synchronous spills which can - // deadlock if another thread holds the device store lock and is trying - // to spill to this store. - withResource(getDeviceMemoryBuffer) { deviceBuffer => - columnarBatchFromDeviceBuffer(deviceBuffer, sparkTypes) - } - } - - protected def columnarBatchFromDeviceBuffer(devBuffer: DeviceMemoryBuffer, - sparkTypes: Array[DataType]): ColumnarBatch = { - val bufferMeta = meta.bufferMeta() - if (bufferMeta == null || bufferMeta.codecBufferDescrsLength == 0) { - MetaUtils.getBatchFromMeta(devBuffer, meta, sparkTypes) - } else { - GpuCompressedColumnVector.from(devBuffer, meta) - } - } - - override def copyToMemoryBuffer(srcOffset: Long, dst: MemoryBuffer, dstOffset: Long, - length: Long, stream: Cuda.Stream): Unit = { - withResource(getMemoryBuffer) { memBuff => - dst match { - case _: HostMemoryBuffer => - // TODO: consider moving to the async version. - dst.copyFromMemoryBuffer(dstOffset, memBuff, srcOffset, length, stream) - case _: BaseDeviceMemoryBuffer => - dst.copyFromMemoryBufferAsync(dstOffset, memBuff, srcOffset, length, stream) - case _ => - throw new IllegalStateException(s"Infeasible destination buffer type ${dst.getClass}") - } - } - } - - /** - * TODO: we want to remove this method from the buffer, instead we want the catalog - * to be responsible for producing the DeviceMemoryBuffer by asking the buffer. This - * hides the RapidsBuffer from clients and simplifies locking. - */ - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = { - if (RapidsBufferCatalog.shouldUnspill) { - (0 until MAX_UNSPILL_ATTEMPTS).foreach { _ => - catalog.acquireBuffer(id, DEVICE) match { - case Some(buffer) => - withResource(buffer) { _ => - return buffer.getDeviceMemoryBuffer - } - case _ => - try { - logDebug(s"Unspilling $this $id to $DEVICE") - val newBuffer = catalog.unspillBufferToDeviceStore( - this, - Cuda.DEFAULT_STREAM) - withResource(newBuffer) { _ => - return newBuffer.getDeviceMemoryBuffer - } - } catch { - case _: DuplicateBufferException => - logDebug(s"Lost device buffer registration race for buffer $id, retrying...") - } - } - } - throw new IllegalStateException(s"Unable to get device memory buffer for ID: $id") - } else { - materializeMemoryBuffer match { - case h: HostMemoryBuffer => - withResource(h) { _ => - closeOnExcept(DeviceMemoryBuffer.allocate(h.getLength)) { deviceBuffer => - logDebug(s"copying ${h.getLength} from host $h to device $deviceBuffer " + - s"of size ${deviceBuffer.getLength}") - deviceBuffer.copyFromHostBuffer(h) - deviceBuffer - } - } - case d: DeviceMemoryBuffer => d - case b => throw new IllegalStateException(s"Unrecognized buffer: $b") - } - } - } - - override def getHostMemoryBuffer: HostMemoryBuffer = { - (0 until MAX_UNSPILL_ATTEMPTS).foreach { _ => - catalog.acquireBuffer(id, HOST) match { - case Some(buffer) => - withResource(buffer) { _ => - return buffer.getHostMemoryBuffer - } - case _ => - try { - logDebug(s"Unspilling $this $id to $HOST") - val newBuffer = catalog.unspillBufferToHostStore( - this, - Cuda.DEFAULT_STREAM) - withResource(newBuffer) { _ => - return newBuffer.getHostMemoryBuffer - } - } catch { - case _: DuplicateBufferException => - logDebug(s"Lost host buffer registration race for buffer $id, retrying...") - } - } - } - throw new IllegalStateException(s"Unable to get host memory buffer for ID: $id") - } - - /** - * close() is called by client code to decrease the ref count of this RapidsBufferBase. - * In the off chance that by the time close is invoked, the buffer was freed (not valid) - * then this close call winds up freeing the resources of the rapids buffer. - */ - override def close(): Unit = synchronized { - if (refcount == 0) { - throw new IllegalStateException("Buffer already closed") - } - refcount -= 1 - if (refcount == 0 && !isValid) { - freeBuffer() - } - } - - /** - * Mark the buffer as freed and no longer valid. This is called by the store when removing a - * buffer (it is no longer tracked). - * - * @note The resources may not be immediately released if the buffer has outstanding references. - * In that case the resources will be released when the reference count reaches zero. - */ - override def free(): Unit = synchronized { - if (isValid) { - isValid = false - buffers.remove(id) - if (refcount == 0) { - freeBuffer() - } - } else { - logWarning(s"Trying to free an invalid buffer => $id, size = ${memoryUsedBytes}, $this") - } - } - - override def getSpillPriority: Long = spillPriority - - override def setSpillPriority(priority: Long): Unit = - buffers.updateSpillPriority(this, priority) - - private[RapidsBufferStore] def updateSpillPriorityValue(priority: Long): Unit = { - spillPriority = priority - } - - override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { - withResource(getMemoryBuffer) { buff => - val lock = rwl.readLock() - try { - lock.lock() - body(buff) - } finally { - lock.unlock() - } - } - } - - override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { - withResource(getMemoryBuffer) { buff => - val lock = rwl.writeLock() - try { - lock.lock() - body(buff) - } finally { - lock.unlock() - } - } - } - - /** Must be called with a lock on the buffer */ - private def freeBuffer(): Unit = { - releaseResources() - } - - override def toString: String = s"$name buffer size=$memoryUsedBytes" - } -} - -/** - * Buffers that inherit from this type do not support changing the spillable status - * of a `RapidsBuffer`. This is only used right now for disk. - * @param tier storage tier of this store - */ -abstract class RapidsBufferStoreWithoutSpill(override val tier: StorageTier) - extends RapidsBufferStore(tier) { - - override def setSpillable(rapidsBuffer: RapidsBufferBase, isSpillable: Boolean): Unit = { - throw new NotImplementedError(s"This store ${this} does not implement setSpillable") - } -} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala deleted file mode 100644 index c56806bc965..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala +++ /dev/null @@ -1,518 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.nio.channels.WritableByteChannel -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.mutable - -import ai.rapids.cudf.{ColumnVector, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Table} -import com.nvidia.spark.rapids.Arm._ -import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableSeq -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta - -import org.apache.spark.sql.rapids.GpuTaskMetrics -import org.apache.spark.sql.rapids.storage.RapidsStorageUtils -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * Buffer storage using device memory. - * @param chunkedPackBounceBufferSize this is the size of the bounce buffer to be used - * during spill in chunked_pack. The parameter defaults to 128MB, - * with a rule-of-thumb of 1MB per SM. - */ -class RapidsDeviceMemoryStore( - chunkedPackBounceBufferSize: Long = 128L*1024*1024, - hostBounceBufferSize: Long = 128L*1024*1024) - extends RapidsBufferStore(StorageTier.DEVICE) { - - // The RapidsDeviceMemoryStore handles spillability via ref counting - override protected def spillableOnAdd: Boolean = false - - // bounce buffer to be used during chunked pack in GPU to host memory spill - private var chunkedPackBounceBuffer: DeviceMemoryBuffer = - DeviceMemoryBuffer.allocate(chunkedPackBounceBufferSize) - - private var hostSpillBounceBuffer: HostMemoryBuffer = - HostMemoryBuffer.allocate(hostBounceBufferSize) - - override protected def createBuffer( - other: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] = { - val memoryBuffer = withResource(other.getCopyIterator) { copyIterator => - copyIterator.next() - } - withResource(memoryBuffer) { _ => - val deviceBuffer = { - memoryBuffer match { - case d: DeviceMemoryBuffer => d - case h: HostMemoryBuffer => - GpuTaskMetrics.get.readSpillFromHostTime { - closeOnExcept(DeviceMemoryBuffer.allocate(memoryBuffer.getLength)) { deviceBuffer => - logDebug(s"copying from host $h to device $deviceBuffer") - deviceBuffer.copyFromHostBuffer(h, stream) - deviceBuffer - } - } - case b => throw new IllegalStateException(s"Unrecognized buffer: $b") - } - } - Some(new RapidsDeviceMemoryBuffer( - other.id, - deviceBuffer.getLength, - other.meta, - deviceBuffer, - other.getSpillPriority)) - } - } - - /** - * Adds a buffer to the device storage. This does NOT take ownership of the - * buffer, so it is the responsibility of the caller to close it. - * - * This function is called only from the RapidsBufferCatalog, under the - * catalog lock. - * - * @param id the RapidsBufferId to use for this buffer - * @param buffer buffer that will be owned by the store - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return the RapidsBuffer instance that was added. - */ - def addBuffer( - id: RapidsBufferId, - buffer: DeviceMemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBuffer = { - buffer.incRefCount() - val rapidsBuffer = new RapidsDeviceMemoryBuffer( - id, - buffer.getLength, - tableMeta, - buffer, - initialSpillPriority) - freeOnExcept(rapidsBuffer) { _ => - logDebug(s"Adding receive side table for: [id=$id, size=${buffer.getLength}, " + - s"uncompressed=${rapidsBuffer.meta.bufferMeta.uncompressedSize}, " + - s"meta_id=${tableMeta.bufferMeta.id}, " + - s"meta_size=${tableMeta.bufferMeta.size}]") - addBuffer(rapidsBuffer, needsSync) - rapidsBuffer - } - } - - /** - * Adds a table to the device storage. - * - * This takes ownership of the table. - * - * This function is called only from the RapidsBufferCatalog, under the - * catalog lock. - * - * @param id the RapidsBufferId to use for this table - * @param table table that will be owned by the store - * @param initialSpillPriority starting spill priority value - * @param needsSync whether the spill framework should stream synchronize while adding - * this table (defaults to true) - * @return the RapidsBuffer instance that was added. - */ - def addTable( - id: RapidsBufferId, - table: Table, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBuffer = { - val rapidsTable = new RapidsTable( - id, - table, - initialSpillPriority) - freeOnExcept(rapidsTable) { _ => - addBuffer(rapidsTable, needsSync) - rapidsTable - } - } - - /** - * A per cuDF column event handler that handles calls to .close() - * inside of the `ColumnVector` lock. - */ - class RapidsDeviceColumnEventHandler - extends ColumnVector.EventHandler { - - // Every RapidsTable that references this column has an entry in this map. - // The value represents the number of times (normally 1) that a ColumnVector - // appears in the RapidsTable. This is also the ColumnVector refCount at which - // the column is considered spillable. - // The map is protected via the ColumnVector lock. - private val registration = new mutable.HashMap[RapidsTable, Int]() - - /** - * Every RapidsTable iterates through its columns and either creates - * a `ColumnTracking` object and associates it with the column's - * `eventHandler` or calls into the existing one, and registers itself. - * - * The registration has two goals: it accounts for repetition of a column - * in a `RapidsTable`. If a table has the same column repeated it must adjust - * the refCount at which this column is considered spillable. - * - * The second goal is to account for aliasing. If two tables alias this column - * we are going to mark it as non spillable. - * - * @param rapidsTable - the table that is registering itself with this tracker - */ - def register(rapidsTable: RapidsTable, repetition: Int): Unit = { - registration.put(rapidsTable, repetition) - } - - /** - * This is invoked during `RapidsTable.free` in order to remove the entry - * in `registration`. - * @param rapidsTable - the table that is de-registering itself - */ - def deregister(rapidsTable: RapidsTable): Unit = { - registration.remove(rapidsTable) - } - - // called with the cudfCv lock held from cuDF's side - override def onClosed(cudfCv: ColumnVector, refCount: Int): Unit = { - // we only handle spillability if there is a single table registered - // (no aliasing) - if (registration.size == 1) { - val (rapidsTable, spillableRefCount) = registration.head - if (spillableRefCount == refCount) { - rapidsTable.onColumnSpillable(cudfCv) - } - } - } - } - - /** - * A `RapidsTable` is the spill store holder of a cuDF `Table`. - * - * The table is not contiguous in GPU memory. Instead, this `RapidsBuffer` instance - * allows us to use the cuDF chunked_pack API to make the table contiguous as the spill - * is happening. - * - * This class owns the cuDF table and will close it when `close` is called. - * - * @param id the `RapidsBufferId` this table is associated with - * @param table the cuDF table that we are managing - * @param spillPriority a starting spill priority - */ - class RapidsTable( - id: RapidsBufferId, - table: Table, - spillPriority: Long) - extends RapidsBufferBase( - id, - null, - spillPriority) - with RapidsBufferChannelWritable { - - /** The storage tier for this buffer */ - override val storageTier: StorageTier = StorageTier.DEVICE - - override val supportsChunkedPacker: Boolean = true - - // This is the current size in batch form. It is to be used while this - // table hasn't migrated to another store. - private val unpackedSizeInBytes: Long = GpuColumnVector.getTotalDeviceMemoryUsed(table) - - // By default all columns are NOT spillable since we are not the only owners of - // the columns (the caller is holding onto a ColumnarBatch that will be closed - // after instantiation, triggering onClosed callbacks) - // This hash set contains the columns that are currently spillable. - private val columnSpillability = new ConcurrentHashMap[ColumnVector, Boolean]() - - private val numDistinctColumns = - (0 until table.getNumberOfColumns).map(table.getColumn).distinct.size - - // we register our event callbacks as the very first action to deal with - // spillability - registerOnCloseEventHandler() - - /** Release the underlying resources for this buffer. */ - override protected def releaseResources(): Unit = { - table.close() - } - - private lazy val (cachedMeta, cachedPackedSize) = { - withResource(makeChunkedPacker) { cp => - (cp.getMeta, cp.getTotalContiguousSize) - } - } - - override def meta: TableMeta = cachedMeta - - override val memoryUsedBytes: Long = unpackedSizeInBytes - - override def getPackedSizeBytes: Long = cachedPackedSize - - override def makeChunkedPacker: ChunkedPacker = - new ChunkedPacker(id, table, chunkedPackBounceBuffer) - - /** - * Mark a column as spillable - * - * @param column the ColumnVector to mark as spillable - */ - def onColumnSpillable(column: ColumnVector): Unit = { - columnSpillability.put(column, true) - updateSpillability() - } - - /** - * Update the spillability state of this RapidsTable. This is invoked from - * two places: - * - * - from the onColumnSpillable callback, which is invoked from a - * ColumnVector.EventHandler.onClosed callback. - * - * - after adding a table to the store to mark the table as spillable if - * all columns are spillable. - */ - override def updateSpillability(): Unit = { - setSpillable(this, columnSpillability.size == numDistinctColumns) - } - - /** - * Produce a `ColumnarBatch` from our table, and in the process make ourselves - * not spillable. - * - * @param sparkTypes the spark data types the batch should have - */ - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - columnSpillability.clear() - setSpillable(this, false) - GpuColumnVector.from(table, sparkTypes) - } - - /** - * Get the underlying memory buffer. This may be either a HostMemoryBuffer or a - * DeviceMemoryBuffer depending on where the buffer currently resides. - * The caller must have successfully acquired the buffer beforehand. - * - * @see [[addReference]] - * @note It is the responsibility of the caller to close the buffer. - */ - override def getMemoryBuffer: MemoryBuffer = { - throw new UnsupportedOperationException( - "RapidsDeviceMemoryBatch doesn't support getMemoryBuffer") - } - - override def free(): Unit = { - // lets remove our handler from the chain of handlers for each column - removeOnCloseEventHandler() - super.free() - } - - private def registerOnCloseEventHandler(): Unit = { - val columns = (0 until table.getNumberOfColumns).map(table.getColumn) - // cudfColumns could contain duplicates. We need to take this into account when we are - // deciding the floor refCount for a duplicated column - val repetitionPerColumn = new mutable.HashMap[ColumnVector, Int]() - columns.foreach { col => - val repetitionCount = repetitionPerColumn.getOrElse(col, 0) - repetitionPerColumn(col) = repetitionCount + 1 - } - repetitionPerColumn.foreach { case (distinctCv, repetition) => - // lock the column because we are setting its event handler, and we are inspecting - // its refCount. - distinctCv.synchronized { - val eventHandler = distinctCv.getEventHandler match { - case null => - val eventHandler = new RapidsDeviceColumnEventHandler - distinctCv.setEventHandler(eventHandler) - eventHandler - case existing: RapidsDeviceColumnEventHandler => - existing - case other => - throw new IllegalStateException( - s"Invalid column event handler $other") - } - eventHandler.register(this, repetition) - if (repetition == distinctCv.getRefCount) { - onColumnSpillable(distinctCv) - } - } - } - } - - // this method is called from free() - private def removeOnCloseEventHandler(): Unit = { - val distinctColumns = - (0 until table.getNumberOfColumns).map(table.getColumn).distinct - distinctColumns.foreach { distinctCv => - distinctCv.synchronized { - distinctCv.getEventHandler match { - case eventHandler: RapidsDeviceColumnEventHandler => - eventHandler.deregister(this) - case t => - throw new IllegalStateException( - s"Invalid column event handler $t") - } - } - } - } - - override def writeToChannel(outputChannel: WritableByteChannel, stream: Cuda.Stream): Long = { - var written: Long = 0L - withResource(getCopyIterator) { copyIter => - while(copyIter.hasNext) { - withResource(copyIter.next()) { slice => - val iter = - new MemoryBufferToHostByteBufferIterator( - slice, - hostSpillBounceBuffer, - stream) - iter.foreach { bb => - try { - while (bb.hasRemaining) { - written += outputChannel.write(bb) - } - } finally { - RapidsStorageUtils.dispose(bb) - } - } - } - } - written - } - } - - } - - class RapidsDeviceMemoryBuffer( - id: RapidsBufferId, - size: Long, - meta: TableMeta, - contigBuffer: DeviceMemoryBuffer, - spillPriority: Long) - extends RapidsBufferBase(id, meta, spillPriority) - with MemoryBuffer.EventHandler - with RapidsBufferChannelWritable { - - override val memoryUsedBytes: Long = size - - override val storageTier: StorageTier = StorageTier.DEVICE - - // If this require triggers, we are re-adding a `DeviceMemoryBuffer` outside of - // the catalog lock, which should not possible. The event handler is set to null - // when we free the `RapidsDeviceMemoryBuffer` and if the buffer is not free, we - // take out another handle (in the catalog). - // TODO: This is not robust (to rely on outside locking and addReference/free) - // and should be revisited. - require(contigBuffer.setEventHandler(this) == null, - "DeviceMemoryBuffer with non-null event handler failed to add!!") - - /** - * Override from the MemoryBuffer.EventHandler interface. - * - * If we are being invoked we have the `contigBuffer` lock, as this callback - * is being invoked from `MemoryBuffer.close` - * - * @param refCount - contigBuffer's current refCount - */ - override def onClosed(refCount: Int): Unit = { - // refCount == 1 means only 1 reference exists to `contigBuffer` in the - // RapidsDeviceMemoryBuffer (we own it) - if (refCount == 1) { - // setSpillable is being called here as an extension of `MemoryBuffer.close()` - // we hold the MemoryBuffer lock and we could be called from a Spark task thread - // Since we hold the MemoryBuffer lock, `incRefCount` waits for us. The only other - // call to `setSpillable` is also under this same MemoryBuffer lock (see: - // `getDeviceMemoryBuffer`) - setSpillable(this, true) - } - } - - override protected def releaseResources(): Unit = synchronized { - // we need to disassociate this RapidsBuffer from the underlying buffer - contigBuffer.close() - } - - /** - * Get and increase the reference count of the device memory buffer - * in this RapidsBuffer, while making the RapidsBuffer non-spillable. - * - * @note It is the responsibility of the caller to close the DeviceMemoryBuffer - */ - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = synchronized { - contigBuffer.synchronized { - setSpillable(this, false) - contigBuffer.incRefCount() - contigBuffer - } - } - - override def getMemoryBuffer: MemoryBuffer = getDeviceMemoryBuffer - - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - // calling `getDeviceMemoryBuffer` guarantees that we have marked this RapidsBuffer - // as not spillable and increased its refCount atomically - withResource(getDeviceMemoryBuffer) { buff => - columnarBatchFromDeviceBuffer(buff, sparkTypes) - } - } - - /** - * We overwrite free to make sure we don't have a handler for the underlying - * contigBuffer, since this `RapidsBuffer` is no longer tracked. - */ - override def free(): Unit = synchronized { - if (isValid) { - // it is going to be invalid when calling super.free() - contigBuffer.setEventHandler(null) - } - super.free() - } - - override def writeToChannel(outputChannel: WritableByteChannel, stream: Cuda.Stream): Long = { - var written: Long = 0L - val iter = new MemoryBufferToHostByteBufferIterator( - contigBuffer, - hostSpillBounceBuffer, - stream) - iter.foreach { bb => - try { - while (bb.hasRemaining) { - written += outputChannel.write(bb) - } - } finally { - RapidsStorageUtils.dispose(bb) - } - } - written - } - - } - override def close(): Unit = { - try { - super.close() - } finally { - Seq(chunkedPackBounceBuffer, hostSpillBounceBuffer).safeClose() - chunkedPackBounceBuffer = null - hostSpillBounceBuffer = null - } - } -} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala deleted file mode 100644 index eb3692d434a..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.{File, FileInputStream} -import java.nio.channels.{Channels, FileChannel} -import java.nio.channels.FileChannel.MapMode -import java.nio.file.StandardOpenOption -import java.util.concurrent.ConcurrentHashMap - -import ai.rapids.cudf.{Cuda, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta -import org.apache.commons.io.IOUtils - -import org.apache.spark.TaskContext -import org.apache.spark.sql.rapids.{GpuTaskMetrics, RapidsDiskBlockManager} -import org.apache.spark.sql.rapids.execution.{SerializedHostTableUtils, TrampolineUtil} -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** A buffer store using files on the local disks. */ -class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) - extends RapidsBufferStoreWithoutSpill(StorageTier.DISK) { - private[this] val sharedBufferFiles = new ConcurrentHashMap[RapidsBufferId, File] - - private def reportDiskAllocMetrics(metrics: GpuTaskMetrics): String = { - val taskId = TaskContext.get().taskAttemptId() - val totalSize = metrics.getDiskBytesAllocated - val maxSize = metrics.getMaxDiskBytesAllocated - s"total size for task $taskId is $totalSize, max size is $maxSize" - } - - override protected def createBuffer( - incoming: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] = { - // assuming that the disk store gets contiguous buffers - val id = incoming.id - val path = if (id.canShareDiskPaths) { - sharedBufferFiles.computeIfAbsent(id, _ => id.getDiskPath(diskBlockManager)) - } else { - id.getDiskPath(diskBlockManager) - } - - val (fileOffset, uncompressedSize, diskLength) = if (id.canShareDiskPaths) { - // only one writer at a time for now when using shared files - path.synchronized { - writeToFile(incoming, path, append = true, stream) - } - } else { - writeToFile(incoming, path, append = false, stream) - } - logDebug(s"Spilled to $path $fileOffset:$diskLength") - val buff = incoming match { - case _: RapidsHostBatchBuffer => - new RapidsDiskColumnarBatch( - id, - fileOffset, - uncompressedSize, - diskLength, - incoming.meta, - incoming.getSpillPriority) - - case _ => - new RapidsDiskBuffer( - id, - fileOffset, - uncompressedSize, - diskLength, - incoming.meta, - incoming.getSpillPriority) - } - TrampolineUtil.incTaskMetricsDiskBytesSpilled(uncompressedSize) - - val metrics = GpuTaskMetrics.get - metrics.incDiskBytesAllocated(uncompressedSize) - logDebug(s"acquiring resources for disk buffer $id of size $uncompressedSize bytes") - logDebug(reportDiskAllocMetrics(metrics)) - Some(buff) - } - - /** - * Copy a host buffer to a file. It leverages [[RapidsSerializerManager]] from - * [[RapidsDiskBlockManager]] to do compression or encryption if needed. - * - * @param incoming the rapid buffer to be written into a file - * @param path file path - * @param append whether to append or written into the beginning of the file - * @param stream cuda stream - * @return a tuple of file offset, memory byte size and written size on disk. File offset is where - * buffer starts in the targeted file path. Memory byte size is the size of byte buffer - * occupied in memory before writing to disk. Written size on disk is actual byte size - * written to disk. - */ - private def writeToFile( - incoming: RapidsBuffer, - path: File, - append: Boolean, - stream: Cuda.Stream): (Long, Long, Long) = { - incoming match { - case fileWritable: RapidsBufferChannelWritable => - val option = if (append) { - Array(StandardOpenOption.CREATE, StandardOpenOption.APPEND) - } else { - Array(StandardOpenOption.CREATE, StandardOpenOption.WRITE) - } - var currentPos, writtenBytes = 0L - - GpuTaskMetrics.get.spillToDiskTime { - withResource(FileChannel.open(path.toPath, option: _*)) { fc => - currentPos = fc.position() - withResource(Channels.newOutputStream(fc)) { os => - withResource(diskBlockManager.getSerializerManager() - .wrapStream(incoming.id, os)) { cos => - val outputChannel = Channels.newChannel(cos) - writtenBytes = fileWritable.writeToChannel(outputChannel, stream) - } - } - (currentPos, writtenBytes, path.length() - currentPos) - } - } - case other => - throw new IllegalStateException( - s"Unable to write $other to file") - } - } - - /** - * A RapidsDiskBuffer that is mean to represent device-bound memory. This - * buffer can produce a device-backed ColumnarBatch. - */ - class RapidsDiskBuffer( - id: RapidsBufferId, - fileOffset: Long, - uncompressedSize: Long, - onDiskSizeInBytes: Long, - meta: TableMeta, - spillPriority: Long) - extends RapidsBufferBase(id, meta, spillPriority) { - - // FIXME: Need to be clean up. Tracked in https://github.com/NVIDIA/spark-rapids/issues/9496 - override val memoryUsedBytes: Long = uncompressedSize - - override val storageTier: StorageTier = StorageTier.DISK - - override def getMemoryBuffer: MemoryBuffer = synchronized { - require(onDiskSizeInBytes > 0, - s"$this attempted an invalid 0-byte mmap of a file") - val path = id.getDiskPath(diskBlockManager) - val serializerManager = diskBlockManager.getSerializerManager() - val memBuffer = if (serializerManager.isRapidsSpill(id)) { - // Only go through serializerManager's stream wrapper for spill case - closeOnExcept(HostAlloc.alloc(uncompressedSize)) { - decompressed => GpuTaskMetrics.get.readSpillFromDiskTime { - withResource(FileChannel.open(path.toPath, StandardOpenOption.READ)) { c => - c.position(fileOffset) - withResource(Channels.newInputStream(c)) { compressed => - withResource(serializerManager.wrapStream(id, compressed)) { in => - withResource(new HostMemoryOutputStream(decompressed)) { out => - IOUtils.copy(in, out) - } - decompressed - } - } - } - } - } - } else { - // Reserved mmap read fashion for UCX shuffle path. Also it's skipping encryption and - // compression. - HostMemoryBuffer.mapFile(path, MapMode.READ_WRITE, fileOffset, onDiskSizeInBytes) - } - memBuffer - } - - override def close(): Unit = synchronized { - super.close() - } - - override protected def releaseResources(): Unit = { - logDebug(s"releasing resources for disk buffer $id of size $memoryUsedBytes bytes") - val metrics = GpuTaskMetrics.get - metrics.decDiskBytesAllocated(memoryUsedBytes) - logDebug(reportDiskAllocMetrics(metrics)) - - // Buffers that share paths must be cleaned up elsewhere - if (id.canShareDiskPaths) { - sharedBufferFiles.remove(id) - } else { - val path = id.getDiskPath(diskBlockManager) - if (!path.delete() && path.exists()) { - logWarning(s"Unable to delete spill path $path") - } - } - } - } - - /** - * A RapidsDiskBuffer that should remain in the host, producing host-backed - * ColumnarBatch if the caller invokes getHostColumnarBatch, but not producing - * anything on the device. - */ - class RapidsDiskColumnarBatch( - id: RapidsBufferId, - fileOffset: Long, - size: Long, - uncompressedSize: Long, - // TODO: remove meta - meta: TableMeta, - spillPriority: Long) - extends RapidsDiskBuffer( - id, fileOffset, size, uncompressedSize, meta, spillPriority) - with RapidsHostBatchBuffer { - - override def getMemoryBuffer: MemoryBuffer = - throw new IllegalStateException( - "Called getMemoryBuffer on a disk buffer that needs deserialization") - - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = - throw new IllegalStateException( - "Called getColumnarBatch on a disk buffer that needs deserialization") - - override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - require(fileOffset == 0, - "Attempted to obtain a HostColumnarBatch from a spilled RapidsBuffer that is sharing " + - "paths on disk") - val path = id.getDiskPath(diskBlockManager) - withResource(new FileInputStream(path)) { fis => - withResource(diskBlockManager.getSerializerManager() - .wrapStream(id, fis)) { fs => - val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fs) - val hostCols = withResource(hostBuffer) { _ => - SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes) - } - new ColumnarBatch(hostCols.toArray, header.getNumRows) - } - } - } - } -} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala deleted file mode 100644 index 235ed9ddb45..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala +++ /dev/null @@ -1,484 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.DataOutputStream -import java.nio.channels.{Channels, WritableByteChannel} -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.mutable - -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, JCudfSerialization, MemoryBuffer} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, freeOnExcept, withResource} -import com.nvidia.spark.rapids.SpillPriorities.{applyPriorityOffset, HOST_MEMORY_BUFFER_SPILL_OFFSET} -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta - -import org.apache.spark.TaskContext -import org.apache.spark.sql.rapids.GpuTaskMetrics -import org.apache.spark.sql.rapids.storage.RapidsStorageUtils -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * A buffer store using host memory. - * @param maxSize maximum size in bytes for all buffers in this store - */ -class RapidsHostMemoryStore( - maxSize: Option[Long]) - extends RapidsBufferStore(StorageTier.HOST) { - - override protected def spillableOnAdd: Boolean = false - - override def getMaxSize: Option[Long] = maxSize - - def addBuffer( - id: RapidsBufferId, - buffer: HostMemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBuffer = { - buffer.incRefCount() - val rapidsBuffer = new RapidsHostMemoryBuffer( - id, - buffer.getLength, - tableMeta, - initialSpillPriority, - buffer) - freeOnExcept(rapidsBuffer) { _ => - logDebug(s"Adding host buffer for: [id=$id, size=${buffer.getLength}, " + - s"uncompressed=${rapidsBuffer.meta.bufferMeta.uncompressedSize}, " + - s"meta_id=${tableMeta.bufferMeta.id}, " + - s"meta_size=${tableMeta.bufferMeta.size}]") - addBuffer(rapidsBuffer, needsSync) - rapidsBuffer - } - } - - def addBatch(id: RapidsBufferId, - hostCb: ColumnarBatch, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBuffer = { - RapidsHostColumnVector.incRefCounts(hostCb) - val rapidsBuffer = new RapidsHostColumnarBatch( - id, - hostCb, - initialSpillPriority) - freeOnExcept(rapidsBuffer) { _ => - addBuffer(rapidsBuffer, needsSync) - rapidsBuffer - } - } - - override protected def trySpillToMaximumSize( - buffer: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Boolean = { - maxSize.forall { ms => - // this spillStore has a maximum size requirement (host only). We need to spill from it - // in order to make room for `buffer`. - val targetTotalSize = ms - buffer.memoryUsedBytes - if (targetTotalSize < 0) { - // lets not spill to host when the buffer we are about - // to spill is larger than our limit - false - } else { - val amountSpilled = synchronousSpill(targetTotalSize, catalog, stream) - if (amountSpilled != 0) { - logDebug(s"Task ${TaskContext.get.taskAttemptId()} spilled $amountSpilled bytes from" + - s"${name} to make room for ${buffer.id}") - } - // if after spill we can fit the new buffer, return true - buffer.memoryUsedBytes <= (ms - currentSize) - } - } - } - - override protected def createBuffer( - other: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] = { - val wouldFit = trySpillToMaximumSize(other, catalog, stream) - if (!wouldFit) { - // skip host - logWarning(s"Buffer $other with size ${other.memoryUsedBytes} does not fit " + - s"in the host store, skipping tier.") - None - } else { - // If the other is from the local disk store, we are unspilling to host memory. - if (other.storageTier == StorageTier.DISK) { - logDebug(s"copying RapidsDiskStore buffer ${other.id} to a HostMemoryBuffer") - val hostBuffer = other.getMemoryBuffer.asInstanceOf[HostMemoryBuffer] - Some(new RapidsHostMemoryBuffer( - other.id, - hostBuffer.getLength(), - other.meta, - applyPriorityOffset(other.getSpillPriority, HOST_MEMORY_BUFFER_SPILL_OFFSET), - hostBuffer)) - } else { - withResource(other.getCopyIterator) { otherBufferIterator => - val isChunked = otherBufferIterator.isChunked - val totalCopySize = otherBufferIterator.getTotalCopySize - closeOnExcept(HostAlloc.tryAlloc(totalCopySize)) { hb => - hb.map { hostBuffer => - val spillNs = GpuTaskMetrics.get.spillToHostTime { - var hostOffset = 0L - val start = System.nanoTime() - while (otherBufferIterator.hasNext) { - val otherBuffer = otherBufferIterator.next() - withResource(otherBuffer) { _ => - otherBuffer match { - case devBuffer: DeviceMemoryBuffer => - hostBuffer.copyFromMemoryBufferAsync( - hostOffset, devBuffer, 0, otherBuffer.getLength, stream) - hostOffset += otherBuffer.getLength - case _ => - throw new IllegalStateException("copying from buffer without device memory") - } - } - } - stream.sync() - System.nanoTime() - start - } - val szMB = (totalCopySize.toDouble / 1024.0 / 1024.0).toLong - val bw = (szMB.toDouble / (spillNs.toDouble / 1000000000.0)).toLong - logDebug(s"Spill to host (chunked=$isChunked) " + - s"size=$szMB MiB bandwidth=$bw MiB/sec") - new RapidsHostMemoryBuffer( - other.id, - totalCopySize, - other.meta, - applyPriorityOffset(other.getSpillPriority, HOST_MEMORY_BUFFER_SPILL_OFFSET), - hostBuffer) - }.orElse { - // skip host - logWarning(s"Buffer $other with size ${other.memoryUsedBytes} does not fit " + - s"in the host store, skipping tier.") - None - } - } - } - } - } - } - - def numBytesFree: Option[Long] = maxSize.map(_ - currentSize) - - class RapidsHostMemoryBuffer( - id: RapidsBufferId, - size: Long, - meta: TableMeta, - spillPriority: Long, - buffer: HostMemoryBuffer) - extends RapidsBufferBase(id, meta, spillPriority) - with RapidsBufferChannelWritable - with MemoryBuffer.EventHandler { - override val storageTier: StorageTier = StorageTier.HOST - - override def getMemoryBuffer: MemoryBuffer = getHostMemoryBuffer - - override def getHostMemoryBuffer: HostMemoryBuffer = synchronized { - buffer.synchronized { - setSpillable(this, false) - buffer.incRefCount() - buffer - } - } - - override def writeToChannel(outputChannel: WritableByteChannel, ignored: Cuda.Stream): Long = { - var written: Long = 0L - val iter = new HostByteBufferIterator(buffer) - iter.foreach { bb => - try { - while (bb.hasRemaining) { - written += outputChannel.write(bb) - } - } finally { - RapidsStorageUtils.dispose(bb) - } - } - written - } - - override def updateSpillability(): Unit = { - if (buffer.getRefCount == 1) { - setSpillable(this, true) - } - } - - override protected def releaseResources(): Unit = { - buffer.close() - } - - /** The size of this buffer in bytes. */ - override val memoryUsedBytes: Long = size - - // If this require triggers, we are re-adding a `HostMemoryBuffer` outside of - // the catalog lock, which should not possible. The event handler is set to null - // when we free the `RapidsHostMemoryBuffer` and if the buffer is not free, we - // take out another handle (in the catalog). - HostAlloc.addEventHandler(buffer, this) - - /** - * Override from the MemoryBuffer.EventHandler interface. - * - * If we are being invoked we have the `buffer` lock, as this callback - * is being invoked from `MemoryBuffer.close` - * - * @param refCount - buffer's current refCount - */ - override def onClosed(refCount: Int): Unit = { - // refCount == 1 means only 1 reference exists to `buffer` in the - // RapidsHostMemoryBuffer (we own it) - if (refCount == 1) { - // setSpillable is being called here as an extension of `MemoryBuffer.close()` - // we hold the MemoryBuffer lock and we could be called from a Spark task thread - // Since we hold the MemoryBuffer lock, `incRefCount` waits for us. The only other - // call to `setSpillable` is also under this same MemoryBuffer lock (see: - // `getMemoryBuffer`) - setSpillable(this, true) - } - } - - /** - * We overwrite free to make sure we don't have a handler for the underlying - * buffer, since this `RapidsBuffer` is no longer tracked. - */ - override def free(): Unit = synchronized { - if (isValid) { - // it is going to be invalid when calling super.free() - HostAlloc.removeEventHandler(buffer, this) - } - super.free() - } - } - - /** - * A per cuDF host column event handler that handles calls to .close() - * inside of the `HostColumnVector` lock. - */ - class RapidsHostColumnEventHandler - extends HostColumnVector.EventHandler { - - // Every RapidsHostColumnarBatch that references this column has an entry in this map. - // The value represents the number of times (normally 1) that a ColumnVector - // appears in the RapidsHostColumnarBatch. This is also the HosColumnVector refCount at which - // the column is considered spillable. - // The map is protected via the ColumnVector lock. - private val registration = new mutable.HashMap[RapidsHostColumnarBatch, Int]() - - /** - * Every RapidsHostColumnarBatch iterates through its columns and either creates - * a `RapidsHostColumnEventHandler` object and associates it with the column's - * `eventHandler` or calls into the existing one, and registers itself. - * - * The registration has two goals: it accounts for repetition of a column - * in a `RapidsHostColumnarBatch`. If a batch has the same column repeated it must adjust - * the refCount at which this column is considered spillable. - * - * The second goal is to account for aliasing. If two host batches alias this column - * we are going to mark it as non spillable. - * - * @param rapidsHostCb - the host batch that is registering itself with this tracker - */ - def register(rapidsHostCb: RapidsHostColumnarBatch, repetition: Int): Unit = { - registration.put(rapidsHostCb, repetition) - } - - /** - * This is invoked during `RapidsHostColumnarBatch.free` in order to remove the entry - * in `registration`. - * - * @param rapidsHostCb - the batch that is de-registering itself - */ - def deregister(rapidsHostCb: RapidsHostColumnarBatch): Unit = { - registration.remove(rapidsHostCb) - } - - // called with the cudf HostColumnVector lock held from cuDF's side - override def onClosed(cudfCv: HostColumnVector, refCount: Int): Unit = { - // we only handle spillability if there is a single batch registered - // (no aliasing) - if (registration.size == 1) { - val (rapidsHostCb, spillableRefCount) = registration.head - if (spillableRefCount == refCount) { - rapidsHostCb.onColumnSpillable(cudfCv) - } - } - } - } - - /** - * A `RapidsHostColumnarBatch` is the spill store holder of ColumnarBatch backed by - * HostColumnVector. - * - * This class owns the host batch and will close it when `close` is called. - * - * @param id the `RapidsBufferId` this batch is associated with - * @param batch the host ColumnarBatch we are managing - * @param spillPriority a starting spill priority - */ - class RapidsHostColumnarBatch( - id: RapidsBufferId, - hostCb: ColumnarBatch, - spillPriority: Long) - extends RapidsBufferBase( - id, - null, - spillPriority) - with RapidsBufferChannelWritable - with RapidsHostBatchBuffer { - - override val storageTier: StorageTier = StorageTier.HOST - - // By default all columns are NOT spillable since we are not the only owners of - // the columns (the caller is holding onto a ColumnarBatch that will be closed - // after instantiation, triggering onClosed callbacks) - // This hash set contains the columns that are currently spillable. - private val columnSpillability = new ConcurrentHashMap[HostColumnVector, Boolean]() - - private val numDistinctColumns = RapidsHostColumnVector.extractBases(hostCb).distinct.size - - // we register our event callbacks as the very first action to deal with - // spillability - registerOnCloseEventHandler() - - /** Release the underlying resources for this buffer. */ - override protected def releaseResources(): Unit = { - hostCb.close() - } - - override def meta: TableMeta = { - null - } - - // This is the current size in batch form. It is to be used while this - // batch hasn't migrated to another store. - override val memoryUsedBytes: Long = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) - - /** - * Mark a column as spillable - * - * @param column the ColumnVector to mark as spillable - */ - def onColumnSpillable(column: HostColumnVector): Unit = { - columnSpillability.put(column, true) - updateSpillability() - } - - /** - * Update the spillability state of this RapidsHostColumnarBatch. This is invoked from - * two places: - * - * - from the onColumnSpillable callback, which is invoked from a - * HostColumnVector.EventHandler.onClosed callback. - * - * - after adding a batch to the store to mark the batch as spillable if - * all columns are spillable. - */ - override def updateSpillability(): Unit = { - setSpillable(this, columnSpillability.size == numDistinctColumns) - } - - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - throw new UnsupportedOperationException( - "RapidsHostColumnarBatch does not support getColumnarBatch") - } - - override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - columnSpillability.clear() - setSpillable(this, false) - RapidsHostColumnVector.incRefCounts(hostCb) - } - - override def getMemoryBuffer: MemoryBuffer = { - throw new UnsupportedOperationException( - "RapidsHostColumnarBatch does not support getMemoryBuffer") - } - - override def getCopyIterator: RapidsBufferCopyIterator = { - throw new UnsupportedOperationException( - "RapidsHostColumnarBatch does not support getCopyIterator") - } - - override def writeToChannel(outputChannel: WritableByteChannel, ignored: Cuda.Stream): Long = { - withResource(Channels.newOutputStream(outputChannel)) { outputStream => - withResource(new DataOutputStream(outputStream)) { dos => - val columns = RapidsHostColumnVector.extractBases(hostCb) - JCudfSerialization.writeToStream(columns, dos, 0, hostCb.numRows()) - dos.size() - } - } - } - - override def free(): Unit = { - // lets remove our handler from the chain of handlers for each column - removeOnCloseEventHandler() - super.free() - } - - private def registerOnCloseEventHandler(): Unit = { - val columns = RapidsHostColumnVector.extractBases(hostCb) - // cudfColumns could contain duplicates. We need to take this into account when we are - // deciding the floor refCount for a duplicated column - val repetitionPerColumn = new mutable.HashMap[HostColumnVector, Int]() - columns.foreach { col => - val repetitionCount = repetitionPerColumn.getOrElse(col, 0) - repetitionPerColumn(col) = repetitionCount + 1 - } - repetitionPerColumn.foreach { case (distinctCv, repetition) => - // lock the column because we are setting its event handler, and we are inspecting - // its refCount. - distinctCv.synchronized { - val eventHandler = distinctCv.getEventHandler match { - case null => - val eventHandler = new RapidsHostColumnEventHandler - distinctCv.setEventHandler(eventHandler) - eventHandler - case existing: RapidsHostColumnEventHandler => - existing - case other => - throw new IllegalStateException( - s"Invalid column event handler $other") - } - eventHandler.register(this, repetition) - if (repetition == distinctCv.getRefCount) { - onColumnSpillable(distinctCv) - } - } - } - } - - // this method is called from free() - private def removeOnCloseEventHandler(): Unit = { - val distinctColumns = RapidsHostColumnVector.extractBases(hostCb).distinct - distinctColumns.foreach { distinctCv => - distinctCv.synchronized { - distinctCv.getEventHandler match { - case eventHandler: RapidsHostColumnEventHandler => - eventHandler.deregister(this) - case t => - throw new IllegalStateException( - s"Invalid column event handler $t") - } - } - } - } - } -} - - diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsSerializerManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsSerializerManager.scala index ab4a6398d32..74aed062142 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsSerializerManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsSerializerManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,8 +20,8 @@ import java.io.{InputStream, OutputStream} import org.apache.spark.SparkConf import org.apache.spark.io.CompressionCodec -import org.apache.spark.sql.rapids.TempSpillBufferId import org.apache.spark.sql.rapids.execution.TrampolineUtil +import org.apache.spark.storage.BlockId /** @@ -44,22 +44,20 @@ class RapidsSerializerManager (conf: SparkConf) { private lazy val compressionCodec: CompressionCodec = TrampolineUtil.createCodec(conf) - // Whether it really goes through crypto streams replies on Spark configuration - // (e.g., `` `spark.io.encryption.enabled` ``) and the existence of crypto keys. - def wrapStream(bufferId: RapidsBufferId, s: OutputStream): OutputStream = { - if(isRapidsSpill(bufferId)) wrapForCompression(bufferId, wrapForEncryption(s)) else s + def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = { + if(isRapidsSpill(blockId)) wrapForCompression(blockId, wrapForEncryption(s)) else s } - def wrapStream(bufferId: RapidsBufferId, s: InputStream): InputStream = { - if(isRapidsSpill(bufferId)) wrapForCompression(bufferId, wrapForEncryption(s)) else s + def wrapStream(blockId: BlockId, s: InputStream): InputStream = { + if(isRapidsSpill(blockId)) wrapForCompression(blockId, wrapForEncryption(s)) else s } - private[this] def wrapForCompression(bufferId: RapidsBufferId, s: InputStream): InputStream = { - if (shouldCompress(bufferId)) compressionCodec.compressedInputStream(s) else s + private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } - private[this] def wrapForCompression(bufferId: RapidsBufferId, s: OutputStream): OutputStream = { - if (shouldCompress(bufferId)) compressionCodec.compressedOutputStream(s) else s + private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } private[this] def wrapForEncryption(s: InputStream): InputStream = { @@ -70,18 +68,15 @@ class RapidsSerializerManager (conf: SparkConf) { if (serializerManager != null) serializerManager.wrapForEncryption(s) else s } - def isRapidsSpill(bufferId: RapidsBufferId): Boolean = { - bufferId match { - case _: TempSpillBufferId => true - case _ => false - } + def isRapidsSpill(blockId: BlockId): Boolean = { + !blockId.isShuffle } - private[this] def shouldCompress(bufferId: RapidsBufferId): Boolean = { - bufferId match { - case _: TempSpillBufferId => compressSpill - case _: ShuffleBufferId | _: ShuffleReceivedBufferId => false - case _ => false + private[this] def shouldCompress(blockId: BlockId): Boolean = { + if (!blockId.isShuffle) { + compressSpill + } else { + false } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index a587c5cd7ae..542fb2deb2d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,52 +16,76 @@ package com.nvidia.spark.rapids -import java.io.File import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import java.util.function.{Consumer, IntUnaryOperator} import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer} +import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, Table} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.spill.{SpillableDeviceBufferHandle, SpillableHandle} -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.rapids.RapidsDiskBlockManager import org.apache.spark.sql.rapids.execution.TrampolineUtil +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.ShuffleBlockId /** Identifier for a shuffle buffer that holds the data for a table */ case class ShuffleBufferId( blockId: ShuffleBlockId, - override val tableId: Int) extends RapidsBufferId { + tableId: Int) { val shuffleId: Int = blockId.shuffleId val mapId: Long = blockId.mapId - - override val canShareDiskPaths: Boolean = true - - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = { - diskBlockManager.getFile(blockId) - } } /** Catalog for lookup of shuffle buffers by block ID */ -class ShuffleBufferCatalog( - catalog: RapidsBufferCatalog, - diskBlockManager: RapidsDiskBlockManager) extends Logging { +class ShuffleBufferCatalog extends Logging { + /** + * Information stored for each active shuffle. + * A shuffle block can be comprised of multiple batches. Each batch + * is given a `ShuffleBufferId`. + */ + private type ShuffleInfo = + ConcurrentHashMap[ShuffleBlockId, ArrayBuffer[ShuffleBufferId]] + + private val bufferIdToHandle = + new ConcurrentHashMap[ + ShuffleBufferId, + (Option[SpillableDeviceBufferHandle], TableMeta)]() - private val bufferIdToHandle = new ConcurrentHashMap[RapidsBufferId, RapidsBufferHandle]() + /** shuffle information for each active shuffle */ + private[this] val activeShuffles = new ConcurrentHashMap[Int, ShuffleInfo] + + /** Mapping of table ID to shuffle buffer ID */ + private[this] val tableMap = new ConcurrentHashMap[Int, ShuffleBufferId] + + /** Tracks the next table identifier */ + private[this] val tableIdCounter = new AtomicInteger(0) private def trackCachedHandle( bufferId: ShuffleBufferId, - bufferHandle: RapidsBufferHandle): Unit = { - bufferIdToHandle.put(bufferId, bufferHandle) + handle: SpillableDeviceBufferHandle, + meta: TableMeta): Unit = { + bufferIdToHandle.put(bufferId, (Some(handle), meta)) + } + + private def trackDegenerate(bufferId: ShuffleBufferId, + meta: TableMeta): Unit = { + bufferIdToHandle.put(bufferId, (None, meta)) } def removeCachedHandles(): Unit = { - bufferIdToHandle.forEach { (_, handle) => removeBuffer(handle) } + val bufferIt = bufferIdToHandle.keySet().iterator() + while (bufferIt.hasNext) { + val buffer = bufferIt.next() + val (maybeHandle, _) = bufferIdToHandle.remove(buffer) + tableMap.remove(buffer.tableId) + maybeHandle.foreach(_.close()) + } } /** @@ -70,56 +94,47 @@ class ShuffleBufferCatalog( * The refcount of the underlying device buffer will be incremented so the contiguous table * can be closed before this buffer is destroyed. * - * @param blockId Spark's `ShuffleBlockId` that identifies this buffer - * @param contigTable contiguous table to track in storage + * @param blockId Spark's `ShuffleBlockId` that identifies this buffer + * @param contigTable contiguous table to track in storage * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) * @return RapidsBufferHandle identifying this table */ - def addContiguousTable( - blockId: ShuffleBlockId, - contigTable: ContiguousTable, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = { - val bufferId = nextShuffleBufferId(blockId) + def addContiguousTable(blockId: ShuffleBlockId, + contigTable: ContiguousTable, + initialSpillPriority: Long): Unit = { withResource(contigTable) { _ => - val handle = catalog.addContiguousTable( - bufferId, - contigTable, - initialSpillPriority, - needsSync) - trackCachedHandle(bufferId, handle) - handle + val bufferId = nextShuffleBufferId(blockId) + val tableMeta = MetaUtils.buildTableMeta(bufferId.tableId, contigTable) + val buff = contigTable.getBuffer + buff.incRefCount() + val handle = SpillableDeviceBufferHandle(buff) + trackCachedHandle(bufferId, handle, tableMeta) } } /** * Adds a buffer to the device storage, taking ownership of the buffer. * - * @param blockId Spark's `ShuffleBlockId` that identifies this buffer - * @param buffer buffer that will be owned by the store - * @param tableMeta metadata describing the buffer layout + * @param blockId Spark's `ShuffleBlockId` that identifies this buffer + * @param compressedBatch Compressed ColumnarBatch * @param initialSpillPriority starting spill priority value for the buffer * @return RapidsBufferHandle associated with this buffer */ - def addBuffer( - blockId: ShuffleBlockId, - buffer: DeviceMemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = { - val bufferId = nextShuffleBufferId(blockId) - // update the table metadata for the buffer ID generated above - tableMeta.bufferMeta.mutateId(bufferId.tableId) - val handle = catalog.addBuffer( - bufferId, - buffer, - tableMeta, - initialSpillPriority, - needsSync) - trackCachedHandle(bufferId, handle) - handle + def addCompressedBatch( + blockId: ShuffleBlockId, + compressedBatch: ColumnarBatch, + initialSpillPriority: Long): Unit = { + withResource(compressedBatch) { _ => + val bufferId = nextShuffleBufferId(blockId) + val compressed = compressedBatch.column(0).asInstanceOf[GpuCompressedColumnVector] + val tableMeta = compressed.getTableMeta + // update the table metadata for the buffer ID generated above + tableMeta.bufferMeta().mutateId(bufferId.tableId) + val buff = compressed.getTableBuffer + buff.incRefCount() + val handle = SpillableDeviceBufferHandle(buff) + trackCachedHandle(bufferId, handle, tableMeta) + } } /** @@ -128,39 +143,18 @@ class ShuffleBufferCatalog( */ def addDegenerateRapidsBuffer( blockId: ShuffleBlockId, - meta: TableMeta): RapidsBufferHandle = { + meta: TableMeta): Unit = { val bufferId = nextShuffleBufferId(blockId) - val handle = catalog.registerDegenerateBuffer(bufferId, meta) - trackCachedHandle(bufferId, handle) - handle + trackDegenerate(bufferId, meta) } - /** - * Information stored for each active shuffle. - * NOTE: ArrayBuffer in blockMap must be explicitly locked when using it! - * - * @param blockMap mapping of block ID to array of buffers for the block - */ - private case class ShuffleInfo( - blockMap: ConcurrentHashMap[ShuffleBlockId, ArrayBuffer[ShuffleBufferId]]) - - /** shuffle information for each active shuffle */ - private[this] val activeShuffles = new ConcurrentHashMap[Int, ShuffleInfo] - - /** Mapping of table ID to shuffle buffer ID */ - private[this] val tableMap = new ConcurrentHashMap[Int, ShuffleBufferId] - - /** Tracks the next table identifier */ - private[this] val tableIdCounter = new AtomicInteger(0) - /** * Register a new shuffle. * This must be called before any buffer identifiers associated with this shuffle can be tracked. * @param shuffleId shuffle identifier */ def registerShuffle(shuffleId: Int): Unit = { - activeShuffles.computeIfAbsent(shuffleId, _ => ShuffleInfo( - new ConcurrentHashMap[ShuffleBlockId, ArrayBuffer[ShuffleBufferId]])) + activeShuffles.computeIfAbsent(shuffleId, _ => new ShuffleInfo) } /** Frees all buffers that correspond to the specified shuffle. */ @@ -174,22 +168,11 @@ class ShuffleBufferCatalog( // NOTE: Not synchronizing array buffer because this shuffle should be inactive. bufferIds.foreach { id => tableMap.remove(id.tableId) - val handle = bufferIdToHandle.remove(id) - if (handle != null) { - handle.close() - } - } - } - info.blockMap.forEachValue(Long.MaxValue, bufferRemover) - - val fileRemover: Consumer[ShuffleBlockId] = { blockId => - val file = diskBlockManager.getFile(blockId) - logDebug(s"Deleting file $file") - if (!file.delete() && file.exists()) { - logWarning(s"Unable to delete $file") + val handleAndMeta = bufferIdToHandle.remove(id) + handleAndMeta._1.foreach(_.close()) } } - info.blockMap.forEachKey(Long.MaxValue, fileRemover) + info.forEachValue(Long.MaxValue, bufferRemover) } else { // currently shuffle unregister can get called on the driver which never saw a register if (!TrampolineUtil.isDriver(SparkEnv.get)) { @@ -201,12 +184,12 @@ class ShuffleBufferCatalog( def hasActiveShuffle(shuffleId: Int): Boolean = activeShuffles.containsKey(shuffleId) /** Get all the buffer IDs that correspond to a shuffle block identifier. */ - def blockIdToBuffersIds(blockId: ShuffleBlockId): Array[ShuffleBufferId] = { + private def blockIdToBuffersIds(blockId: ShuffleBlockId): Array[ShuffleBufferId] = { val info = activeShuffles.get(blockId.shuffleId) if (info == null) { - throw new NoSuchElementException(s"unknown shuffle $blockId.shuffleId") + throw new NoSuchElementException(s"unknown shuffle ${blockId.shuffleId}") } - val entries = info.blockMap.get(blockId) + val entries = info.get(blockId) if (entries == null) { throw new NoSuchElementException(s"unknown shuffle block $blockId") } @@ -215,27 +198,61 @@ class ShuffleBufferCatalog( } } - def blockIdToBufferHandles(blockId: ShuffleBlockId): Array[RapidsBufferHandle] = { + def getColumnarBatchIterator( + blockId: ShuffleBlockId, + sparkTypes: Array[DataType]): Iterator[ColumnarBatch] = { + val bufferIDs = blockIdToBuffersIds(blockId) + bufferIDs.iterator.map { bId => + GpuSemaphore.acquireIfNecessary(TaskContext.get) + val (maybeHandle, meta) = bufferIdToHandle.get(bId) + maybeHandle.map { handle => + withResource(handle.materialize()) { buff => + val bufferMeta = meta.bufferMeta() + if (bufferMeta == null || bufferMeta.codecBufferDescrsLength == 0) { + MetaUtils.getBatchFromMeta(buff, meta, sparkTypes) + } else { + GpuCompressedColumnVector.from(buff, meta) + } + } + }.getOrElse { + // degenerate table (handle is None) + // make a batch out of denegerate meta + val rowCount = meta.rowCount + val packedMeta = meta.packedMetaAsByteBuffer() + if (packedMeta != null) { + withResource(DeviceMemoryBuffer.allocate(0)) { deviceBuffer => + withResource(Table.fromPackedTable( + meta.packedMetaAsByteBuffer(), deviceBuffer)) { table => + GpuColumnVectorFromBuffer.from(table, deviceBuffer, meta, sparkTypes) + } + } + } else { + // no packed metadata, must be a table with zero columns + new ColumnarBatch(Array.empty, rowCount.toInt) + } + } + } + } + + /** Get all the buffer metadata that correspond to a shuffle block identifier. */ + def blockIdToMetas(blockId: ShuffleBlockId): Seq[TableMeta] = { val info = activeShuffles.get(blockId.shuffleId) if (info == null) { - throw new NoSuchElementException(s"unknown shuffle $blockId.shuffleId") + throw new NoSuchElementException(s"unknown shuffle ${blockId.shuffleId}") } - val entries = info.blockMap.get(blockId) + val entries = info.get(blockId) if (entries == null) { throw new NoSuchElementException(s"unknown shuffle block $blockId") } - entries.synchronized { - entries.map(bufferIdToHandle.get).toArray - } - } - - /** Get all the buffer metadata that correspond to a shuffle block identifier. */ - def blockIdToMetas(blockId: ShuffleBlockId): Seq[TableMeta] = { - blockIdToBuffersIds(blockId).map(catalog.getBufferMeta) + entries.synchronized { + entries.map(bufferIdToHandle.get).map { case (_, meta) => + meta + } + }.toSeq } /** Allocate a new shuffle buffer identifier and update the shuffle block mapping. */ - def nextShuffleBufferId(blockId: ShuffleBlockId): ShuffleBufferId = { + private def nextShuffleBufferId(blockId: ShuffleBlockId): ShuffleBufferId = { val info = activeShuffles.get(blockId.shuffleId) if (info == null) { throw new IllegalStateException(s"unknown shuffle ${blockId.shuffleId}") @@ -249,7 +266,7 @@ class ShuffleBufferCatalog( } // associate this new buffer with the shuffle block - val blockBufferIds = info.blockMap.computeIfAbsent(blockId, _ => + val blockBufferIds = info.computeIfAbsent(blockId, _ => new ArrayBuffer[ShuffleBufferId]) blockBufferIds.synchronized { blockBufferIds.append(id) @@ -258,35 +275,29 @@ class ShuffleBufferCatalog( } /** Lookup the shuffle buffer handle that corresponds to the specified table identifier. */ - def getShuffleBufferHandle(tableId: Int): RapidsBufferHandle = { + def getShuffleBufferHandle(tableId: Int): RapidsShuffleHandle = { val shuffleBufferId = tableMap.get(tableId) if (shuffleBufferId == null) { throw new NoSuchElementException(s"unknown table ID $tableId") } - bufferIdToHandle.get(shuffleBufferId) + val (maybeHandle, meta) = bufferIdToHandle.get(shuffleBufferId) + maybeHandle match { + case Some(spillable) => + RapidsShuffleHandle(spillable, meta) + case None => + throw new IllegalStateException( + "a buffer handle could not be obtained for a degenerate buffer") + } } /** * Update the spill priority of a shuffle buffer that soon will be read locally. * @param handle shuffle buffer handle of buffer to update */ - def updateSpillPriorityForLocalRead(handle: RapidsBufferHandle): Unit = { - handle.setSpillPriority(SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY) - } - - /** - * Lookup the shuffle buffer that corresponds to the specified buffer handle and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * @param handle shuffle buffer handle - * @return shuffle buffer that has been acquired - */ - def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = { - val buffer = catalog.acquireBuffer(handle) - // Shuffle buffers that have been read are less likely to be read again, - // so update the spill priority based on this access - handle.setSpillPriority(SpillPriorities.getShuffleOutputBufferReadPriority) - buffer - } + // TODO: AB: priorities + //def updateSpillPriorityForLocalRead(handle: RapidsBufferHandle): Unit = { + // handle.setSpillPriority(SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY) + //} /** * Remove a buffer and table given a buffer handle @@ -294,9 +305,7 @@ class ShuffleBufferCatalog( * the handle being removed is not being utilized by another thread. * @param handle buffer handle */ - def removeBuffer(handle: RapidsBufferHandle): Unit = { - val id = handle.id - tableMap.remove(id.tableId) + def removeBuffer(handle: SpillableHandle): Unit = { handle.close() } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala index 0ff4f9278be..450622ef3ba 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,49 +16,25 @@ package com.nvidia.spark.rapids -import java.io.File -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicInteger -import java.util.function.IntUnaryOperator - -import ai.rapids.cudf.DeviceMemoryBuffer +import ai.rapids.cudf.{DeviceMemoryBuffer, Table} import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableColumn import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle import org.apache.spark.internal.Logging -import org.apache.spark.sql.rapids.RapidsDiskBlockManager - -/** Identifier for a shuffle buffer that holds the data for a table on the read side */ - -case class ShuffleReceivedBufferId( - override val tableId: Int) extends RapidsBufferId { - override val canShareDiskPaths: Boolean = false +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = { - diskBlockManager.getFile(s"temp_shuffle_${tableId}") +case class RapidsShuffleHandle( + spillable: SpillableDeviceBufferHandle, tableMeta: TableMeta) extends AutoCloseable { + override def close(): Unit = { + spillable.safeClose() } } /** Catalog for lookup of shuffle buffers by block ID */ -class ShuffleReceivedBufferCatalog( - catalog: RapidsBufferCatalog) extends Logging { - - /** Mapping of table ID to shuffle buffer ID */ - private[this] val tableMap = new ConcurrentHashMap[Int, ShuffleReceivedBufferId] - - /** Tracks the next table identifier */ - private[this] val tableIdCounter = new AtomicInteger(0) - - /** Allocate a new shuffle buffer identifier and update the shuffle block mapping. */ - private def nextShuffleReceivedBufferId(): ShuffleReceivedBufferId = { - val tableId = tableIdCounter.getAndUpdate(ShuffleReceivedBufferCatalog.TABLE_ID_UPDATER) - val id = ShuffleReceivedBufferId(tableId) - val prev = tableMap.put(tableId, id) - if (prev != null) { - throw new IllegalStateException(s"table ID $tableId is already in use") - } - id - } +class ShuffleReceivedBufferCatalog() extends Logging { /** * Adds a buffer to the device storage, taking ownership of the buffer. @@ -70,64 +46,52 @@ class ShuffleReceivedBufferCatalog( * @param initialSpillPriority starting spill priority value for the buffer * @param needsSync tells the store a synchronize in the current stream is required * before storing this buffer - * @return RapidsBufferHandle associated with this buffer + * @return RapidsShuffleHandle associated with this buffer */ def addBuffer( buffer: DeviceMemoryBuffer, tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = { - val bufferId = nextShuffleReceivedBufferId() - tableMeta.bufferMeta.mutateId(bufferId.tableId) - // when we call `addBuffer` the store will incRefCount - withResource(buffer) { _ => - catalog.addBuffer( - bufferId, - buffer, - tableMeta, - initialSpillPriority, - needsSync) - } + initialSpillPriority: Long): RapidsShuffleHandle = { + RapidsShuffleHandle(SpillableDeviceBufferHandle(buffer), tableMeta) } /** - * Adds a degenerate buffer (zero rows or columns) + * Adds a degenerate batch (zero rows or columns), described only by metadata. * * @param meta metadata describing the buffer layout - * @return RapidsBufferHandle associated with this buffer + * @return RapidsShuffleHandle associated with this buffer */ - def addDegenerateRapidsBuffer( - meta: TableMeta): RapidsBufferHandle = { - val bufferId = nextShuffleReceivedBufferId() - catalog.registerDegenerateBuffer(bufferId, meta) + def addDegenerateBatch(meta: TableMeta): RapidsShuffleHandle = { + RapidsShuffleHandle(null, meta) } - /** - * Lookup the shuffle buffer that corresponds to the specified shuffle buffer - * handle and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * - * @param handle shuffle buffer handle - * @return shuffle buffer that has been acquired - */ - def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = catalog.acquireBuffer(handle) - - /** - * Remove a buffer and table given a buffer handle - * NOTE: This function is not thread safe! The caller should only invoke if - * the handle being removed is not being utilized by another thread. - * @param handle buffer handle - */ - def removeBuffer(handle: RapidsBufferHandle): Unit = { - val id = handle.id - tableMap.remove(id.tableId) - handle.close() - } -} - -object ShuffleReceivedBufferCatalog{ - private val MAX_TABLE_ID = Integer.MAX_VALUE - private val TABLE_ID_UPDATER = new IntUnaryOperator { - override def applyAsInt(i: Int): Int = if (i < MAX_TABLE_ID) i + 1 else 0 + def getColumnarBatchAndRemove(handle: RapidsShuffleHandle, + sparkTypes: Array[DataType]): (ColumnarBatch, Long) = { + withResource(handle) { _ => + val spillable = handle.spillable + var memoryUsedBytes = 0L + val cb = if (spillable != null) { + memoryUsedBytes = spillable.sizeInBytes + withResource(spillable.materialize()) { buff => + MetaUtils.getBatchFromMeta(buff, handle.tableMeta, sparkTypes) + } + } else { + val rowCount = handle.tableMeta.rowCount + val packedMeta = handle.tableMeta.packedMetaAsByteBuffer() + if (packedMeta != null) { + withResource(DeviceMemoryBuffer.allocate(0)) { deviceBuffer => + withResource(Table.fromPackedTable( + handle.tableMeta.packedMetaAsByteBuffer(), deviceBuffer)) { table => + GpuColumnVectorFromBuffer.from( + table, deviceBuffer, handle.tableMeta, sparkTypes) + } + } + } else { + // no packed metadata, must be a table with zero columns + new ColumnarBatch(Array.empty, rowCount.toInt) + } + } + (cb, memoryUsedBytes) + } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index e1f45c34180..7b247af42d3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -16,8 +16,9 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.{ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer} +import com.nvidia.spark.rapids.Arm.closeOnExcept +import com.nvidia.spark.rapids.spill.{SpillableColumnarBatchFromBufferHandle, SpillableColumnarBatchHandle, SpillableCompressedColumnarBatchHandle, SpillableDeviceBufferHandle, SpillableHostBufferHandle, SpillableHostColumnarBatchHandle} import org.apache.spark.TaskContext import org.apache.spark.sql.types.DataType @@ -93,7 +94,7 @@ class JustRowsColumnarBatch(numRows: Int) * use `SpillableColumnarBatch.apply` instead. */ class SpillableColumnarBatchImpl ( - handle: RapidsBufferHandle, + handle: SpillableColumnarBatchHandle, rowCount: Int, sparkTypes: Array[DataType]) extends SpillableColumnarBatch { @@ -105,27 +106,128 @@ class SpillableColumnarBatchImpl ( */ override def numRows(): Int = rowCount - private def withRapidsBuffer[T](fn: RapidsBuffer => T): T = { - withResource(RapidsBufferCatalog.acquireBuffer(handle)) { rapidsBuffer => - fn(rapidsBuffer) + override lazy val sizeInBytes: Long = handle.approxSizeInBytes + + /** + * Set a new spill priority. + */ + override def setSpillPriority(priority: Long): Unit = { + // TODO: handle.setSpillPriority(priority) + } + + override def getColumnarBatch(): ColumnarBatch = { + GpuSemaphore.acquireIfNecessary(TaskContext.get()) + handle.materialize(sparkTypes) + } + + override def incRefCount(): SpillableColumnarBatch = { + if (refCount <= 0) { + throw new IllegalStateException("Use after free on SpillableColumnarBatchImpl") + } + refCount += 1 + this + } + + /** + * Remove the `ColumnarBatch` from the cache. + */ + override def close(): Unit = { + refCount -= 1 + if (refCount == 0) { + // closing my reference + handle.close() } + // TODO this is causing problems so we need to look into this + // https://github.com/NVIDIA/spark-rapids/issues/10161 + //else if (refCount < 0) { + // throw new IllegalStateException("Double free on SpillableColumnarBatchImpl") + //} } - override lazy val sizeInBytes: Long = - withRapidsBuffer(_.memoryUsedBytes) + override def toString: String = + s"SCB $handle $rowCount ${sparkTypes.toList} $refCount" +} + +class SpillableCompressedColumnarBatchImpl( + handle: SpillableCompressedColumnarBatchHandle, rowCount: Int) + extends SpillableColumnarBatch { + + private var refCount = 1 + + /** + * The number of rows stored in this batch. + */ + override def numRows(): Int = rowCount + + override lazy val sizeInBytes: Long = handle.compressedSizeInBytes /** * Set a new spill priority. */ override def setSpillPriority(priority: Long): Unit = { - handle.setSpillPriority(priority) + // TODO: handle.setSpillPriority(priority) } override def getColumnarBatch(): ColumnarBatch = { - withRapidsBuffer { rapidsBuffer => - GpuSemaphore.acquireIfNecessary(TaskContext.get()) - rapidsBuffer.getColumnarBatch(sparkTypes) + GpuSemaphore.acquireIfNecessary(TaskContext.get()) + handle.materialize() + } + + override def incRefCount(): SpillableColumnarBatch = { + if (refCount <= 0) { + throw new IllegalStateException("Use after free on SpillableColumnarBatchImpl") } + refCount += 1 + this + } + + /** + * Remove the `ColumnarBatch` from the cache. + */ + override def close(): Unit = { + refCount -= 1 + if (refCount == 0) { + // closing my reference + handle.close() + } + // TODO this is causing problems so we need to look into this + // https://github.com/NVIDIA/spark-rapids/issues/10161 + //else if (refCount < 0) { + // throw new IllegalStateException("Double free on SpillableColumnarBatchImpl") + //} + } + + override def toString: String = + s"SCCB $handle $rowCount $refCount" + + override def dataTypes: Array[DataType] = null +} + +class SpillableColumnarBatchFromBufferImpl( + handle: SpillableColumnarBatchFromBufferHandle, + rowCount: Int, + sparkTypes: Array[DataType]) + extends SpillableColumnarBatch { + private var refCount = 1 + + override def dataTypes: Array[DataType] = sparkTypes + /** + * The number of rows stored in this batch. + */ + override def numRows(): Int = rowCount + + override lazy val sizeInBytes: Long = handle.sizeInBytes + + /** + * Set a new spill priority. + */ + override def setSpillPriority(priority: Long): Unit = { + // TODO: handle.setSpillPriority(priority) + } + + override def getColumnarBatch(): ColumnarBatch = { + GpuSemaphore.acquireIfNecessary(TaskContext.get()) + handle.materialize(dataTypes) } override def incRefCount(): SpillableColumnarBatch = { @@ -147,9 +249,9 @@ class SpillableColumnarBatchImpl ( } // TODO this is causing problems so we need to look into this // https://github.com/NVIDIA/spark-rapids/issues/10161 -// else if (refCount < 0) { -// throw new IllegalStateException("Double free on SpillableColumnarBatchImpl") -// } + //else if (refCount < 0) { + // throw new IllegalStateException("Double free on SpillableColumnarBatchImpl") + //} } override def toString: String = @@ -184,10 +286,9 @@ class JustRowsHostColumnarBatch(numRows: Int) * use `SpillableHostColumnarBatch.apply` instead. */ class SpillableHostColumnarBatchImpl ( - handle: RapidsBufferHandle, + handle: SpillableHostColumnarBatchHandle, rowCount: Int, - sparkTypes: Array[DataType], - catalog: RapidsBufferCatalog) + sparkTypes: Array[DataType]) extends SpillableColumnarBatch { private var refCount = 1 @@ -198,27 +299,17 @@ class SpillableHostColumnarBatchImpl ( */ override def numRows(): Int = rowCount - private def withRapidsHostBatchBuffer[T](fn: RapidsHostBatchBuffer => T): T = { - withResource(catalog.acquireHostBatchBuffer(handle)) { rapidsBuffer => - fn(rapidsBuffer) - } - } - - override lazy val sizeInBytes: Long = { - withRapidsHostBatchBuffer(_.memoryUsedBytes) - } + override lazy val sizeInBytes: Long = handle.approxSizeInBytes /** * Set a new spill priority. */ override def setSpillPriority(priority: Long): Unit = { - handle.setSpillPriority(priority) + // TODO: handle.setSpillPriority(priority) } override def getColumnarBatch(): ColumnarBatch = { - withRapidsHostBatchBuffer { hostBatchBuffer => - hostBatchBuffer.getHostColumnarBatch(sparkTypes) - } + handle.materialize(sparkTypes) } override def incRefCount(): SpillableColumnarBatch = { @@ -257,18 +348,29 @@ object SpillableColumnarBatch { */ def apply(batch: ColumnarBatch, priority: Long): SpillableColumnarBatch = { + Cuda.DEFAULT_STREAM.sync() val numRows = batch.numRows() if (batch.numCols() <= 0) { // We consumed it batch.close() new JustRowsColumnarBatch(numRows) } else { - val types = GpuColumnVector.extractTypes(batch) - val handle = addBatch(batch, priority) - new SpillableColumnarBatchImpl( - handle, - numRows, - types) + if (GpuCompressedColumnVector.isBatchCompressed(batch)) { + new SpillableCompressedColumnarBatchImpl( + SpillableCompressedColumnarBatchHandle(batch), + numRows) + } else if (GpuColumnVectorFromBuffer.isFromBuffer(batch)) { + new SpillableColumnarBatchFromBufferImpl( + SpillableColumnarBatchFromBufferHandle(batch), + numRows, + GpuColumnVector.extractTypes(batch) + ) + } else { + new SpillableColumnarBatchImpl( + SpillableColumnarBatchHandle(batch), + numRows, + GpuColumnVector.extractTypes(batch)) + } } } @@ -283,54 +385,11 @@ object SpillableColumnarBatch { ct: ContiguousTable, sparkTypes: Array[DataType], priority: Long): SpillableColumnarBatch = { - withResource(ct) { _ => - val handle = RapidsBufferCatalog.addContiguousTable(ct, priority) - new SpillableColumnarBatchImpl(handle, ct.getRowCount.toInt, sparkTypes) - } - } - - private[this] def allFromSameBuffer(batch: ColumnarBatch): Boolean = { - var bufferAddr = 0L - var isSet = false - val numColumns = batch.numCols() - (0 until numColumns).forall { i => - batch.column(i) match { - case fb: GpuColumnVectorFromBuffer => - if (!isSet) { - bufferAddr = fb.getBuffer.getAddress - isSet = true - true - } else { - bufferAddr == fb.getBuffer.getAddress - } - case _ => false - } - } - } - - private[this] def addBatch( - batch: ColumnarBatch, - initialSpillPriority: Long): RapidsBufferHandle = { - withResource(batch) { batch => - val numColumns = batch.numCols() - if (GpuCompressedColumnVector.isBatchCompressed(batch)) { - val cv = batch.column(0).asInstanceOf[GpuCompressedColumnVector] - val buff = cv.getTableBuffer - RapidsBufferCatalog.addBuffer(buff, cv.getTableMeta, initialSpillPriority) - } else if (GpuPackedTableColumn.isBatchPacked(batch)) { - val cv = batch.column(0).asInstanceOf[GpuPackedTableColumn] - RapidsBufferCatalog.addContiguousTable( - cv.getContiguousTable, - initialSpillPriority) - } else if (numColumns > 0 && - allFromSameBuffer(batch)) { - val cv = batch.column(0).asInstanceOf[GpuColumnVectorFromBuffer] - val buff = cv.getBuffer - RapidsBufferCatalog.addBuffer(buff, cv.getTableMeta, initialSpillPriority) - } else { - RapidsBufferCatalog.addBatch(batch, initialSpillPriority) - } - } + Cuda.DEFAULT_STREAM.sync() + new SpillableColumnarBatchFromBufferImpl( + SpillableColumnarBatchFromBufferHandle(ct, sparkTypes), + ct.getRowCount.toInt, + sparkTypes) } } @@ -342,10 +401,7 @@ object SpillableHostColumnarBatch { * @param batch the batch to make spillable * @param priority the initial spill priority of this batch */ - def apply( - batch: ColumnarBatch, - priority: Long, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton): SpillableColumnarBatch = { + def apply(batch: ColumnarBatch, priority: Long): SpillableColumnarBatch = { val numRows = batch.numRows() if (batch.numCols() <= 0) { // We consumed it @@ -353,45 +409,30 @@ object SpillableHostColumnarBatch { new JustRowsHostColumnarBatch(numRows) } else { val types = RapidsHostColumnVector.extractColumns(batch).map(_.dataType()) - val handle = addHostBatch(batch, priority, catalog) - new SpillableHostColumnarBatchImpl( - handle, - numRows, - types, - catalog) - } - } - - private[this] def addHostBatch( - batch: ColumnarBatch, - initialSpillPriority: Long, - catalog: RapidsBufferCatalog): RapidsBufferHandle = { - withResource(batch) { batch => - catalog.addBatch(batch, initialSpillPriority) + val handle = SpillableHostColumnarBatchHandle(batch) + new SpillableHostColumnarBatchImpl(handle, numRows, types) } } - } + /** * Just like a SpillableColumnarBatch but for buffers. */ class SpillableBuffer( - handle: RapidsBufferHandle) extends AutoCloseable { + handle: SpillableDeviceBufferHandle) extends AutoCloseable { /** * Set a new spill priority. */ def setSpillPriority(priority: Long): Unit = { - handle.setSpillPriority(priority) + // TODO: handle.setSpillPriority(priority) } /** * Use the device buffer. */ def getDeviceBuffer(): DeviceMemoryBuffer = { - withResource(RapidsBufferCatalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.getDeviceMemoryBuffer - } + handle.materialize() } /** @@ -402,9 +443,7 @@ class SpillableBuffer( } override def toString: String = { - val size = withResource(RapidsBufferCatalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.memoryUsedBytes - } + val size = handle.sizeInBytes s"SpillableBuffer size:$size, handle:$handle" } } @@ -416,17 +455,15 @@ class SpillableBuffer( * @param length a metadata-only length that is kept in the `SpillableHostBuffer` * instance. Used in cases where the backing host buffer is larger * than the number of usable bytes. - * @param catalog this was added for tests, it defaults to - * `RapidsBufferCatalog.singleton` in the companion object. */ -class SpillableHostBuffer(handle: RapidsBufferHandle, - val length: Long, - catalog: RapidsBufferCatalog) extends AutoCloseable { +class SpillableHostBuffer(handle: SpillableHostBufferHandle, + val length: Long) + extends AutoCloseable { /** * Set a new spill priority. */ def setSpillPriority(priority: Long): Unit = { - handle.setSpillPriority(priority) + // TODO: handle.setSpillPriority(priority) } /** @@ -437,9 +474,7 @@ class SpillableHostBuffer(handle: RapidsBufferHandle, } def getHostBuffer(): HostMemoryBuffer = { - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.getHostMemoryBuffer - } + handle.materialize() } override def toString: String = @@ -457,10 +492,8 @@ object SpillableBuffer { def apply( buffer: DeviceMemoryBuffer, priority: Long): SpillableBuffer = { - val meta = MetaUtils.getTableMetaNoTable(buffer.getLength) - val handle = withResource(buffer) { _ => - RapidsBufferCatalog.addBuffer(buffer, meta, priority) - } + Cuda.DEFAULT_STREAM.sync() + val handle = SpillableDeviceBufferHandle(buffer) // TODO: AB: priority new SpillableBuffer(handle) } } @@ -478,17 +511,12 @@ object SpillableHostBuffer { */ def apply(buffer: HostMemoryBuffer, length: Long, - priority: Long, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton): SpillableHostBuffer = { + priority: Long): SpillableHostBuffer = { closeOnExcept(buffer) { _ => require(length <= buffer.getLength, s"Attempted to add a host spillable with a length ${length} B which is " + s"greater than the backing host buffer length ${buffer.getLength} B") } - val meta = MetaUtils.getTableMetaNoTable(buffer.getLength) - val handle = withResource(buffer) { _ => - catalog.addBuffer(buffer, meta, priority) - } - new SpillableHostBuffer(handle, length, catalog) + new SpillableHostBuffer(SpillableHostBufferHandle(buffer), length) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala index 08a1ae22f5e..0a7942bd581 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,9 @@ package com.nvidia.spark.rapids.shuffle -import java.io.IOException - -import ai.rapids.cudf.{Cuda, MemoryBuffer} -import com.nvidia.spark.rapids.{RapidsBuffer, ShuffleMetadata, StorageTier} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, MemoryBuffer} +import com.nvidia.spark.rapids.{RapidsShuffleHandle, ShuffleMetadata} +import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.format.{BufferMeta, BufferTransferRequest} @@ -60,8 +58,17 @@ class BufferSendState( serverStream: Cuda.Stream = Cuda.DEFAULT_STREAM) extends AutoCloseable with Logging { - class SendBlock(val bufferId: Int, tableSize: Long) extends BlockWithSize { - override def size: Long = tableSize + class SendBlock(val bufferHandle: RapidsShuffleHandle) extends BlockWithSize { + // we assume that the size of the buffer won't change as it goes to host/disk + // we also are likely to assume this is just a device buffer, and so we should + // copy to device and then send. + override def size: Long = { + if (bufferHandle.spillable != null) { + bufferHandle.spillable.sizeInBytes + } else { + 0L // degenerate + } + } } val peerExecutorId: Long = transaction.peerExecutorId() @@ -80,13 +87,10 @@ class BufferSendState( val btr = new BufferTransferRequest() // for reuse val blocksToSend = (0 until transferRequest.requestsLength()).map { ix => val bufferTransferRequest = transferRequest.requests(btr, ix) - withResource(requestHandler.acquireShuffleBuffer( - bufferTransferRequest.bufferId())) { table => - bufferMetas(ix) = table.meta.bufferMeta() - new SendBlock(bufferTransferRequest.bufferId(), table.getPackedSizeBytes) - } + val handle = requestHandler.getShuffleHandle(bufferTransferRequest.bufferId()) + bufferMetas(ix) = handle.tableMeta.bufferMeta() + new SendBlock(handle) } - (peerBufferReceiveHeader, bufferMetas, blocksToSend) } } @@ -145,7 +149,7 @@ class BufferSendState( } case class RangeBuffer( - range: BlockRange[SendBlock], rapidsBuffer: RapidsBuffer) + range: BlockRange[SendBlock], rapidsBuffer: MemoryBuffer) extends AutoCloseable { override def close(): Unit = { rapidsBuffer.close() @@ -170,50 +174,50 @@ class BufferSendState( if (hasMoreBlocks) { var deviceBuffs = 0L var hostBuffs = 0L - acquiredBuffs = blockRanges.safeMap { blockRange => - val bufferId = blockRange.block.bufferId - // we acquire these buffers now, and keep them until the caller releases them - // using `releaseAcquiredToCatalog` - closeOnExcept( - requestHandler.acquireShuffleBuffer(bufferId)) { rapidsBuffer => + var needsCleanup = false + try { + acquiredBuffs = blockRanges.safeMap { blockRange => + // we acquire these buffers now, and keep them until the caller releases them + // using `releaseAcquiredToCatalog` //these are closed later, after we synchronize streams - rapidsBuffer.storageTier match { - case StorageTier.DEVICE => + val spillable = blockRange.block.bufferHandle.spillable + val buff = spillable.materialize() + buff match { + case _: DeviceMemoryBuffer => deviceBuffs += blockRange.rangeSize() - case _ => // host/disk + case _ => hostBuffs += blockRange.rangeSize() } - RangeBuffer(blockRange, rapidsBuffer) + RangeBuffer(blockRange, buff) } - } - logDebug(s"Occupancy for bounce buffer is [device=${deviceBuffs}, host=${hostBuffs}] Bytes") + logDebug(s"Occupancy for bounce buffer is " + + s"[device=${deviceBuffs}, host=${hostBuffs}] Bytes") - bounceBuffToUse = if (deviceBuffs >= hostBuffs || hostBounceBuffer == null) { - deviceBounceBuffer.buffer - } else { - hostBounceBuffer.buffer - } + bounceBuffToUse = if (deviceBuffs >= hostBuffs || hostBounceBuffer == null) { + deviceBounceBuffer.buffer + } else { + hostBounceBuffer.buffer + } - // `copyToMemoryBuffer` can throw if the `RapidsBuffer` is in the DISK tier and - // the file fails to mmap. We catch the `IOException` and attempt a retry - // in the server. - var needsCleanup = false - try { - acquiredBuffs.foreach { case RangeBuffer(blockRange, rapidsBuffer) => + acquiredBuffs.foreach { case RangeBuffer(blockRange, memoryBuffer) => needsCleanup = true require(blockRange.rangeSize() <= bounceBuffToUse.getLength - buffOffset) - rapidsBuffer.copyToMemoryBuffer(blockRange.rangeStart, bounceBuffToUse, buffOffset, - blockRange.rangeSize(), serverStream) + bounceBuffToUse.copyFromMemoryBufferAsync( + buffOffset, + memoryBuffer, + blockRange.rangeStart, + blockRange.rangeSize(), + serverStream) buffOffset += blockRange.rangeSize() } needsCleanup = false } catch { - case ioe: IOException => + case ex: Exception => throw new RapidsShuffleSendPrepareException( s"Error while copying to bounce buffer for executor ${peerExecutorId} and " + - s"header ${TransportUtils.toHex(peerBufferReceiveHeader)}", ioe) + s"header ${TransportUtils.toHex(peerBufferReceiveHeader)}", ex) } finally { if (needsCleanup) { // we likely failed in `copyToMemoryBuffer` @@ -251,4 +255,4 @@ class BufferSendState( acquiredBuffs.foreach(_.close()) acquiredBuffs = Seq.empty } -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index b73f9820bad..2723e24b0f0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ trait RapidsShuffleFetchHandler { * @return a boolean that lets the caller know the batch was accepted (true), or * rejected (false), in which case the caller should dispose of the batch. */ - def batchReceived(handle: RapidsBufferHandle): Boolean + def batchReceived(handle: RapidsShuffleHandle): Boolean /** * Called when the transport layer is not able to handle a fetch error for metadata @@ -390,7 +390,7 @@ class RapidsShuffleClient( buffMetas.foreach { consumed: ConsumedBatchFromBounceBuffer => val handle = track(consumed.contigBuffer, consumed.meta) if (!consumed.handler.batchReceived(handle)) { - catalog.removeBuffer(handle) + handle.close() numBatchesRejected += 1 } transport.doneBytesInFlight(consumed.contigBuffer.getLength) @@ -431,25 +431,19 @@ class RapidsShuffleClient( * used to look up the buffer from the catalog going (e.g. from the iterator) * @param buffer contiguous [[DeviceMemoryBuffer]] with the tables' data * @param meta [[TableMeta]] describing [[buffer]] - * @return the [[RapidsBufferId]] to be used to look up the buffer from catalog + * @return a [[RapidsShuffleHandle]] with a spillable and metadata */ private[shuffle] def track( - buffer: DeviceMemoryBuffer, meta: TableMeta): RapidsBufferHandle = { + buffer: DeviceMemoryBuffer, meta: TableMeta): RapidsShuffleHandle = { if (buffer != null) { // add the buffer to the catalog so it is available for spill catalog.addBuffer( buffer, meta, - SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY, - // set needsSync to false because we already have stream synchronized after - // consuming the bounce buffer, so we know these buffers are synchronized - // w.r.t. the CPU - needsSync = false) + SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY) } else { // no device data, just tracking metadata - catalog.addDegenerateRapidsBuffer( - meta) - + catalog.addDegenerateBatch(meta) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala index 126b9200c90..95c411c64f2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import java.util.concurrent.{ConcurrentLinkedQueue, Executor} import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{Cuda, MemoryBuffer, NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.{RapidsBuffer, RapidsConf, ShuffleMetadata} +import com.nvidia.spark.rapids.{RapidsConf, RapidsShuffleHandle, ShuffleMetadata} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.format.TableMeta @@ -49,7 +49,7 @@ trait RapidsShuffleRequestHandler { * @param tableId the unique id for a table in the catalog * @return a [[RapidsBuffer]] which is reference counted, and should be closed by the acquirer */ - def acquireShuffleBuffer(tableId: Int): RapidsBuffer + def getShuffleHandle(tableId: Int): RapidsShuffleHandle } /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala new file mode 100644 index 00000000000..57f2a823432 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -0,0 +1,1743 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.spill + +import java.io._ +import java.nio.ByteBuffer +import java.nio.channels.{Channels, FileChannel, WritableByteChannel} +import java.nio.file.StandardOpenOption +import java.util +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable + +import ai.rapids.cudf._ +import com.nvidia.spark.rapids.{GpuColumnVector, GpuColumnVectorFromBuffer, GpuCompressedColumnVector, GpuDeviceManager, HostAlloc, HostMemoryOutputStream, MemoryBufferToHostByteBufferIterator, RapidsConf, RapidsHostColumnVector} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableSeq +import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.internal.HostByteBufferIterator +import org.apache.commons.io.IOUtils + +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.rapids.{GpuTaskMetrics, RapidsDiskBlockManager} +import org.apache.spark.sql.rapids.execution.SerializedHostTableUtils +import org.apache.spark.sql.rapids.storage.RapidsStorageUtils +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.storage.BlockId + +/** + * Spark-RAPIDS Spill Framework + * + * The spill framework tracks device/host/disk object lifecycle in the RAPIDS Accelerator + * for Apache Spark. A set of stores is used to track these objects, which are wrapped in + * "handles" that describe the state of each to the user and to the framework. + * + * This file comment covers some pieces of the framework that are worth knowing up front. + * + * Ownership: + * + * Any object handed to the framework via the factory methods for each of the handles + * should not be used directly by the user. The framework takes ownership of all objects. + * To get a reference back, call the `materialize` method, and always close what the framework + * returns. + * + * CUDA/Host synchronization: + * + * We assume all device backed handles are completely materialized on the device (before adding + * to the store, the CUDA stream has been synchronized with the CPU thread creating the handle), + * and that all host memory backed handles are completely materialized and not mutated by + * other CPU threads, because the contents of the handle may spill at any time, using any CUDA + * stream or thread, without synchronization. If handles added to the store are not synchronized + * we could write incomplete data to host memory or to disk. + * + * Spillability: + * + * An object is spillable (it will be copied to host or disk during OOM) if: + * - it has a approxSizeInBytes > 0 + * - it is not actively being referenced by the user (call to `materialize`, or aliased) + * - it hasn't already spilled + * - it hasn't been closed + * + * Aliasing: + * + * We handle aliasing of objects, either in the spill framework or outside, by looking at the + * reference count. All objects added to the store should support a ref count. + * If the ref count is greater than the expected value, we assume it is being aliased, + * and therefore we don't waste time spilling the aliased object. Please take a look at the + * `spillable` method in each of the handles on how this is implemented. + * + * Materialization: + * + * Every store handle supports a `materialize` method that isn't part of the interface. + * The reason is that to materialize certain objects, you may need some data (for example, + * Spark schema descriptors). `materialize` incRefCounts the object if it's resident in the + * intended store (`DeviceSpillableHandle` incRefCounts an object if it is in the device store), + * and otherwise it will create a new copy from the spilled version and hand it to the user. + * Any time a user calls `materialize`, they are responsible for closing the returned object. + * + * Spilling: + * + * A `SpillableHandle` will track an object in a specific store (`DeviceSpillable` tracks + * device "intended" objects) for example. If the handle is asked to spill, it is the handle's + * responsibility to initiate that spill, and to track the spilled handle (a device spillable + * would have a `host` handle, which tracks the host spilled object). + * + * Spill is broken down into two methods: `spill` and `releaseSpilled`. This is a two stage + * process because we need to make sure that there is no code running kernels on the spilled + * data before we actually free it. See method documentations for `spill` and `releasedSpilled` + * for more info. + * + * A cascade of spills can occur device -> host -> disk, given that host allocations can fail, or + * could not fit in the SpillableHostStore's limit (if defined). In this case, the call to spill + * will either create a host handle tracking an object on the host store (if we made room), or it + * will create a host handle that points to a disk handle, tracking a file on disk. + * + * Host handles created directly, via the factory methods `SpillableHostBufferHandle(...)` or + * `SpillableHostColumnarBatchHandle(...)`, do not trigger immediate spills. For example: + * if the host store limit is set to 1GB, and we add a 1.5GB host buffer via + * its factory method, we are going to have 2.5GB worth of host memory in the host store. + * That said, if we run out of device memory and we need to spill to host, 1.5GB will be spilled + * to disk, as device OOM triggers the pipeline spill. + * + * If we don't have a host store limit, spilling from the host store is done entirely via + * host memory allocation failure callbacks. All objects added to the host store are tracked + * immediately, since they were successfully allocated. If we fail to allocate host memory + * during a device->host spill, however, the spill framework will bypass host memory and + * go straight to disk (this last part works the same whether there are host limits or not). + * + * If the disk is full, we do not handle this in any special way. We expect this to be a + * terminal state to the executor. Every handle spills to its own file on disk, identified + * as a "temporary block" `BlockId` from Spark. + * + * Notes on locking: + * + * All stores use a concurrent hash map to store instances of `StoreHandle`. The only store + * with extra locking is the `SpillableHostStore`, to maintain a `totalSize` number that is + * used to figure out cheaply when it is full. + * + * All handles, except for disk handles, hold a reference to an object in their respective store: + * `SpillableDeviceBufferHandle` has a `dev` reference that holds a `DeviceMemoryBuffer`, and a + * `host` reference to `SpillableHostBufferHandle` that is only set if spilled. Disk handles are + * different because they don't spill, as disk is considered the final store. When a user calls + * `materialize` on a handle, the handle must guarantee that it can satisfy that, even if the caller + * should wait until a spill happens. This is currently implemented using the handle lock. + * + * Note that we hold the handle lock while we are spilling (performing IO). That means that no other + * consumer can access this spillable device handle while it is being spilled, including a second + * thread that is trying to spill and is generating a spill plan, as the handle lock is likely held + * up with IO. We will relax this likely in follow on work. + * + * We never hold a store-wide coarse grain lock in the stores when we do IO. + */ + +/** + * Common interface for all handles in the spill framework. + */ +trait StoreHandle extends AutoCloseable { + /** + * Approximate size of this handle, used in three scenarios: + * - Used by callers when accumulating up to a batch size for size goals. + * - Used from the host store to figure out how much host memory total it is tracking. + * - If approxSizeInBytes is 0, the object is tracked by the stores so it can be + * removed on shutdown, or by handle.close, but 0-byte handles are not spillable. + */ + val approxSizeInBytes: Long +} + +trait SpillableHandle extends StoreHandle { + /** + * Method called to spill this handle. It can be triggered from the spill store, + * or directly against the handle. + * + * This will not free the spilled data. If you would like to free the spill + * call `releaseSpilled` + * + * @note The size returned from this method is only used by the spill framework + * to track the approximate size. It should just return `approxSizeInBytes`, as + * that's the size that it used when it first started tracking the object. + * @return approxSizeInBytes if spilled, 0 for any other reason (not spillable, closed) + */ + def spill(): Long + + /** + * Method used to determine whether a handle tracks an object that could be spilled + * @note At the level of `SpillableHandle`, the only requirement of spillability + * is that the size of the handle is > 0. `approxSizeInBytes` is known at + * construction, and is immutable. + * @return true if currently spillable, false otherwise + */ + private[spill] def spillable: Boolean = approxSizeInBytes > 0 +} + +/** + * Spillable handles that can be materialized on the device. + * @tparam T an auto closeable subclass. `dev` tracks an instance of this object, + * on the device. + */ +trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { + private[spill] var dev: Option[T] + + private[spill] override def spillable: Boolean = synchronized { + super.spillable && dev.isDefined + } + + protected def releaseDeviceResource(): Unit = { + SpillFramework.removeFromDeviceStore(this) + synchronized { + dev.foreach(_.close()) + dev = None + } + } + + /** + * Part two of the two-stage process for spilling device buffers. We call `releaseSpilled` after + * a handle has spilled, and after a device synchronize. This prevents a race + * between threads working on cuDF kernels, that did not synchronize while holding the + * materialized handle's refCount, and the spiller thread (the spiller thread cannot + * free a device buffer that the worker thread isn't done with). + * See https://github.com/NVIDIA/spark-rapids/issues/8610 for more info. + */ + def releaseSpilled(): Unit = { + releaseDeviceResource() + } +} + +/** + * Spillable handles that can be materialized on the host. + * @tparam T an auto closeable subclass. `host` tracks an instance of this object, + * on the host. + */ +trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle { + private[spill] var host: Option[T] + + private[spill] override def spillable: Boolean = synchronized { + super.spillable && host.isDefined + } + + protected def releaseHostResource(): Unit = { + SpillFramework.removeFromHostStore(this) + synchronized { + host.foreach(_.close()) + host = None + } + } +} + +object SpillableHostBufferHandle extends Logging { + def apply(hmb: HostMemoryBuffer): SpillableHostBufferHandle = { + val handle = new SpillableHostBufferHandle(hmb.getLength, host = Some(hmb)) + SpillFramework.stores.hostStore.trackNoSpill(handle) + handle + } + + private[spill] def createHostHandleWithPacker( + chunkedPacker: ChunkedPacker): SpillableHostBufferHandle = { + val handle = new SpillableHostBufferHandle(chunkedPacker.getTotalContiguousSize) + withResource( + SpillFramework.stores.hostStore.makeBuilder(handle)) { builder => + while (chunkedPacker.hasNext) { + val (bb, len) = chunkedPacker.next() + withResource(bb) { _ => + builder.copyNext(bb.dmb, len, Cuda.DEFAULT_STREAM) + // copyNext is synchronous w.r.t. the cuda stream passed, + // no need to synchronize here. + } + } + builder.build + } + } + + private[spill] def createHostHandleFromDeviceBuff( + buff: DeviceMemoryBuffer): SpillableHostBufferHandle = { + val handle = new SpillableHostBufferHandle(buff.getLength) + withResource( + SpillFramework.stores.hostStore.makeBuilder(handle)) { builder => + builder.copyNext(buff, buff.getLength, Cuda.DEFAULT_STREAM) + builder.build + } + } +} + +class SpillableHostBufferHandle private ( + val sizeInBytes: Long, + private[spill] override var host: Option[HostMemoryBuffer] = None, + private[spill] var disk: Option[DiskHandle] = None) + extends HostSpillableHandle[HostMemoryBuffer] { + + override val approxSizeInBytes: Long = sizeInBytes + + private[spill] override def spillable: Boolean = synchronized { + if (super.spillable) { + host.getOrElse { + throw new IllegalStateException( + s"$this is spillable but it doesn't have a materialized host buffer!") + }.getRefCount == 1 + } else { + false + } + } + + def materialize(): HostMemoryBuffer = { + var materialized: HostMemoryBuffer = null + var diskHandle: DiskHandle = null + synchronized { + if (host.isDefined) { + materialized = host.get + materialized.incRefCount() + } else if (disk.isDefined) { + diskHandle = disk.get + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (materialized == null) { + materialized = closeOnExcept(HostMemoryBuffer.allocate(sizeInBytes)) { hmb => + diskHandle.materializeToHostMemoryBuffer(hmb) + hmb + } + } + materialized + } + + override def spill(): Long = { + if (!spillable) { + 0L + } else { + val spilled = synchronized { + if (disk.isEmpty && host.isDefined) { + withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => + val outputChannel = diskHandleBuilder.getChannel + GpuTaskMetrics.get.spillToDiskTime { + val iter = new HostByteBufferIterator(host.get) + iter.foreach { bb => + try { + while (bb.hasRemaining) { + outputChannel.write(bb) + } + } finally { + RapidsStorageUtils.dispose(bb) + } + } + } + disk = Some(diskHandleBuilder.build) + sizeInBytes + } + } else { + 0L + } + } + releaseHostResource() + spilled + } + } + + override def close(): Unit = { + releaseHostResource() + synchronized { + disk.foreach(_.close()) + disk = None + } + } + + private[spill] def materializeToDeviceMemoryBuffer(dmb: DeviceMemoryBuffer): Unit = { + var hostBuffer: HostMemoryBuffer = null + var diskHandle: DiskHandle = null + synchronized { + if (host.isDefined) { + hostBuffer = host.get + hostBuffer.incRefCount() + } else if (disk.isDefined) { + diskHandle = disk.get + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (hostBuffer != null) { + GpuTaskMetrics.get.readSpillFromHostTime { + withResource(hostBuffer) { _ => + dmb.copyFromHostBuffer( + /*dstOffset*/ 0, + /*src*/ hostBuffer, + /*srcOffset*/ 0, + /*length*/ hostBuffer.getLength) + } + } + } else { + // cannot find a full host buffer, get chunked api + // from disk + diskHandle.materializeToDeviceMemoryBuffer(dmb) + } + } + + private[spill] def setHost(singleShotBuffer: HostMemoryBuffer): Unit = synchronized { + host = Some(singleShotBuffer) + } + + private[spill] def setDisk(handle: DiskHandle): Unit = synchronized { + disk = Some(handle) + } +} + +object SpillableDeviceBufferHandle { + def apply(dmb: DeviceMemoryBuffer): SpillableDeviceBufferHandle = { + val handle = new SpillableDeviceBufferHandle(dmb.getLength, dev = Some(dmb)) + SpillFramework.stores.deviceStore.track(handle) + handle + } +} + +class SpillableDeviceBufferHandle private ( + val sizeInBytes: Long, + private[spill] override var dev: Option[DeviceMemoryBuffer], + private[spill] var host: Option[SpillableHostBufferHandle] = None) + extends DeviceSpillableHandle[DeviceMemoryBuffer] { + + override val approxSizeInBytes: Long = sizeInBytes + + private[spill] override def spillable: Boolean = synchronized { + if (super.spillable) { + dev.getOrElse { + throw new IllegalStateException( + s"$this is spillable but it doesn't have a dev buffer!") + }.getRefCount == 1 + } else { + false + } + } + + def materialize(): DeviceMemoryBuffer = { + var materialized: DeviceMemoryBuffer = null + var hostHandle: SpillableHostBufferHandle = null + synchronized { + if (host.isDefined) { + // since we spilled, host must be set. + hostHandle = host.get + } else if (dev.isDefined) { + materialized = dev.get + materialized.incRefCount() + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + // if `materialized` is null, we spilled. This is a terminal + // state, as we are not allowing unspill, and we don't need + // to hold locks while we copy back from here. + if (materialized == null) { + materialized = closeOnExcept(DeviceMemoryBuffer.allocate(sizeInBytes)) { dmb => + hostHandle.materializeToDeviceMemoryBuffer(dmb) + dmb + } + } + materialized + } + + override def spill(): Long = { + if (!spillable) { + 0L + } else { + synchronized { + if (host.isEmpty && dev.isDefined) { + host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(dev.get)) + sizeInBytes + } else { + 0L + } + } + } + } + + override def close(): Unit = { + releaseDeviceResource() + synchronized { + host.foreach(_.close()) + host = None + } + } +} + +class SpillableColumnarBatchHandle private ( + override val approxSizeInBytes: Long, + private[spill] override var dev: Option[ColumnarBatch], + private[spill] var host: Option[SpillableHostBufferHandle] = None) + extends DeviceSpillableHandle[ColumnarBatch] with Logging { + + override def spillable: Boolean = synchronized { + if (super.spillable) { + val dcvs = GpuColumnVector.extractBases(dev.get) + val colRepetition = mutable.HashMap[ColumnVector, Int]() + dcvs.foreach { hcv => + colRepetition.put(hcv, colRepetition.getOrElse(hcv, 0) + 1) + } + dcvs.forall(dcv => { + colRepetition(dcv) == dcv.getRefCount + }) + } else { + false + } + } + + private var meta: Option[ByteBuffer] = None + + def materialize(dt: Array[DataType]): ColumnarBatch = { + var materialized: ColumnarBatch = null + var hostHandle: SpillableHostBufferHandle = null + synchronized { + if (host.isDefined) { + hostHandle = host.get + } else if (dev.isDefined) { + materialized = GpuColumnVector.incRefCounts(dev.get) + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (materialized == null) { + val devBuffer = closeOnExcept(DeviceMemoryBuffer.allocate(hostHandle.sizeInBytes)) { dmb => + hostHandle.materializeToDeviceMemoryBuffer(dmb) + dmb + } + val cb = withResource(devBuffer) { _ => + withResource(Table.fromPackedTable(meta.get, devBuffer)) { tbl => + GpuColumnVector.from(tbl, dt) + } + } + materialized = cb + } + materialized + } + + override def spill(): Long = { + if (!spillable) { + 0L + } else { + synchronized { + if (host.isEmpty && dev.isDefined) { + withChunkedPacker { chunkedPacker => + meta = Some(chunkedPacker.getPackedMeta) + host = Some(SpillableHostBufferHandle.createHostHandleWithPacker(chunkedPacker)) + } + // We return the size we were created with. This is not the actual size + // of this batch when it is packed, and it is used by the calling code + // to figure out more or less how much did we free in the device. + approxSizeInBytes + } else { + 0L + } + } + } + } + + private def withChunkedPacker[T](body: ChunkedPacker => T): T = { + val tbl = synchronized { + if (dev.isEmpty) { + throw new IllegalStateException("cannot get copier without a batch") + } + GpuColumnVector.from(dev.get) + } + withResource(tbl) { _ => + withResource(new ChunkedPacker(tbl, SpillFramework.chunkedPackBounceBufferPool)) { packer => + body(packer) + } + } + } + + override def close(): Unit = { + releaseDeviceResource() + synchronized { + host.foreach(_.close()) + host = None + } + } +} + +object SpillableColumnarBatchFromBufferHandle { + def apply( + ct: ContiguousTable, + dataTypes: Array[DataType]): SpillableColumnarBatchFromBufferHandle = { + withResource(ct) { _ => + val sizeInBytes = ct.getBuffer.getLength + val cb = GpuColumnVectorFromBuffer.from(ct, dataTypes) + val handle = new SpillableColumnarBatchFromBufferHandle( + sizeInBytes, dev = Some(cb)) + SpillFramework.stores.deviceStore.track(handle) + handle + } + } + + def apply(cb: ColumnarBatch): SpillableColumnarBatchFromBufferHandle = { + require(GpuColumnVectorFromBuffer.isFromBuffer(cb), + "Columnar batch isn't a batch from buffer") + val sizeInBytes = + cb.column(0).asInstanceOf[GpuColumnVectorFromBuffer].getBuffer.getLength + val handle = new SpillableColumnarBatchFromBufferHandle( + sizeInBytes, dev = Some(cb)) + SpillFramework.stores.deviceStore.track(handle) + handle + } +} + +class SpillableColumnarBatchFromBufferHandle private ( + val sizeInBytes: Long, + private[spill] override var dev: Option[ColumnarBatch], + private[spill] var host: Option[SpillableHostBufferHandle] = None) + extends DeviceSpillableHandle[ColumnarBatch] { + + override val approxSizeInBytes: Long = sizeInBytes + + private var meta: Option[TableMeta] = None + + private[spill] override def spillable: Boolean = synchronized { + if (super.spillable) { + val dcvs = GpuColumnVector.extractBases(dev.get) + val colRepetition = mutable.HashMap[ColumnVector, Int]() + dcvs.foreach { hcv => + colRepetition.put(hcv, colRepetition.getOrElse(hcv, 0) + 1) + } + dcvs.forall(dcv => { + colRepetition(dcv) == dcv.getRefCount + }) + } else { + false + } + } + + def materialize(dt: Array[DataType]): ColumnarBatch = { + var materialized: ColumnarBatch = null + var hostHandle: SpillableHostBufferHandle = null + synchronized { + if (host.isDefined) { + hostHandle = host.get + } else if (dev.isDefined) { + materialized = GpuColumnVector.incRefCounts(dev.get) + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (materialized == null) { + val devBuffer = closeOnExcept(DeviceMemoryBuffer.allocate(hostHandle.sizeInBytes)) { dmb => + hostHandle.materializeToDeviceMemoryBuffer(dmb) + dmb + } + val cb = withResource(devBuffer) { _ => + withResource(Table.fromPackedTable(meta.get.packedMetaAsByteBuffer(), devBuffer)) { tbl => + GpuColumnVector.from(tbl, dt) + } + } + materialized = cb + } + materialized + } + + override def spill(): Long = { + if (!spillable) { + 0 + } else { + synchronized { + if (host.isEmpty && dev.isDefined) { + val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuColumnVectorFromBuffer] + meta = Some(cvFromBuffer.getTableMeta) + host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( + cvFromBuffer.getBuffer)) + sizeInBytes + } else { + 0L + } + } + } + } + + override def close(): Unit = { + releaseDeviceResource() + synchronized { + host.foreach(_.close()) + host = None + } + } +} + +object SpillableCompressedColumnarBatchHandle { + def apply(cb: ColumnarBatch): SpillableCompressedColumnarBatchHandle = { + require(GpuCompressedColumnVector.isBatchCompressed(cb), + "Tried to track a compressed batch, but the batch wasn't compressed") + val compressedSize = + cb.column(0).asInstanceOf[GpuCompressedColumnVector].getTableBuffer.getLength + val handle = new SpillableCompressedColumnarBatchHandle(compressedSize, dev = Some(cb)) + SpillFramework.stores.deviceStore.track(handle) + handle + } +} + +class SpillableCompressedColumnarBatchHandle private ( + val compressedSizeInBytes: Long, + private[spill] override var dev: Option[ColumnarBatch], + private[spill] var host: Option[SpillableHostBufferHandle] = None) + extends DeviceSpillableHandle[ColumnarBatch] { + + override val approxSizeInBytes: Long = compressedSizeInBytes + + protected var meta: Option[TableMeta] = None + + override def spillable: Boolean = synchronized { + if (super.spillable) { + val cb = dev.get + val buff = cb.column(0).asInstanceOf[GpuCompressedColumnVector].getTableBuffer + buff.getRefCount == 1 + } else { + false + } + } + + def materialize(): ColumnarBatch = { + var materialized: ColumnarBatch = null + var hostHandle: SpillableHostBufferHandle = null + synchronized { + if (host.isDefined) { + hostHandle = host.get + } else if (dev.isDefined) { + materialized = GpuCompressedColumnVector.incRefCounts(dev.get) + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (materialized == null) { + val devBuffer = closeOnExcept(DeviceMemoryBuffer.allocate(hostHandle.sizeInBytes)) { dmb => + hostHandle.materializeToDeviceMemoryBuffer(dmb) + dmb + } + materialized = withResource(devBuffer) { _ => + GpuCompressedColumnVector.from(devBuffer, meta.get) + } + } + materialized + } + + override def spill(): Long = { + if (!spillable) { + 0L + } else { + synchronized { + if (host.isEmpty && dev.isDefined) { + val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuCompressedColumnVector] + meta = Some(cvFromBuffer.getTableMeta) + host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( + cvFromBuffer.getTableBuffer)) + compressedSizeInBytes + } else { + 0L + } + } + } + } + + override def close(): Unit = { + releaseDeviceResource() + synchronized { + host.foreach(_.close()) + host = None + meta = None + } + } +} + +object SpillableHostColumnarBatchHandle { + def apply(cb: ColumnarBatch): SpillableHostColumnarBatchHandle = { + val sizeInBytes = RapidsHostColumnVector.getTotalHostMemoryUsed(cb) + val handle = new SpillableHostColumnarBatchHandle(sizeInBytes, cb.numRows(), host = Some(cb)) + SpillFramework.stores.hostStore.trackNoSpill(handle) + handle + } +} + +class SpillableHostColumnarBatchHandle private ( + override val approxSizeInBytes: Long, + val numRows: Int, + private[spill] override var host: Option[ColumnarBatch], + private[spill] var disk: Option[DiskHandle] = None) + extends HostSpillableHandle[ColumnarBatch] { + + override def spillable: Boolean = synchronized { + if (super.spillable) { + val hcvs = RapidsHostColumnVector.extractBases(host.get) + val colRepetition = mutable.HashMap[HostColumnVector, Int]() + hcvs.foreach { hcv => + colRepetition.put(hcv, colRepetition.getOrElse(hcv, 0) + 1) + } + hcvs.forall(hcv => { + colRepetition(hcv) == hcv.getRefCount + }) + } else { + false + } + } + + def materialize(sparkTypes: Array[DataType]): ColumnarBatch = { + var materialized: ColumnarBatch = null + var diskHandle: DiskHandle = null + synchronized { + if (host.isDefined) { + materialized = RapidsHostColumnVector.incRefCounts(host.get) + } else if (disk.isDefined) { + diskHandle = disk.get + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (materialized == null) { + materialized = diskHandle.withInputWrappedStream { inputStream => + val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(inputStream) + val hostCols = withResource(hostBuffer) { _ => + SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes) + } + new ColumnarBatch(hostCols.toArray, numRows) + } + } + materialized + } + + override def spill(): Long = { + if (!spillable) { + 0L + } else { + val spilled = synchronized { + if (disk.isEmpty && host.isDefined) { + withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => + GpuTaskMetrics.get.spillToDiskTime { + val dos = diskHandleBuilder.getDataOutputStream + val columns = RapidsHostColumnVector.extractBases(host.get) + JCudfSerialization.writeToStream(columns, dos, 0, host.get.numRows()) + } + disk = Some(diskHandleBuilder.build) + approxSizeInBytes + } + } else { + 0L + } + } + releaseHostResource() + spilled + } + } + + override def close(): Unit = { + releaseHostResource() + synchronized { + disk.foreach(_.close()) + disk = None + } + } +} + +object DiskHandle { + def apply(blockId: BlockId, + offset: Long, + diskSizeInBytes: Long): DiskHandle = { + val handle = new DiskHandle( + blockId, offset, diskSizeInBytes) + SpillFramework.stores.diskStore.track(handle) + handle + } +} + +/** + * A disk buffer handle helps us track spill-framework originated data on disk. + * This type of handle isn't spillable, and therefore it just implements `StoreHandle` + * @param blockId - a spark `BlockId` obtained from the configured `BlockManager` + * @param offset - starting offset for the data within the file backing `blockId` + * @param sizeInBytes - amount of bytes on disk (usually compressed and could also be encrypted). + */ +class DiskHandle private( + val blockId: BlockId, + val offset: Long, + val sizeInBytes: Long) + extends StoreHandle { + + override val approxSizeInBytes: Long = sizeInBytes + + private def withInputChannel[T](body: FileChannel => T): T = synchronized { + val file = SpillFramework.stores.diskStore.diskBlockManager.getFile(blockId) + GpuTaskMetrics.get.readSpillFromDiskTime { + withResource(new FileInputStream(file)) { fs => + withResource(fs.getChannel) { channel => + body(channel) + } + } + } + } + + def withInputWrappedStream[T](body: InputStream => T): T = synchronized { + val diskBlockManager = SpillFramework.stores.diskStore.diskBlockManager + val serializerManager = diskBlockManager.getSerializerManager() + GpuTaskMetrics.get.readSpillFromDiskTime { + withInputChannel { inputChannel => + inputChannel.position(offset) + withResource(Channels.newInputStream(inputChannel)) { compressed => + withResource(serializerManager.wrapStream(blockId, compressed)) { in => + body(in) + } + } + } + } + } + + override def close(): Unit = { + SpillFramework.removeFromDiskStore(this) + SpillFramework.stores.diskStore.deleteFile(blockId) + } + + def materializeToHostMemoryBuffer(mb: HostMemoryBuffer): Unit = { + withInputWrappedStream { in => + withResource(new HostMemoryOutputStream(mb)) { out => + IOUtils.copy(in, out) + } + } + } + + def materializeToDeviceMemoryBuffer(dmb: DeviceMemoryBuffer): Unit = { + var copyOffset = 0L + withInputWrappedStream { in => + SpillFramework.withHostSpillBounceBuffer { hmb => + val bbLength = hmb.getLength.toInt + withResource(new HostMemoryOutputStream(hmb)) { out => + var sizeRead = IOUtils.copyLarge(in, out, 0, bbLength) + while (sizeRead > 0) { + // this syncs at every copy, since for now we are + // reusing a single host spill bounce buffer + dmb.copyFromHostBuffer( + /*dstOffset*/ copyOffset, + /*src*/ hmb, + /*srcOffset*/ 0, + /*length*/ sizeRead) + out.seek(0) // start over + copyOffset += sizeRead + sizeRead = IOUtils.copyLarge(in, out, 0, bbLength) + } + } + } + } + } +} + +trait HandleStore[T <: StoreHandle] extends AutoCloseable with Logging { + protected val handles = new ConcurrentHashMap[T, java.lang.Boolean]() + + def numHandles: Int = { + handles.size() + } + + def track(handle: T): Unit = { + doTrack(handle) + } + + def remove(handle: T): Unit = { + doRemove(handle) + } + + protected def doTrack(handle: T): Boolean = { + handles.put(handle, true) == null + } + + protected def doRemove(handle: T): Boolean = { + handles.remove(handle) != null + } + + override def close(): Unit = { + handles.forEach((handle, _ )=> { + handle.close() + }) + handles.clear() + } +} + +trait SpillableStore[T <: SpillableHandle] + extends HandleStore[T] with Logging { + protected def spillNvtxRange: NvtxRange + + /** + * Internal class to provide an interface to our plan for this spill. + * + * We will build up this SpillPlan by adding spillables: handles + * that are marked spillable given the `spillable` method returning true. + * The spill store will call `trySpill`, which moves handles from the + * `spillableHandles` array to the `spilledHandles` array. + * + * At any point in time, a spill framework can call `getSpilled` + * to obtain the list of spilled handles. The device store does this + * to inject CUDA synchronization before actually releasing device handles. + */ + class SpillPlan { + private val spillableHandles = new util.ArrayList[T]() + private val spilledHandles = new util.ArrayList[T]() + + def add(spillable: T): Unit = { + spillableHandles.add(spillable) + } + + def trySpill(): Long = { + var amountSpilled = 0L + val it = spillableHandles.iterator() + while (it.hasNext) { + val handle = it.next() + val spilled = handle.spill() + if (spilled > 0) { + // this thread was successful at spilling handle. + amountSpilled += spilled + spilledHandles.add(handle) + } else { + // else, either: + // - this thread lost the race and the handle was closed + // - another thread spilled it + // - the handle isn't spillable anymore, due to ref count. + it.remove() + } + } + amountSpilled + } + + def getSpilled: util.ArrayList[T] = { + spilledHandles + } + } + + private def makeSpillPlan(spillNeeded: Long): SpillPlan = { + val plan = new SpillPlan() + var amountToSpill = 0L + val allHandles = handles.keySet().iterator() + // two threads could be here trying to spill and creating a list of spillables + while (allHandles.hasNext && amountToSpill < spillNeeded) { + val handle = allHandles.next() + if (handle.spillable) { + amountToSpill += handle.approxSizeInBytes + plan.add(handle) + } + } + plan + } + + protected def postSpill(plan: SpillPlan): Unit = {} + + def spill(spillNeeded: Long): Long = { + if (spillNeeded == 0) { + 0L + } else { + withResource(spillNvtxRange) { _ => + val plan = makeSpillPlan(spillNeeded) + val amountSpilled = plan.trySpill() + postSpill(plan) + amountSpilled + } + } + } +} + +class SpillableHostStore(val maxSize: Option[Long] = None) + extends SpillableStore[HostSpillableHandle[_]] + with Logging { + + private[spill] var totalSize: Long = 0L + + private def tryTrack(handle: HostSpillableHandle[_]): Boolean = { + if (maxSize.isEmpty || handle.approxSizeInBytes == 0) { + super.doTrack(handle) + // for now, keep this totalSize part, we technically + // do not need to track `totalSize` if we don't have a limit + synchronized { + totalSize += handle.approxSizeInBytes + } + true + } else { + synchronized { + val storeMaxSize = maxSize.get + if (totalSize > 0 && totalSize + handle.approxSizeInBytes > storeMaxSize) { + // we want to try to make room for this buffer + false + } else { + // it fits + if (super.doTrack(handle)) { + totalSize += handle.approxSizeInBytes + } + true + } + } + } + } + + override def track(handle: HostSpillableHandle[_]): Unit = { + trackInternal(handle) + } + + private def trackInternal(handle: HostSpillableHandle[_]): Boolean = { + // try to track the handle: in the case of no limits + // this should just be add to the store + var tracked = false + tracked = tryTrack(handle) + if (!tracked) { + // we only end up here if we have host store limits. + var numRetries = 0 + // we are going to try to track again, in a loop, + // since we want to release + var canFit = true + val handleSize = handle.approxSizeInBytes + var amountSpilled = 0L + val hadHandlesToSpill = !handles.isEmpty + while (canFit && !tracked && numRetries < 5) { + // if we are trying to add a handle larger than our limit + if (maxSize.get < handleSize) { + // no point in checking how much is free, just spill all + // we have + amountSpilled += spill(maxSize.get) + } else { + // handleSize is within the limits + val freeAmount = synchronized { + maxSize.get - totalSize + } + val spillNeeded = handleSize - freeAmount + if (spillNeeded > 0) { + amountSpilled += spill(spillNeeded) + } + } + tracked = tryTrack(handle) + if (!tracked) { + // we tried to spill, and we still couldn't fit this buffer + // if we have a totalSize > 0, we could try some more + // the disk api + synchronized { + canFit = totalSize > 0 + } + } + numRetries += 1 + } + val taskId = Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(0) + if (hadHandlesToSpill) { + logInfo(s"Task $taskId spilled $amountSpilled bytes while trying to " + + s"track $handleSize bytes.") + } + } + tracked + } + + /** + * This is a special method in the host store where spillable handles can be added + * but they will not trigger the cascade host->disk spill logic. This is to replicate + * how the stores used to work in the past, and is only called from factory + * methods that are used by client code. + */ + def trackNoSpill(handle: HostSpillableHandle[_]): Unit = { + synchronized { + if (doTrack(handle)) { + totalSize += handle.approxSizeInBytes + } + } + } + + override def remove(handle: HostSpillableHandle[_]): Unit = { + synchronized { + if (doRemove(handle)) { + totalSize -= handle.approxSizeInBytes + } + } + } + + /** + * Makes a builder object for `SpillableHostBufferHandle`. The builder will + * either copy ot host or disk, if the host buffer fits in the host store (if tracking + * is enabled). + * + * Host store locks and disk store locks will be taken/released during this call, but + * after the builder is created, no locks are held in the store. + * + * @note When creating the host buffer handle, never call the factory Spillable* methods, + * instead, construct the handles directly. This is because the factory methods + * trigger a spill to disk, and that standard behavior of the spill framework so far. + * @param handle a host handle that only has a size set, and no backing store. + * @return the builder to be closed by caller + */ + def makeBuilder(handle: SpillableHostBufferHandle): SpillableHostBufferHandleBuilder = { + var builder: Option[SpillableHostBufferHandleBuilder] = None + if (handle.sizeInBytes <= maxSize.getOrElse(Long.MaxValue)) { + HostAlloc.tryAlloc(handle.sizeInBytes).foreach { hmb => + withResource(hmb) { _ => + if (trackInternal(handle)) { + hmb.incRefCount() + // the host store made room or fit this buffer + builder = Some(new SpillableHostBufferHandleBuilderForHost(handle, hmb)) + } + } + } + } + builder.getOrElse { + // the disk store will track this when we call .build + new SpillableHostBufferHandleBuilderForDisk(handle) + } + } + + trait SpillableHostBufferHandleBuilder extends AutoCloseable { + /** + * Copy `mb` from offset 0 to len to host or disk. + * + * We synchronize after each copy since we do not manage the lifetime + * of `mb`. + * + * @param mb buffer to copy from + * @param len the amount of bytes that should be copied from `mb` + * @param stream CUDA stream to use, and synchronize against + */ + def copyNext(mb: DeviceMemoryBuffer, len: Long, stream: Cuda.Stream): Unit + + /** + * Returns a usable `SpillableHostBufferHandle` with either the + * `host` or `disk` set with the appropriate object. + * + * Note that if we are writing to disk, we are going to add a + * new `DiskHandle` in the disk store's concurrent collection. + * + * @return host handle with data in host or disk + */ + def build: SpillableHostBufferHandle + } + + private class SpillableHostBufferHandleBuilderForHost( + var handle: SpillableHostBufferHandle, + var singleShotBuffer: HostMemoryBuffer) + extends SpillableHostBufferHandleBuilder with Logging { + private var copied = 0L + + override def copyNext(mb: DeviceMemoryBuffer, len: Long, stream: Cuda.Stream): Unit = { + GpuTaskMetrics.get.spillToHostTime { + singleShotBuffer.copyFromMemoryBuffer( + copied, + mb, + 0, + len, + stream) + copied += len + } + } + + override def build: SpillableHostBufferHandle = { + // add some sort of setter method to Host Handle + require(handle != null, "Called build too many times") + require(copied == handle.sizeInBytes, + s"Expected ${handle.sizeInBytes} B but copied $copied B instead") + handle.setHost(singleShotBuffer) + singleShotBuffer = null + val res = handle + handle = null + res + } + + override def close(): Unit = { + if (handle != null) { + handle.close() + handle = null + } + if (singleShotBuffer != null) { + singleShotBuffer.close() + singleShotBuffer = null + } + } + } + + private class SpillableHostBufferHandleBuilderForDisk( + var handle: SpillableHostBufferHandle) + extends SpillableHostBufferHandleBuilder { + private var copied = 0L + private var diskHandleBuilder = DiskHandleStore.makeBuilder + + override def copyNext(mb: DeviceMemoryBuffer, len: Long, stream: Cuda.Stream): Unit = { + SpillFramework.withHostSpillBounceBuffer { hostSpillBounceBuffer => + GpuTaskMetrics.get.spillToDiskTime { + val outputChannel = diskHandleBuilder.getChannel + withResource(mb.slice(0, len)) { slice => + val iter = new MemoryBufferToHostByteBufferIterator( + slice, + hostSpillBounceBuffer, + Cuda.DEFAULT_STREAM) + iter.foreach { byteBuff => + try { + while (byteBuff.hasRemaining) { + outputChannel.write(byteBuff) + } + copied += byteBuff.capacity() + } finally { + RapidsStorageUtils.dispose(byteBuff) + } + } + } + } + } + } + + override def build: SpillableHostBufferHandle = { + // add some sort of setter method to Host Handle + require(handle != null, "Called build too many times") + require(copied == handle.sizeInBytes, + s"Expected ${handle.sizeInBytes} B but copied $copied B instead") + handle.setDisk(diskHandleBuilder.build) + val res = handle + handle = null + res + } + + override def close(): Unit = { + if (handle != null) { + handle.close() + handle = null + } + if (diskHandleBuilder!= null) { + diskHandleBuilder.close() + diskHandleBuilder = null + } + } + } + + override protected def spillNvtxRange: NvtxRange = + new NvtxRange("disk spill", NvtxColor.RED) +} + +class SpillableDeviceStore extends SpillableStore[DeviceSpillableHandle[_]] { + override protected def spillNvtxRange: NvtxRange = + new NvtxRange("device spill", NvtxColor.ORANGE) + + override def postSpill(plan: SpillPlan): Unit = { + // spillables is the list of handles that have to be closed + // we synchronize every thread before we release what was spilled + Cuda.deviceSynchronize() + // this is safe to be called unconditionally if another thread spilled + plan.getSpilled.forEach(_.releaseSpilled()) + } +} + +class DiskHandleStore(conf: SparkConf) + extends HandleStore[DiskHandle] with Logging { + val diskBlockManager: RapidsDiskBlockManager = new RapidsDiskBlockManager(conf) + + def getFile(blockId: BlockId): File = { + diskBlockManager.getFile(blockId) + } + + def deleteFile(blockId: BlockId): Unit = { + val file = getFile(blockId) + file.delete() + if (file.exists()) { + logWarning(s"Unable to delete $file") + } + } + + override def track(handle: DiskHandle): Unit = { + // protects the off chance that someone adds this handle twice.. + if (doTrack(handle)) { + GpuTaskMetrics.get.incDiskBytesAllocated(handle.sizeInBytes) + } + } + + override def remove(handle: DiskHandle): Unit = { + // protects the off chance that someone removes this handle twice.. + if (doRemove(handle)) { + GpuTaskMetrics.get.decDiskBytesAllocated(handle.sizeInBytes) + } + } +} + +object DiskHandleStore { + /** + * An object that knows how to write a block to disk in Spark. + * It supports + * @param blockId the BlockManager `BlockId` to use. + * @param startPos the position to start writing from, useful if we can + * share files + */ + class DiskHandleBuilder(val blockId: BlockId, + val startPos: Long = 0L) extends AutoCloseable { + private val file = SpillFramework.stores.diskStore.getFile(blockId) + + private val serializerManager = + SpillFramework.stores.diskStore.diskBlockManager.getSerializerManager() + + // this is just to make sure we use DiskWriter once and we are not leaking + // as it is, we could use `DiskWriter` to start writing at other offsets + private var closed = false + + private var fc: FileChannel = _ + + private def getFileChannel: FileChannel = { + val options = Seq(StandardOpenOption.CREATE, StandardOpenOption.WRITE) + fc = FileChannel.open(file.toPath, options:_*) + // seek to the starting pos + fc.position(startPos) + fc + } + + private def wrapChannel(channel: FileChannel): OutputStream = { + val os = Channels.newOutputStream(channel) + serializerManager.wrapStream(blockId, os) + } + + private var outputChannel: WritableByteChannel = _ + private var outputStream: DataOutputStream = _ + + def getChannel: WritableByteChannel = { + require(!closed, "Cannot write to closed DiskWriter") + require(outputStream == null, + "either channel or data output stream supported, but not both") + if (outputChannel != null) { + outputChannel + } else { + val fc = getFileChannel + val wrappedStream = closeOnExcept(fc)(wrapChannel) + outputChannel = closeOnExcept(wrappedStream)(Channels.newChannel) + outputChannel + } + } + + def getDataOutputStream: DataOutputStream = { + require(!closed, "Cannot write to closed DiskWriter") + require(outputStream == null, + "either channel or data output stream supported, but not both") + if (outputStream != null) { + outputStream + } else { + val fc = getFileChannel + val wrappedStream = closeOnExcept(fc)(wrapChannel) + outputStream = new DataOutputStream(wrappedStream) + outputStream + } + } + + override def close(): Unit = { + if (closed) { + throw new IllegalStateException("already closed DiskWriter") + } + if (outputStream != null) { + outputStream.close() + outputStream = null + } + if (outputChannel != null) { + outputChannel.close() + outputChannel = null + } + closed = true + } + + def build: DiskHandle = + DiskHandle( + blockId, + startPos, + fc.position() - startPos) + } + + def makeBuilder: DiskHandleBuilder = { + val blockId = BlockId(s"temp_local_${UUID.randomUUID().toString}") + new DiskHandleBuilder(blockId) + } +} + +trait SpillableStores extends AutoCloseable { + var deviceStore: SpillableDeviceStore + var hostStore: SpillableHostStore + var diskStore: DiskHandleStore + override def close(): Unit = { + Seq(deviceStore, hostStore, diskStore).safeClose() + } +} + +/** + * A spillable that is meant to be interacted with from the device. + */ +object SpillableColumnarBatchHandle { + def apply(tbl: Table, dataTypes: Array[DataType]): SpillableColumnarBatchHandle = { + withResource(tbl) { _ => + SpillableColumnarBatchHandle(GpuColumnVector.from(tbl, dataTypes)) + } + } + + def apply(cb: ColumnarBatch): SpillableColumnarBatchHandle = { + require(!GpuColumnVectorFromBuffer.isFromBuffer(cb), + "A SpillableColumnarBatchHandle doesn't support cuDF packed batches") + require(!GpuCompressedColumnVector.isBatchCompressed(cb), + "A SpillableColumnarBatchHandle doesn't support comprssed batches") + val sizeInBytes = GpuColumnVector.getTotalDeviceMemoryUsed(cb) + val handle = new SpillableColumnarBatchHandle(sizeInBytes, dev = Some(cb)) + SpillFramework.stores.deviceStore.track(handle) + handle + } +} + +object SpillFramework extends Logging { + // public for tests. Some tests not in the `spill` package require setting this + // because they need fine control over allocations. + var storesInternal: SpillableStores = _ + + def stores: SpillableStores = { + if (storesInternal == null) { + throw new IllegalStateException( + "Cannot use SpillFramework without calling SpillFramework.initialize first") + } + storesInternal + } + + // TODO: these should be pools, instead of individual buffers + private var hostSpillBounceBuffer: HostMemoryBuffer = _ + + private lazy val conf: SparkConf = { + val env = SparkEnv.get + if (env != null) { + env.conf + } else { + // For some unit tests + new SparkConf() + } + } + + def initialize(rapidsConf: RapidsConf): Unit = synchronized { + require(storesInternal == null, + s"cannot initialize SpillFramework multiple times.") + + val hostSpillStorageSize = if (rapidsConf.offHeapLimitEnabled) { + // Disable the limit because it is handled by the RapidsHostMemoryStore + None + } else if (rapidsConf.hostSpillStorageSize == -1) { + // + 1 GiB by default to match backwards compatibility + Some(rapidsConf.pinnedPoolSize + (1024L * 1024 * 1024)) + } else { + Some(rapidsConf.hostSpillStorageSize) + } + // this should hopefully be pinned, but it would work without + hostSpillBounceBuffer = HostMemoryBuffer.allocate(rapidsConf.spillToDiskBounceBufferSize) + + chunkedPackBounceBufferPool = new DeviceBounceBufferPool { + private val bounceBuffer: DeviceBounceBuffer = + DeviceBounceBuffer(DeviceMemoryBuffer.allocate(rapidsConf.chunkedPackBounceBufferSize)) + override def bufferSize: Long = rapidsConf.chunkedPackBounceBufferSize + override def nextBuffer(): DeviceBounceBuffer = { + // can block waiting for bounceBuffer to be released + bounceBuffer.acquire() + } + override def close(): Unit = { + // this closes the DeviceMemoryBuffer wrapped by the bounce buffer class + bounceBuffer.release() + } + } + storesInternal = new SpillableStores { + override var deviceStore: SpillableDeviceStore = new SpillableDeviceStore + override var hostStore: SpillableHostStore = new SpillableHostStore(hostSpillStorageSize) + override var diskStore: DiskHandleStore = new DiskHandleStore(conf) + } + val hostSpillStorageSizeStr = hostSpillStorageSize.map(sz => s"$sz B").getOrElse("unlimited") + logInfo(s"Initialized SpillFramework. Host spill store max size is: $hostSpillStorageSizeStr.") + } + + def shutdown(): Unit = { + if (hostSpillBounceBuffer != null) { + hostSpillBounceBuffer.close() + hostSpillBounceBuffer = null + } + if (chunkedPackBounceBufferPool != null) { + chunkedPackBounceBufferPool.close() + chunkedPackBounceBufferPool = null + } + if (storesInternal != null) { + storesInternal.close() + storesInternal = null + } + } + + def withHostSpillBounceBuffer[T](body: HostMemoryBuffer => T): T = + hostSpillBounceBuffer.synchronized { + body(hostSpillBounceBuffer) + } + + var chunkedPackBounceBufferPool: DeviceBounceBufferPool = _ + + // if the stores have already shut down, we don't want to create them here + // so we use `storesInternal` directly in these remove functions. + + private[spill] def removeFromDeviceStore(handle: DeviceSpillableHandle[_]): Unit = { + synchronized { + Option(storesInternal).map(_.deviceStore) + }.foreach(_.remove(handle)) + } + + private[spill] def removeFromHostStore(handle: HostSpillableHandle[_]): Unit = { + synchronized { + Option(storesInternal).map(_.hostStore) + }.foreach(_.remove(handle)) + } + + private[spill] def removeFromDiskStore(handle: DiskHandle): Unit = { + synchronized { + Option(storesInternal).map(_.diskStore) + }.foreach(_.remove(handle)) + } +} + +/** + * A bounce buffer wrapper class that supports the concept of acquisition. + * + * The bounce buffer is acquired exclusively. So any calls to acquire while the + * buffer is in use will block at `acquire`. Calls to `release` notify the blocked + * threads, and they will check to see if they can acquire. + * + * `close` is the interface to unacquire the bounce buffer. + * + * `release` actually closes the underlying DeviceMemoryBuffer, and should be called + * once at the end of the lifetime of the executor. + * + * @param dmb - actual cudf DeviceMemoryBuffer that this class is protecting. + */ +private[spill] case class DeviceBounceBuffer(var dmb: DeviceMemoryBuffer) extends AutoCloseable { + private var acquired: Boolean = false + def acquire(): DeviceBounceBuffer = synchronized { + while (acquired) { + wait() + } + acquired = true + this + } + + private def unaquire(): Unit = synchronized { + acquired = false + notifyAll() + } + + override def close(): Unit = { + unaquire() + } + + def release(): Unit = synchronized { + if (acquired) { + throw new IllegalStateException( + "closing device buffer pool, but some bounce buffers are in use.") + } + if (dmb != null) { + dmb.close() + dmb = null + } + } +} + +/** + * A bounce buffer pool with buffers of size `bufferSize` + * + * This pool returns instances of `DeviceBounceBuffer`, that should + * be closed in order to be reused. + * + * Callers should synchronize before calling close on their `DeviceMemoryBuffer`s. + */ +trait DeviceBounceBufferPool extends AutoCloseable { + def bufferSize: Long + def nextBuffer(): DeviceBounceBuffer +} + +/** + * ChunkedPacker is an Iterator-like class that uses a cudf::chunked_pack to copy a cuDF `Table` + * to a target buffer in chunks. It implements a next method that takes a DeviceMemoryBuffer + * as an argument to be used for the copy. + * + * Each chunk is sized at most `bounceBuffer.getLength`, and the caller should cudaMemcpy + * bytes from `bounceBuffer` to a target buffer after each call to `next()`. + * + * @note `ChunkedPacker` must be closed by the caller as it has GPU and host resources + * associated with it. + * + * @param table cuDF Table to chunk_pack + * @param bounceBufferPool bounce buffer pool to use during the lifetime of this packer. + */ +class ChunkedPacker(table: Table, + bounceBufferPool: DeviceBounceBufferPool) + extends Iterator[(DeviceBounceBuffer, Long)] with Logging with AutoCloseable { + + private var closed: Boolean = false + + // When creating cudf::chunked_pack use a pool if available, otherwise default to the + // per-device memory resource + private val chunkedPack = { + val pool = GpuDeviceManager.chunkedPackMemoryResource + val cudfChunkedPack = try { + pool.flatMap { chunkedPool => + Some(table.makeChunkedPack(bounceBufferPool.bufferSize, chunkedPool)) + } + } catch { + case _: OutOfMemoryError => + if (!ChunkedPacker.warnedAboutPoolFallback) { + ChunkedPacker.warnedAboutPoolFallback = true + logWarning( + s"OOM while creating chunked_pack using pool sized ${pool.map(_.getMaxSize)}B. " + + "Falling back to the per-device memory resource.") + } + None + } + + // if the pool is not configured, or we got an OOM, try again with the per-device pool + cudfChunkedPack.getOrElse { + table.makeChunkedPack(bounceBufferPool.bufferSize) + } + } + + private val packedMeta = withResource(chunkedPack.buildMetadata()) { packedMeta => + val tmpBB = packedMeta.getMetadataDirectBuffer + val metaCopy = ByteBuffer.allocateDirect(tmpBB.capacity()) + metaCopy.put(tmpBB) + metaCopy.flip() + metaCopy + } + + def getTotalContiguousSize: Long = chunkedPack.getTotalContiguousSize + + def getPackedMeta: ByteBuffer = { + packedMeta + } + + override def hasNext: Boolean = { + if (closed) { + throw new IllegalStateException(s"ChunkedPacker is closed") + } + chunkedPack.hasNext + } + + override def next(): (DeviceBounceBuffer, Long) = { + withResource(bounceBufferPool.nextBuffer()) { bounceBuffer => + if (closed) { + throw new IllegalStateException(s"ChunkedPacker is closed") + } + val bytesWritten = chunkedPack.next(bounceBuffer.dmb) + // we increment the refcount because the caller has no idea where + // this memory came from, so it should close it. + (bounceBuffer, bytesWritten) + } + } + + override def close(): Unit = { + if (!closed) { + closed = true + chunkedPack.close() + } + } +} + +private object ChunkedPacker { + private var warnedAboutPoolFallback: Boolean = false +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index 1b0ee21d494..7f8733b9e00 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -44,12 +44,12 @@ class GpuShuffleEnv(rapidsConf: RapidsConf) extends Logging { } } - def init(diskBlockManager: RapidsDiskBlockManager): Unit = { + def init(): Unit = { if (isRapidsShuffleConfigured) { shuffleCatalog = - new ShuffleBufferCatalog(RapidsBufferCatalog.singleton, diskBlockManager) + new ShuffleBufferCatalog() shuffleReceivedBufferCatalog = - new ShuffleReceivedBufferCatalog(RapidsBufferCatalog.singleton) + new ShuffleReceivedBufferCatalog() } } @@ -172,9 +172,9 @@ object GpuShuffleEnv extends Logging { // Functions below only get called from the executor // - def init(conf: RapidsConf, diskBlockManager: RapidsDiskBlockManager): Unit = { + def init(conf: RapidsConf): Unit = { val shuffleEnv = new GpuShuffleEnv(conf) - shuffleEnv.init(diskBlockManager) + shuffleEnv.init() env = shuffleEnv } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala index 84ca5e2ac51..a8892c85194 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala @@ -145,7 +145,6 @@ class GpuTaskMetrics extends Serializable { GpuTaskMetrics.decHostBytesAllocated(bytes) } - def incDiskBytesAllocated(bytes: Long): Unit = { GpuTaskMetrics.incDiskBytesAllocated(bytes) maxDiskBytesAllocated = maxDiskBytesAllocated.max(GpuTaskMetrics.diskBytesAllocated) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index 05bc76c3fab..ee7f0380331 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -1073,7 +1073,7 @@ class RapidsCachingWriter[K, V]( val blockId = ShuffleBlockId(handle.shuffleId, mapId, partId) if (batch.numRows > 0 && batch.numCols > 0) { // Add the table to the shuffle store - val handle = batch.column(0) match { + batch.column(0) match { case c: GpuPackedTableColumn => val contigTable = c.getContiguousTable partSize = c.getTableBuffer.getLength @@ -1081,23 +1081,14 @@ class RapidsCachingWriter[K, V]( catalog.addContiguousTable( blockId, contigTable, - SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY, - // we don't need to sync here, because we sync on the cuda - // stream after sliceInternalOnGpu (contiguous_split) - needsSync = false) + SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY) case c: GpuCompressedColumnVector => - val buffer = c.getTableBuffer - partSize = buffer.getLength - val tableMeta = c.getTableMeta - uncompressedMetric += tableMeta.bufferMeta().uncompressedSize() - catalog.addBuffer( + partSize = c.getTableBuffer.getLength + uncompressedMetric += c.getTableMeta.bufferMeta().uncompressedSize() + catalog.addCompressedBatch( blockId, - buffer, - tableMeta, - SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY, - // we don't need to sync here, because we sync on the cuda - // stream after compression. - needsSync = false) + batch, + SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY) case c => throw new IllegalStateException(s"Unexpected column type: ${c.getClass}") } @@ -1111,21 +1102,18 @@ class RapidsCachingWriter[K, V]( } else { sizes(partId) += partSize } - handle } else { // no device data, tracking only metadata val tableMeta = MetaUtils.buildDegenerateTableMeta(batch) - val handle = - catalog.addDegenerateRapidsBuffer( - blockId, - tableMeta) + catalog.addDegenerateRapidsBuffer( + blockId, + tableMeta) // ensure that we set the partition size to the default in this case if // we have non-zero rows, so this degenerate batch is shuffled. if (batch.numRows > 0) { sizes(partId) += DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT } - handle } } metricsReporter.incBytesWritten(bytesWritten) @@ -1279,9 +1267,8 @@ class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: Boolean) if (rapidsConf.isGPUShuffle && !isDriver) { val catalog = getCatalogOrThrow val requestHandler = new RapidsShuffleRequestHandler() { - override def acquireShuffleBuffer(tableId: Int): RapidsBuffer = { - val handle = catalog.getShuffleBufferHandle(tableId) - catalog.acquireBuffer(handle) + override def getShuffleHandle(tableId: Int): RapidsShuffleHandle = { + catalog.getShuffleBufferHandle(tableId) } override def getShuffleBufferMetas(sbbId: ShuffleBlockBatchId): Seq[TableMeta] = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/TempSpillBufferId.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/TempSpillBufferId.scala deleted file mode 100644 index 0f9510a28ba..00000000000 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/TempSpillBufferId.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.rapids - -import java.io.File -import java.util.UUID -import java.util.concurrent.atomic.AtomicInteger -import java.util.function.IntUnaryOperator - -import com.nvidia.spark.rapids.RapidsBufferId - -import org.apache.spark.storage.TempLocalBlockId - -object TempSpillBufferId { - private val MAX_TABLE_ID = Integer.MAX_VALUE - private val TABLE_ID_UPDATER = new IntUnaryOperator { - override def applyAsInt(i: Int): Int = if (i < MAX_TABLE_ID) i + 1 else 0 - } - - /** Tracks the next table identifier */ - private[this] val tableIdCounter = new AtomicInteger(0) - - def apply(): TempSpillBufferId = { - val tableId = tableIdCounter.getAndUpdate(TABLE_ID_UPDATER) - val tempBlockId = TempLocalBlockId(UUID.randomUUID()) - new TempSpillBufferId(tableId, tempBlockId) - } -} - -case class TempSpillBufferId private( - override val tableId: Int, - bufferId: TempLocalBlockId) extends RapidsBufferId { - - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - diskBlockManager.getFile(bufferId) -} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index bd30459d63e..9290a8a9482 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -167,7 +167,7 @@ class SerializeConcatHostBuffersDeserializeBatch( * This will populate `data` before any task has had a chance to call `.batch` on this class. * * If `batchInternal` is defined we are in the executor, and there is no work to be done. - * This broadcast has been materialized on the GPU/RapidsBufferCatalog, and it is completely + * This broadcast has been materialized on the GPU/spill store, and it is completely * managed by the plugin. * * Public for unit tests. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala index b2bb5461a40..cfd7e4e60b8 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,8 @@ object GpuBroadcastHelper { case broadcastBatch: SerializeConcatHostBuffersDeserializeBatch => RmmRapidsRetryIterator.withRetryNoSplit { withResource(new NvtxRange("getBroadcastBatch", NvtxColor.YELLOW)) { _ => - broadcastBatch.batch.getColumnarBatch() + val spillable = broadcastBatch.batch + spillable.getColumnarBatch() } } case v if SparkShimImpl.isEmptyRelation(v) => diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala index 70942001cae..f2f655efef5 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala @@ -36,10 +36,11 @@ package com.nvidia.spark.rapids.shuffle import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} +import scala.collection import scala.collection.mutable import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.{GpuSemaphore, RapidsBuffer, RapidsBufferHandle, RapidsConf, ShuffleReceivedBufferCatalog} +import com.nvidia.spark.rapids.{GpuSemaphore, RapidsConf, RapidsShuffleHandle, ShuffleReceivedBufferCatalog} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.jni.RmmSpark @@ -71,7 +72,7 @@ class RapidsShuffleIterator( localBlockManagerId: BlockManagerId, rapidsConf: RapidsConf, transport: RapidsShuffleTransport, - blocksByAddress: Array[(BlockManagerId, Seq[(BlockId, Long, Int)])], + blocksByAddress: Array[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], metricsUpdater: ShuffleMetricsUpdater, sparkTypes: Array[DataType], taskAttemptId: Long, @@ -90,7 +91,7 @@ class RapidsShuffleIterator( * A result for a successful buffer received * @param handle - the shuffle received buffer handle as tracked in the catalog */ - case class BufferReceived(handle: RapidsBufferHandle) extends ShuffleClientResult + case class BufferReceived(handle: RapidsShuffleHandle) extends ShuffleClientResult /** * A result for a failed attempt at receiving block metadata, or corresponding batches. @@ -180,7 +181,7 @@ class RapidsShuffleIterator( val (local, remote) = blocksByAddress.partition(ba => ba._1.host == localHost) (local ++ remote).foreach { - case (blockManagerId: BlockManagerId, blockIds: Seq[(BlockId, Long, Int)]) => { + case (blockManagerId: BlockManagerId, blockIds: collection.Seq[(BlockId, Long, Int)]) => { val shuffleRequestsMapIndex: Seq[BlockIdMapIndex] = blockIds.map { case (blockId, _, mapIndex) => /** @@ -200,7 +201,7 @@ class RapidsShuffleIterator( throw new IllegalArgumentException( s"${blockId.getClass} $blockId is not currently supported") } - } + }.toSeq val client = try { transport.makeClient(blockManagerId) @@ -245,7 +246,7 @@ class RapidsShuffleIterator( def clientDone: Boolean = clientExpectedBatches > 0 && clientExpectedBatches == clientResolvedBatches - override def batchReceived(handle: RapidsBufferHandle): Boolean = { + override def batchReceived(handle: RapidsShuffleHandle): Boolean = { resolvedBatches.synchronized { if (taskComplete) { false @@ -310,8 +311,7 @@ class RapidsShuffleIterator( logWarning(s"Iterator for task ${taskAttemptIdStr} closing, " + s"but it is not done. Closing ${resolvedBatches.size()} resolved batches!!") resolvedBatches.forEach { - case BufferReceived(handle) => - GpuShuffleEnv.getReceivedCatalog.removeBuffer(handle) + case BufferReceived(handle) => handle.close() case _ => } // tell the client to cancel pending requests @@ -337,8 +337,6 @@ class RapidsShuffleIterator( } override def next(): ColumnarBatch = { - var cb: ColumnarBatch = null - var sb: RapidsBuffer = null val range = new NvtxRange(s"RapidshuffleIterator.next", NvtxColor.RED) // If N tasks downstream are accumulating memory we run the risk OOM @@ -356,6 +354,7 @@ class RapidsShuffleIterator( // fetches and so it could produce device memory. Note this is not allowing for some external // thread to schedule the fetches for us, it may be something we consider in the future, given // memory pressure. + // No good way to get a metric in here for semaphore time. taskContext.foreach(GpuSemaphore.acquireIfNecessary) if (!started) { @@ -379,16 +378,12 @@ class RapidsShuffleIterator( val nvtxRangeAfterGettingBatch = new NvtxRange("RapidsShuffleIterator.gotBatch", NvtxColor.PURPLE) try { - sb = catalog.acquireBuffer(handle) - cb = sb.getColumnarBatch(sparkTypes) - metricsUpdater.update(blockedTime, 1, sb.memoryUsedBytes, cb.numRows()) + val (cb, memoryUsedBytes) = catalog.getColumnarBatchAndRemove(handle, sparkTypes) + metricsUpdater.update(blockedTime, 1, memoryUsedBytes, cb.numRows()) + cb } finally { nvtxRangeAfterGettingBatch.close() range.close() - if (sb != null) { - sb.close() - } - catalog.removeBuffer(handle) } case Some( TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage, throwable)) => @@ -414,6 +409,5 @@ class RapidsShuffleIterator( case _ => throw new IllegalStateException(s"Invalid result type $result") } - cb } } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala index 3334187bb16..1905190f30e 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala @@ -38,13 +38,13 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.shuffle.{RapidsShuffleIterator, RapidsShuffleTransport} import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.{ShuffleReader, ShuffleReadMetricsReporter} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator @@ -79,10 +79,12 @@ class RapidsCachingReader[K, C]( override def read(): Iterator[Product2[K, C]] = { val readRange = new NvtxRange(s"RapidsCachingReader.read", NvtxColor.DARK_GREEN) try { - val blocksForRapidsTransport = new ArrayBuffer[(BlockManagerId, Seq[(BlockId, Long, Int)])]() - val cachedBlocks = new ArrayBuffer[BlockId]() - val cachedBufferHandles = new ArrayBuffer[RapidsBufferHandle]() - val blocksByAddressMap: Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = blocksByAddress.toMap + val blocksForRapidsTransport = + new ArrayBuffer[(BlockManagerId, Seq[(BlockId, Long, Int)])]() + var cachedBatchIterator: Iterator[ColumnarBatch] = Iterator.empty + val blocksByAddressMap: Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = + blocksByAddress.toMap + var numCachedBlocks: Int = 0 blocksByAddressMap.keys.foreach(blockManagerId => { val blockInfos: Seq[(BlockId, Long, Int)] = blocksByAddressMap(blockManagerId) @@ -91,33 +93,29 @@ class RapidsCachingReader[K, C]( if (blockManagerId.executorId == localId.executorId) { val readLocalRange = new NvtxRange("Read Local", NvtxColor.GREEN) try { - blockInfos.foreach( - blockInfo => { - val blockId = blockInfo._1 - val shuffleBufferHandles: IndexedSeq[RapidsBufferHandle] = blockId match { - case sbbid: ShuffleBlockBatchId => - (sbbid.startReduceId to sbbid.endReduceId).flatMap { reduceId => - cachedBlocks.append(blockId) - val sBlockId = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId) - catalog.blockIdToBufferHandles(sBlockId) - } - case sbid: ShuffleBlockId => - cachedBlocks.append(blockId) - catalog.blockIdToBufferHandles(sbid) - case _ => throw new IllegalArgumentException( - s"${blockId.getClass} $blockId is not currently supported") - } - - cachedBufferHandles ++= shuffleBufferHandles - - // Update the spill priorities of these buffers to indicate they are about - // to be read and therefore should not be spilled if possible. - shuffleBufferHandles.foreach(catalog.updateSpillPriorityForLocalRead) - - if (shuffleBufferHandles.nonEmpty) { - metrics.incLocalBlocksFetched(1) - } - }) + cachedBatchIterator = blockInfos.iterator.flatMap { blockInfo => + val blockId = blockInfo._1 + val shuffleBufferHandles = blockId match { + case sbbid: ShuffleBlockBatchId => + (sbbid.startReduceId to sbbid.endReduceId).iterator.flatMap { reduceId => + val sbid = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId) + numCachedBlocks += 1 + catalog.getColumnarBatchIterator(sbid, sparkTypes) + } + case sbid: ShuffleBlockId => + numCachedBlocks += 1 + catalog.getColumnarBatchIterator(sbid, sparkTypes) + case _ => throw new IllegalArgumentException( + s"${blockId.getClass} $blockId is not currently supported") + } + + shuffleBufferHandles + } + + // Update the spill priorities of these buffers to indicate they are about + // to be read and therefore should not be spilled if possible. + // TODO: AB: shuffleBufferHandles.foreach(catalog.updateSpillPriorityForLocalRead) + metrics.incLocalBlocksFetched(numCachedBlocks) } finally { readLocalRange.close() } @@ -139,7 +137,7 @@ class RapidsCachingReader[K, C]( } }) - logInfo(s"Will read ${cachedBlocks.size} cached blocks, " + + logInfo(s"Will read ${numCachedBlocks} cached blocks, " + s"${blocksForRapidsTransport.size} remote blocks from the RapidsShuffleTransport. ") if (transport.isEmpty && blocksForRapidsTransport.nonEmpty) { @@ -159,17 +157,12 @@ class RapidsCachingReader[K, C]( val itRange = new NvtxRange("Shuffle Iterator prep", NvtxColor.BLUE) try { - val cachedIt = cachedBufferHandles.iterator.map(bufferHandle => { - // No good way to get a metric in here for semaphore wait time - GpuSemaphore.acquireIfNecessary(context) - val cb = withResource(catalog.acquireBuffer(bufferHandle)) { buffer => - buffer.getColumnarBatch(sparkTypes) - } + val cachedIt = cachedBatchIterator.map { cb => val cachedBytesRead = GpuColumnVector.getTotalDeviceMemoryUsed(cb) metrics.incLocalBytesRead(cachedBytesRead) metrics.incRecordsRead(cb.numRows()) (0, cb) - }).asInstanceOf[Iterator[(K, C)]] + }.asInstanceOf[Iterator[(K, C)]] val cbArrayFromUcx: Iterator[(K, C)] = if (blocksForRapidsTransport.nonEmpty) { val rapidsShuffleIterator = new RapidsShuffleIterator(localId, rapidsConf, transport.get, diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala index 56552dac7b7..868e5492e2b 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala @@ -36,7 +36,7 @@ import scala.collection import scala.collection.mutable import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.{GpuSemaphore, RapidsBuffer, RapidsBufferHandle, RapidsConf, ShuffleReceivedBufferCatalog} +import com.nvidia.spark.rapids.{GpuSemaphore, RapidsConf, RapidsShuffleHandle, ShuffleReceivedBufferCatalog} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.jni.RmmSpark @@ -87,7 +87,7 @@ class RapidsShuffleIterator( * A result for a successful buffer received * @param handle - the shuffle received buffer handle as tracked in the catalog */ - case class BufferReceived(handle: RapidsBufferHandle) extends ShuffleClientResult + case class BufferReceived(handle: RapidsShuffleHandle) extends ShuffleClientResult /** * A result for a failed attempt at receiving block metadata, or corresponding batches. @@ -223,7 +223,7 @@ class RapidsShuffleIterator( override def getTaskIds: Array[Long] = taskIds - def start(expectedBatches: Int): Unit = resolvedBatches.synchronized { + override def start(expectedBatches: Int): Unit = resolvedBatches.synchronized { if (expectedBatches == 0) { throw new IllegalStateException( s"Received an invalid response from shuffle server: " + @@ -242,7 +242,7 @@ class RapidsShuffleIterator( def clientDone: Boolean = clientExpectedBatches > 0 && clientExpectedBatches == clientResolvedBatches - def batchReceived(handle: RapidsBufferHandle): Boolean = { + override def batchReceived(handle: RapidsShuffleHandle): Boolean = { resolvedBatches.synchronized { if (taskComplete) { false @@ -307,8 +307,7 @@ class RapidsShuffleIterator( logWarning(s"Iterator for task ${taskAttemptIdStr} closing, " + s"but it is not done. Closing ${resolvedBatches.size()} resolved batches!!") resolvedBatches.forEach { - case BufferReceived(handle) => - GpuShuffleEnv.getReceivedCatalog.removeBuffer(handle) + case BufferReceived(handle) => handle.close() case _ => } // tell the client to cancel pending requests @@ -334,8 +333,6 @@ class RapidsShuffleIterator( } override def next(): ColumnarBatch = { - var cb: ColumnarBatch = null - var sb: RapidsBuffer = null val range = new NvtxRange(s"RapidshuffleIterator.next", NvtxColor.RED) // If N tasks downstream are accumulating memory we run the risk OOM @@ -377,16 +374,12 @@ class RapidsShuffleIterator( val nvtxRangeAfterGettingBatch = new NvtxRange("RapidsShuffleIterator.gotBatch", NvtxColor.PURPLE) try { - sb = catalog.acquireBuffer(handle) - cb = sb.getColumnarBatch(sparkTypes) - metricsUpdater.update(blockedTime, 1, sb.memoryUsedBytes, cb.numRows()) + val (cb, memoryUsedBytes) = catalog.getColumnarBatchAndRemove(handle, sparkTypes) + metricsUpdater.update(blockedTime, 1, memoryUsedBytes, cb.numRows()) + cb } finally { nvtxRangeAfterGettingBatch.close() range.close() - if (sb != null) { - sb.close() - } - catalog.removeBuffer(handle) } case Some( TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage, throwable)) => @@ -412,6 +405,5 @@ class RapidsShuffleIterator( case _ => throw new IllegalStateException(s"Invalid result type $result") } - cb } } diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala index f7afe6aeba4..bc962e1bf5c 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala @@ -35,13 +35,13 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.shuffle.{RapidsShuffleIterator, RapidsShuffleTransport} import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.{ShuffleReader, ShuffleReadMetricsReporter} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator @@ -78,10 +78,10 @@ class RapidsCachingReader[K, C]( try { val blocksForRapidsTransport = new ArrayBuffer[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]() - val cachedBlocks = new ArrayBuffer[BlockId]() - val cachedBufferHandles = new ArrayBuffer[RapidsBufferHandle]() - val blocksByAddressMap: Map[BlockManagerId, collection.Seq[(BlockId, Long, Int)]] = + var cachedBatchIterator: Iterator[ColumnarBatch] = Iterator.empty + val blocksByAddressMap: Map[BlockManagerId, collection.Seq[(BlockId, Long, Int)]] = blocksByAddress.toMap + var numCachedBlocks: Int = 0 blocksByAddressMap.keys.foreach(blockManagerId => { val blockInfos: collection.Seq[(BlockId, Long, Int)] = blocksByAddressMap(blockManagerId) @@ -90,33 +90,29 @@ class RapidsCachingReader[K, C]( if (blockManagerId.executorId == localId.executorId) { val readLocalRange = new NvtxRange("Read Local", NvtxColor.GREEN) try { - blockInfos.foreach( - blockInfo => { - val blockId = blockInfo._1 - val shuffleBufferHandles: IndexedSeq[RapidsBufferHandle] = blockId match { - case sbbid: ShuffleBlockBatchId => - (sbbid.startReduceId to sbbid.endReduceId).flatMap { reduceId => - cachedBlocks.append(blockId) - val sBlockId = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId) - catalog.blockIdToBufferHandles(sBlockId) - } - case sbid: ShuffleBlockId => - cachedBlocks.append(blockId) - catalog.blockIdToBufferHandles(sbid) - case _ => throw new IllegalArgumentException( - s"${blockId.getClass} $blockId is not currently supported") - } - - cachedBufferHandles ++= shuffleBufferHandles - - // Update the spill priorities of these buffers to indicate they are about - // to be read and therefore should not be spilled if possible. - shuffleBufferHandles.foreach(catalog.updateSpillPriorityForLocalRead) - - if (shuffleBufferHandles.nonEmpty) { - metrics.incLocalBlocksFetched(1) - } - }) + cachedBatchIterator = blockInfos.iterator.flatMap { blockInfo => + val blockId = blockInfo._1 + val shuffleBufferHandles = blockId match { + case sbbid: ShuffleBlockBatchId => + (sbbid.startReduceId to sbbid.endReduceId).iterator.flatMap { reduceId => + val sbid = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId) + numCachedBlocks += 1 + catalog.getColumnarBatchIterator(sbid, sparkTypes) + } + case sbid: ShuffleBlockId => + numCachedBlocks += 1 + catalog.getColumnarBatchIterator(sbid, sparkTypes) + case _ => throw new IllegalArgumentException( + s"${blockId.getClass} $blockId is not currently supported") + } + + shuffleBufferHandles + } + + // Update the spill priorities of these buffers to indicate they are about + // to be read and therefore should not be spilled if possible. + // TODO: AB: shuffleBufferHandles.foreach(catalog.updateSpillPriorityForLocalRead) + metrics.incLocalBlocksFetched(numCachedBlocks) } finally { readLocalRange.close() } @@ -138,7 +134,7 @@ class RapidsCachingReader[K, C]( } }) - logInfo(s"Will read ${cachedBlocks.size} cached blocks, " + + logInfo(s"Will read ${numCachedBlocks} cached blocks, " + s"${blocksForRapidsTransport.size} remote blocks from the RapidsShuffleTransport. ") if (transport.isEmpty && blocksForRapidsTransport.nonEmpty) { @@ -158,17 +154,12 @@ class RapidsCachingReader[K, C]( val itRange = new NvtxRange("Shuffle Iterator prep", NvtxColor.BLUE) try { - val cachedIt = cachedBufferHandles.iterator.map(bufferHandle => { - // No good way to get a metric in here for semaphore wait time - GpuSemaphore.acquireIfNecessary(context) - val cb = withResource(catalog.acquireBuffer(bufferHandle)) { buffer => - buffer.getColumnarBatch(sparkTypes) - } + val cachedIt = cachedBatchIterator.map { cb => val cachedBytesRead = GpuColumnVector.getTotalDeviceMemoryUsed(cb) metrics.incLocalBytesRead(cachedBytesRead) metrics.incRecordsRead(cb.numRows()) (0, cb) - }).asInstanceOf[Iterator[(K, C)]] + }.asInstanceOf[Iterator[(K, C)]] val cbArrayFromUcx: Iterator[(K, C)] = if (blocksForRapidsTransport.nonEmpty) { val rapidsShuffleIterator = new RapidsShuffleIterator(localId, rapidsConf, transport.get, diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala index 24755c2c0a1..fbd61b9a7fb 100644 --- a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala @@ -21,6 +21,7 @@ import java.util.concurrent.{ExecutionException, Future, LinkedBlockingQueue, Ti import ai.rapids.cudf.{HostMemoryBuffer, PinnedMemoryPool, Rmm, RmmAllocationMode} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.jni.{RmmSpark, RmmSparkThreadState} +import com.nvidia.spark.rapids.spill._ import org.mockito.Mockito.when import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.concurrent.{Signaler, TimeLimits} @@ -28,7 +29,7 @@ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.time._ import org.scalatestplus.mockito.MockitoSugar.mock -import org.apache.spark.TaskContext +import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.execution.TrampolineUtil @@ -36,7 +37,7 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil class HostAllocSuite extends AnyFunSuite with BeforeAndAfterEach with BeforeAndAfterAll with TimeLimits { private val sqlConf = new SQLConf() - private val rc = new RapidsConf(sqlConf) + Rmm.shutdown() private val timeoutMs = 10000 def setMockContext(taskAttemptId: Long): Unit = { @@ -316,23 +317,34 @@ class HostAllocSuite extends AnyFunSuite with BeforeAndAfterEach with private var rmmWasInitialized = false override def beforeEach(): Unit = { - RapidsBufferCatalog.close() + val sc = new SparkConf SparkSession.getActiveSession.foreach(_.stop()) SparkSession.clearActiveSession() + SpillFramework.shutdown() if (Rmm.isInitialized) { rmmWasInitialized = true Rmm.shutdown() } Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) + // this doesn't allocate memory for bounce buffers, as HostAllocSuite + // is playing games with the pools. + SpillFramework.storesInternal = new SpillableStores { + override var deviceStore: SpillableDeviceStore = new SpillableDeviceStore + override var hostStore: SpillableHostStore = new SpillableHostStore(None) + override var diskStore: DiskHandleStore = new DiskHandleStore(sc) + } + // some tests need an event handler + RmmSpark.setEventHandler( + new DeviceMemoryEventHandler(SpillFramework.stores.deviceStore, None, 0)) PinnedMemoryPool.shutdown() HostAlloc.initialize(-1) - RapidsBufferCatalog.init(rc) } override def afterAll(): Unit = { - RapidsBufferCatalog.close() + SpillFramework.shutdown() PinnedMemoryPool.shutdown() Rmm.shutdown() + RmmSpark.clearEventHandler() if (rmmWasInitialized) { // put RMM back for other tests to use Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala index 0b531adabb7..9ba0147d878 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.spill.SpillableDeviceStore import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.when import org.scalatestplus.mockito.MockitoSugar @@ -23,12 +24,9 @@ import org.scalatestplus.mockito.MockitoSugar class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoSugar { test("a failed allocation should be retried if we spilled enough") { - val mockCatalog = mock[RapidsBufferCatalog] - val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSpillableSize).thenReturn(1024) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024L)) + val mockStore = mock[SpillableDeviceStore] + when(mockStore.spill(any())).thenAnswer(_ => 1024L) val handler = new DeviceMemoryEventHandler( - mockCatalog, mockStore, None, 2) @@ -36,12 +34,9 @@ class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoS } test("when we deplete the store, retry up to max failed OOM retries") { - val mockCatalog = mock[RapidsBufferCatalog] - val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSpillableSize).thenReturn(0) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(0L)) + val mockStore = mock[SpillableDeviceStore] + when(mockStore.spill(any())).thenAnswer(_ => 0L) val handler = new DeviceMemoryEventHandler( - mockCatalog, mockStore, None, 2) @@ -51,12 +46,9 @@ class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoS } test("we reset our OOM state after a successful retry") { - val mockCatalog = mock[RapidsBufferCatalog] - val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSpillableSize).thenReturn(0) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(0L)) + val mockStore = mock[SpillableDeviceStore] + when(mockStore.spill(any())).thenAnswer(_ => 0L) val handler = new DeviceMemoryEventHandler( - mockCatalog, mockStore, None, 2) @@ -69,12 +61,9 @@ class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoS } test("a negative allocation cannot be retried and handler throws") { - val mockCatalog = mock[RapidsBufferCatalog] - val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSpillableSize).thenReturn(1024) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024L)) + val mockStore = mock[SpillableDeviceStore] + when(mockStore.spill(any())).thenAnswer(_ => 1024L) val handler = new DeviceMemoryEventHandler( - mockCatalog, mockStore, None, 2) @@ -82,12 +71,9 @@ class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoS } test("a negative retry count is invalid") { - val mockCatalog = mock[RapidsBufferCatalog] - val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSpillableSize).thenReturn(1024) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024L)) + val mockStore = mock[SpillableDeviceStore] + when(mockStore.spill(any())).thenAnswer(_ => 1024L) val handler = new DeviceMemoryEventHandler( - mockCatalog, mockStore, None, 2) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala index 725a2e37032..da03992f72f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala @@ -19,8 +19,9 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.Table import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.jni.{GpuSplitAndRetryOOM, RmmSpark} +import com.nvidia.spark.rapids.spill.{SpillableColumnarBatchHandle, SpillableDeviceStore, SpillFramework} import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{doAnswer, spy, times, verify} +import org.mockito.Mockito.{doAnswer, spy} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatestplus.mockito.MockitoSugar @@ -41,6 +42,14 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } } + override def beforeEach(): Unit = { + // some tests in this suite will want to perform `verify` calls on the device store + // so we close it and create a spy around one. + super.beforeEach() + SpillFramework.storesInternal.deviceStore.close() + SpillFramework.storesInternal.deviceStore = spy(new SpillableDeviceStore) + } + private def getAndResetNumRetryThrowCurrentTask: Int = { // taskId 1 was associated with the current thread in RmmSparkRetrySuiteBase RmmSpark.getAndResetNumRetryThrow(/*taskId*/ 1) @@ -65,26 +74,24 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } assert(!GpuColumnVector.extractBases(batch).exists(_.getRefCount > 0)) assert(!myIter.hasNext) - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } test("a retry when converting to a table is handled") { val batch = buildBatch() val batchIter = Seq(batch).iterator - var rapidsBufferSpy: RapidsBuffer = null - doAnswer(new Answer[AnyRef]() { - override def answer(invocation: InvocationOnMock): AnyRef = { - val res = invocation.callRealMethod() + doAnswer(new Answer[Boolean]() { + override def answer(invocation: InvocationOnMock): Boolean = { + invocation.callRealMethod() // we mock things this way due to code generation issues with mockito. // when we add a table we have RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 3, RmmSpark.OomInjectionType.GPU.ordinal, 0) - rapidsBufferSpy = spy(res.asInstanceOf[RapidsBuffer]) - rapidsBufferSpy + true } - }).when(deviceStorage) - .addTable(any(), any(), any(), any()) + }).when(SpillFramework.stores.deviceStore) + .track(any()) withResource(new ColumnarToRowIterator(batchIter, NoopMetric, NoopMetric, NoopMetric, NoopMetric)) { ctriter => @@ -102,35 +109,27 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } assertResult(6)(getAndResetNumRetryThrowCurrentTask) assert(!myIter.hasNext) - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) - // This is my wrap around of checking that we did retry the last part - // where we are converting the device column of rows into an actual column. - // Because we asked for 3 retries, we would ask the spill framework 4 times to materialize - // a batch. - verify(rapidsBufferSpy, times(4)) - .getColumnarBatch(any()) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } test("spilling the device column of rows works") { val batch = buildBatch() val batchIter = Seq(batch).iterator - var rapidsBufferSpy: RapidsBuffer = null - doAnswer(new Answer[AnyRef]() { - override def answer(invocation: InvocationOnMock): AnyRef = { - val res = invocation.callRealMethod() + doAnswer(new Answer[Boolean]() { + override def answer(invocation: InvocationOnMock): Boolean = { + val handle = invocation.getArgument(0).asInstanceOf[SpillableColumnarBatchHandle] // we mock things this way due to code generation issues with mockito. // when we add a table we have RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 3, RmmSpark.OomInjectionType.GPU.ordinal, 0) - rapidsBufferSpy = spy(res.asInstanceOf[RapidsBuffer]) // at this point we have created a buffer in the Spill Framework // lets spill it - RapidsBufferCatalog.singleton.synchronousSpill(deviceStorage, 0) - rapidsBufferSpy + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) + true } - }).when(deviceStorage) - .addTable(any(), any(), any(), any()) + }).when(SpillFramework.stores.deviceStore) + .track(any()) withResource(new ColumnarToRowIterator(batchIter, NoopMetric, NoopMetric, NoopMetric, NoopMetric)) { ctriter => @@ -148,13 +147,7 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } assertResult(6)(getAndResetNumRetryThrowCurrentTask) assert(!myIter.hasNext) - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) - // This is my wrap around of checking that we did retry the last part - // where we are converting the device column of rows into an actual column. - // Because we asked for 3 retries, we would ask the spill framework 4 times to materialize - // a batch. - verify(rapidsBufferSpy, times(4)) - .getColumnarBatch(any()) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } @@ -173,7 +166,7 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite assertThrows[GpuSplitAndRetryOOM] { myIter.next() } - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } @@ -199,7 +192,7 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } assert(!GpuColumnVector.extractBases(batch).exists(_.getRefCount > 0)) assert(!myIter.hasNext) - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } @@ -225,7 +218,7 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } assert(!GpuColumnVector.extractBases(batch).exists(_.getRefCount > 0)) assert(!myIter.hasNext) - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } } \ No newline at end of file diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala index 34a9fa984d5..4f0027d4c2a 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala @@ -199,7 +199,7 @@ class GpuCoalesceBatchesRetrySuite val batches = iter.asInstanceOf[CoalesceIteratorMocks].getBatches() assertResult(10)(batches.length) batches.foreach(b => - verify(b, times(1)).close() + GpuColumnVector.extractBases(b).forall(_.getRefCount == 0) ) } } @@ -209,7 +209,7 @@ class GpuCoalesceBatchesRetrySuite var refCount = 1 override def numRows(): Int = 0 override def setSpillPriority(priority: Long): Unit = {} - override def getColumnarBatch(): ColumnarBatch = { + override def getColumnarBatch: ColumnarBatch = { throw new GpuSplitAndRetryOOM() } override def sizeInBytes: Long = 0 diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuGenerateSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuGenerateSuite.scala index e69f7c75118..fbbf0acc20d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuGenerateSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuGenerateSuite.scala @@ -268,12 +268,12 @@ class GpuGenerateSuite var forceOOM: Boolean) extends SpillableColumnarBatch { override def numRows(): Int = spillable.numRows() override def setSpillPriority(priority: Long): Unit = spillable.setSpillPriority(priority) - override def getColumnarBatch(): ColumnarBatch = { + override def getColumnarBatch: ColumnarBatch = { if (forceOOM) { forceOOM = false throw new GpuSplitAndRetryOOM(s"mock split and retry") } - spillable.getColumnarBatch() + spillable.getColumnarBatch } override def sizeInBytes: Long = spillable.sizeInBytes override def dataTypes: Array[DataType] = spillable.dataTypes diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala index 1e3c0f699da..fdbe316a394 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala @@ -16,20 +16,20 @@ package com.nvidia.spark.rapids -import java.io.File import java.math.RoundingMode import ai.rapids.cudf.{ColumnVector, Cuda, DType, Table} import com.nvidia.spark.rapids.Arm.withResource +import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.{GpuShuffleEnv, RapidsDiskBlockManager} +import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types.{DecimalType, DoubleType, IntegerType, StringType} import org.apache.spark.sql.vectorized.ColumnarBatch -class GpuPartitioningSuite extends AnyFunSuite { +class GpuPartitioningSuite extends AnyFunSuite with BeforeAndAfterEach { var rapidsConf = new RapidsConf(Map[String, String]()) private def buildBatch(): ColumnarBatch = { @@ -113,7 +113,7 @@ class GpuPartitioningSuite extends AnyFunSuite { TrampolineUtil.cleanupAnyExistingSession() val conf = new SparkConf().set(RapidsConf.SHUFFLE_COMPRESSION_CODEC.key, "none") TestUtils.withGpuSparkSession(conf) { _ => - GpuShuffleEnv.init(new RapidsConf(conf), new RapidsDiskBlockManager(conf)) + GpuShuffleEnv.init(new RapidsConf(conf)) val partitionIndices = Array(0, 2, 2) val gp = new GpuPartitioning { override val numPartitions: Int = partitionIndices.length @@ -157,61 +157,53 @@ class GpuPartitioningSuite extends AnyFunSuite { val conf = new SparkConf() .set(RapidsConf.SHUFFLE_COMPRESSION_CODEC.key, codecName) TestUtils.withGpuSparkSession(conf) { _ => - GpuShuffleEnv.init(new RapidsConf(conf), new RapidsDiskBlockManager(conf)) - val spillPriority = 7L - - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = new RapidsBufferCatalog(store) - val partitionIndices = Array(0, 2, 2) - val gp = new GpuPartitioning { - override val numPartitions: Int = partitionIndices.length - } - withResource(buildBatch()) { batch => - // `sliceInternalOnGpuAndClose` will close the batch, but in this test we want to - // reuse it - GpuColumnVector.incRefCounts(batch) - val columns = GpuColumnVector.extractColumns(batch) - val sparkTypes = GpuColumnVector.extractTypes(batch) - val numRows = batch.numRows - withResource( - gp.sliceInternalOnGpuAndClose(numRows, partitionIndices, columns)) { partitions => - partitions.zipWithIndex.foreach { case (partBatch, partIndex) => - val startRow = partitionIndices(partIndex) - val endRow = if (partIndex < partitionIndices.length - 1) { - partitionIndices(partIndex + 1) - } else { - batch.numRows - } - val expectedRows = endRow - startRow - assertResult(expectedRows)(partBatch.numRows) - val columns = (0 until partBatch.numCols).map(i => partBatch.column(i)) - columns.foreach { column => - // batches with any rows should be compressed, and - // batches with no rows should not be compressed. - val actualRows = column match { - case c: GpuCompressedColumnVector => - val rows = c.getTableMeta.rowCount - assert(rows != 0) - rows - case c: GpuPackedTableColumn => - val rows = c.getContiguousTable.getRowCount - assert(rows == 0) - rows - case _ => - throw new IllegalStateException("column should either be compressed or packed") - } - assertResult(expectedRows)(actualRows) + GpuShuffleEnv.init(new RapidsConf(conf)) + val partitionIndices = Array(0, 2, 2) + val gp = new GpuPartitioning { + override val numPartitions: Int = partitionIndices.length + } + withResource(buildBatch()) { batch => + // `sliceInternalOnGpuAndClose` will close the batch, but in this test we want to + // reuse it + GpuColumnVector.incRefCounts(batch) + val columns = GpuColumnVector.extractColumns(batch) + val numRows = batch.numRows + withResource( + gp.sliceInternalOnGpuAndClose(numRows, partitionIndices, columns)) { partitions => + partitions.zipWithIndex.foreach { case (partBatch, partIndex) => + val startRow = partitionIndices(partIndex) + val endRow = if (partIndex < partitionIndices.length - 1) { + partitionIndices(partIndex + 1) + } else { + batch.numRows + } + val expectedRows = endRow - startRow + assertResult(expectedRows)(partBatch.numRows) + val columns = (0 until partBatch.numCols).map(i => partBatch.column(i)) + columns.foreach { column => + // batches with any rows should be compressed, and + // batches with no rows should not be compressed. + val actualRows = column match { + case c: GpuCompressedColumnVector => + val rows = c.getTableMeta.rowCount + assert(rows != 0) + rows + case c: GpuPackedTableColumn => + val rows = c.getContiguousTable.getRowCount + assert(rows == 0) + rows + case _ => + throw new IllegalStateException( + "column should either be compressed or packed") } - if (GpuCompressedColumnVector.isBatchCompressed(partBatch)) { - val gccv = columns.head.asInstanceOf[GpuCompressedColumnVector] - val devBuffer = gccv.getTableBuffer - val handle = catalog.addBuffer(devBuffer, gccv.getTableMeta, spillPriority) - withResource(buildSubBatch(batch, startRow, endRow)) { expectedBatch => - withResource(catalog.acquireBuffer(handle)) { buffer => - withResource(buffer.getColumnarBatch(sparkTypes)) { batch => - compareBatches(expectedBatch, batch) - } - } + assertResult(expectedRows)(actualRows) + } + if (GpuCompressedColumnVector.isBatchCompressed(partBatch)) { + GpuCompressedColumnVector.incRefCounts(partBatch) + val handle = SpillableColumnarBatch(partBatch, -1) + withResource(buildSubBatch(batch, startRow, endRow)) { expectedBatch => + withResource(handle.getColumnarBatch()) { cb => + compareBatches(expectedBatch, cb) } } } @@ -220,9 +212,4 @@ class GpuPartitioningSuite extends AnyFunSuite { } } } -} - -case class MockRapidsBufferId(tableId: Int) extends RapidsBufferId { - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - throw new UnsupportedOperationException -} +} \ No newline at end of file diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala index 9211c32e142..6c5ff16ffcb 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ import com.nvidia.spark.rapids.Arm.withResource import org.scalatest.funsuite.AnyFunSuite import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.{GpuShuffleEnv, RapidsDiskBlockManager} +import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.sql.types.{DecimalType, DoubleType, IntegerType, StringType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -48,7 +48,7 @@ class GpuSinglePartitioningSuite extends AnyFunSuite { .set("spark.rapids.shuffle.mode", RapidsConf.RapidsShuffleManagerMode.UCX.toString) .set(RapidsConf.SHUFFLE_COMPRESSION_CODEC.key, "none") TestUtils.withGpuSparkSession(conf) { _ => - GpuShuffleEnv.init(new RapidsConf(conf), new RapidsDiskBlockManager(conf)) + GpuShuffleEnv.init(new RapidsConf(conf)) val partitioner = GpuSinglePartitioning withResource(buildBatch()) { batch => withResource(GpuColumnVector.from(batch)) { table => diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala index 58608ed132c..82fb1dd4154 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala @@ -70,10 +70,8 @@ class HashAggregateRetrySuite when(aggHelper.aggOrdinals).thenReturn(aggOrdinals) // attempt a cuDF reduction - withResource(input) { _ => - GpuAggregateIterator.aggregate( - aggHelper, input, mockMetrics) - } + GpuAggregateIterator.aggregate( + aggHelper, input, mockMetrics) } def makeGroupByAggHelper(forceMerge: Boolean): AggHelper = { @@ -118,172 +116,187 @@ class HashAggregateRetrySuite test("computeAndAggregate reduction with retry") { val reductionBatch = buildReductionBatch() - RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doReduction(reductionBatch) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(1)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - withResource(gcv.getBase.copyToHost()) { hcv => - assertResult(9)(hcv.getLong(0)) + withResource(reductionBatch.incRefCount()) { _ => + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doReduction(reductionBatch) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(1)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + withResource(gcv.getBase.copyToHost()) { hcv => + assertResult(9)(hcv.getLong(0)) + } } } + // we need to request a ColumnarBatch twice here for the retry + // why is this invoking the underlying method + verify(reductionBatch, times(2)).getColumnarBatch } - // we need to request a ColumnarBatch twice here for the retry - verify(reductionBatch, times(2)).getColumnarBatch() } test("computeAndAggregate reduction with two retries") { val reductionBatch = buildReductionBatch() - RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 2, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doReduction(reductionBatch) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(1)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - withResource(gcv.getBase.copyToHost()) { hcv => - assertResult(9)(hcv.getLong(0)) + withResource(reductionBatch.incRefCount()) { _ => + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 2, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doReduction(reductionBatch) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(1)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + withResource(gcv.getBase.copyToHost()) { hcv => + assertResult(9)(hcv.getLong(0)) + } } } + // we need to request a ColumnarBatch three times, because of 1 regular attempt, + // and two retries + verify(reductionBatch, times(3)).getColumnarBatch } - // we need to request a ColumnarBatch three times, because of 1 regular attempt, - // and two retries - verify(reductionBatch, times(3)).getColumnarBatch() } test("computeAndAggregate reduction with cudf exception") { val reductionBatch = buildReductionBatch() - RmmSpark.forceCudfException(RmmSpark.getCurrentThreadId) - assertThrows[CudfException] { - doReduction(reductionBatch) + withResource(reductionBatch.incRefCount()) { _ => + RmmSpark.forceCudfException(RmmSpark.getCurrentThreadId) + assertThrows[CudfException] { + doReduction(reductionBatch) + } + // columnar batch was obtained once, but since this was not a retriable exception + // we don't retry it + verify(reductionBatch, times(1)).getColumnarBatch } - // columnar batch was obtained once, but since this was not a retriable exception - // we don't retry it - verify(reductionBatch, times(1)).getColumnarBatch() } test("computeAndAggregate group by with retry") { val groupByBatch = buildGroupByBatch() - RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doGroupBy(groupByBatch) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(3)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - val aggv = cb.column(1).asInstanceOf[GpuColumnVector] - var rowsLeftToMatch = 3 - withResource(aggv.getBase.copyToHost()) { aggvh => - withResource(gcv.getBase.copyToHost()) { grph => - (0 until 3).foreach { row => - if (grph.isNull(row)) { - assertResult(2L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 5) { - assertResult(1L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 1) { - assertResult(7L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 + withResource(groupByBatch.incRefCount()) { _ => + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doGroupBy(groupByBatch) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(3)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + val aggv = cb.column(1).asInstanceOf[GpuColumnVector] + var rowsLeftToMatch = 3 + withResource(aggv.getBase.copyToHost()) { aggvh => + withResource(gcv.getBase.copyToHost()) { grph => + (0 until 3).foreach { row => + if (grph.isNull(row)) { + assertResult(2L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 5) { + assertResult(1L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 1) { + assertResult(7L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } } } } + assertResult(0)(rowsLeftToMatch) } - assertResult(0)(rowsLeftToMatch) } + // we need to request a ColumnarBatch twice here for the retry + verify(groupByBatch, times(2)).getColumnarBatch } - // we need to request a ColumnarBatch twice here for the retry - verify(groupByBatch, times(2)).getColumnarBatch() } test("computeAndAggregate reduction with split and retry") { val reductionBatch = buildReductionBatch() - RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doReduction(reductionBatch) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(1)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + withResource(reductionBatch.incRefCount()) { _ => + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doReduction(reductionBatch) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(1)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - withResource(gcv.getBase.copyToHost()) { hcv => - assertResult(9L)(hcv.getLong(0)) + withResource(gcv.getBase.copyToHost()) { hcv => + assertResult(9L)(hcv.getLong(0)) + } } } + // the second time we access this batch is to split it + verify(reductionBatch, times(2)).getColumnarBatch } - // the second time we access this batch is to split it - verify(reductionBatch, times(2)).getColumnarBatch() } test("computeAndAggregate group by with split retry") { val groupByBatch = buildGroupByBatch() - RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doGroupBy(groupByBatch) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(3)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - val aggv = cb.column(1).asInstanceOf[GpuColumnVector] - var rowsLeftToMatch = 3 - withResource(aggv.getBase.copyToHost()) { aggvh => - withResource(gcv.getBase.copyToHost()) { grph => - (0 until 3).foreach { row => - if (grph.isNull(row)) { - assertResult(2L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 5) { - assertResult(1L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 1) { - assertResult(7L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 + withResource(groupByBatch.incRefCount()) { _ => + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doGroupBy(groupByBatch) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(3)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + val aggv = cb.column(1).asInstanceOf[GpuColumnVector] + var rowsLeftToMatch = 3 + withResource(aggv.getBase.copyToHost()) { aggvh => + withResource(gcv.getBase.copyToHost()) { grph => + (0 until 3).foreach { row => + if (grph.isNull(row)) { + assertResult(2L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 5) { + assertResult(1L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 1) { + assertResult(7L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } } } } + assertResult(0)(rowsLeftToMatch) } - assertResult(0)(rowsLeftToMatch) } + // the second time we access this batch is to split it + verify(groupByBatch, times(2)).getColumnarBatch } - // the second time we access this batch is to split it - verify(groupByBatch, times(2)).getColumnarBatch() } test("computeAndAggregate group by with retry and forceMerge") { // with forceMerge we expect 1 batch to be returned at all costs val groupByBatch = buildGroupByBatch() - // we force a split because that would cause us to compute two aggs - RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doGroupBy(groupByBatch, forceMerge = true) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(3)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - val aggv = cb.column(1).asInstanceOf[GpuColumnVector] - var rowsLeftToMatch = 3 - withResource(aggv.getBase.copyToHost()) { aggvh => - withResource(gcv.getBase.copyToHost()) { grph => - (0 until 3).foreach { row => - if (grph.isNull(row)) { - assertResult(2L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 5) { - assertResult(1L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 1) { - assertResult(7L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 + withResource(groupByBatch.incRefCount()) { _ => + // we force a split because that would cause us to compute two aggs + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doGroupBy(groupByBatch, forceMerge = true) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(3)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + val aggv = cb.column(1).asInstanceOf[GpuColumnVector] + var rowsLeftToMatch = 3 + withResource(aggv.getBase.copyToHost()) { aggvh => + withResource(gcv.getBase.copyToHost()) { grph => + (0 until 3).foreach { row => + if (grph.isNull(row)) { + assertResult(2L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 5) { + assertResult(1L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 1) { + assertResult(7L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } } } } + assertResult(0)(rowsLeftToMatch) } - assertResult(0)(rowsLeftToMatch) } + // we need to request a ColumnarBatch twice here for the retry + verify(groupByBatch, times(2)).getColumnarBatch } - // we need to request a ColumnarBatch twice here for the retry - verify(groupByBatch, times(2)).getColumnarBatch() } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala deleted file mode 100644 index 9b5b37af480..00000000000 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala +++ /dev/null @@ -1,368 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.File - -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.StorageTier.{DEVICE, DISK, HOST, StorageTier} -import com.nvidia.spark.rapids.format.TableMeta -import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito._ -import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.mockito.MockitoSugar - -import org.apache.spark.sql.rapids.RapidsDiskBlockManager -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -class RapidsBufferCatalogSuite extends AnyFunSuite with MockitoSugar { - test("lookup unknown buffer") { - val catalog = new RapidsBufferCatalog - val bufferId = new RapidsBufferId { - override val tableId: Int = 10 - override def getDiskPath(m: RapidsDiskBlockManager): File = null - } - val bufferHandle = new RapidsBufferHandle { - override val id: RapidsBufferId = bufferId - override def setSpillPriority(newPriority: Long): Unit = {} - override def close(): Unit = {} - } - - assertThrows[NoSuchElementException](catalog.acquireBuffer(bufferHandle)) - assertThrows[NoSuchElementException](catalog.getBufferMeta(bufferId)) - } - - test("buffer double register throws") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId) - catalog.registerNewBuffer(buffer) - val buffer2 = mockBuffer(bufferId) - assertThrows[DuplicateBufferException](catalog.registerNewBuffer(buffer2)) - } - - test("a second handle prevents buffer to be removed") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId) - catalog.registerNewBuffer(buffer) - val handle1 = - catalog.makeNewHandle(bufferId, -1) - val handle2 = - catalog.makeNewHandle(bufferId, -1) - - handle1.close() - - // this does not throw - catalog.acquireBuffer(handle2).close() - // actually this doesn't throw either - catalog.acquireBuffer(handle1).close() - - handle2.close() - - assertThrows[NoSuchElementException](catalog.acquireBuffer(handle1)) - assertThrows[NoSuchElementException](catalog.acquireBuffer(handle2)) - } - - test("spill priorities are updated as handles are registered and unregistered") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, initialPriority = -1) - catalog.registerNewBuffer(buffer) - val handle1 = - catalog.makeNewHandle(bufferId, -1) - withResource(catalog.acquireBuffer(handle1)) { buff => - assertResult(-1)(buff.getSpillPriority) - } - val handle2 = - catalog.makeNewHandle(bufferId, 0) - withResource(catalog.acquireBuffer(handle2)) { buff => - assertResult(0)(buff.getSpillPriority) - } - - // removing the lower priority handle, keeps the high priority spill - handle1.close() - withResource(catalog.acquireBuffer(handle2)) { buff => - assertResult(0)(buff.getSpillPriority) - } - - // adding a lower priority -1000 handle keeps the high priority (0) spill - val handle3 = - catalog.makeNewHandle(bufferId, -1000) - withResource(catalog.acquireBuffer(handle3)) { buff => - assertResult(0)(buff.getSpillPriority) - } - - // removing the high priority spill (0) brings us down to the - // low priority that is remaining - handle2.close() - withResource(catalog.acquireBuffer(handle2)) { buff => - assertResult(-1000)(buff.getSpillPriority) - } - - handle3.close() - } - - test("buffer registering slower tier does not hide faster tier") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - val handle = catalog.makeNewHandle(bufferId, 0) - val buffer2 = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer2) - val buffer3 = mockBuffer(bufferId, tier = DISK) - catalog.registerNewBuffer(buffer3) - val acquired = catalog.acquireBuffer(handle) - assertResult(5)(acquired.id.tableId) - assertResult(buffer)(acquired) - - // registering the handle acquires the buffer - verify(buffer, times(2)).addReference() - } - - test("acquire buffer") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId) - catalog.registerNewBuffer(buffer) - val handle = catalog.makeNewHandle(bufferId, 0) - val acquired = catalog.acquireBuffer(handle) - assertResult(5)(acquired.id.tableId) - assertResult(buffer)(acquired) - - // registering the handle acquires the buffer - verify(buffer, times(2)).addReference() - } - - test("acquire buffer retries automatically") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, acquireAttempts = 9) - catalog.registerNewBuffer(buffer) - val handle = catalog.makeNewHandle(bufferId, 0) - val acquired = catalog.acquireBuffer(handle) - assertResult(5)(acquired.id.tableId) - assertResult(buffer)(acquired) - - // registering the handle acquires the buffer - verify(buffer, times(10)).addReference() - } - - test("acquire buffer at specific tier") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - val buffer2 = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer2) - val acquired = catalog.acquireBuffer(MockBufferId(5), HOST).get - assertResult(5)(acquired.id.tableId) - assertResult(buffer2)(acquired) - verify(buffer2).addReference() - } - - test("acquire buffer at nonexistent tier") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer) - assert(catalog.acquireBuffer(MockBufferId(5), DEVICE).isEmpty) - assert(catalog.acquireBuffer(MockBufferId(5), DISK).isEmpty) - } - - test("get buffer meta") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val expectedMeta = new TableMeta - val buffer = mockBuffer(bufferId, tableMeta = expectedMeta) - catalog.registerNewBuffer(buffer) - val meta = catalog.getBufferMeta(bufferId) - assertResult(expectedMeta)(meta) - } - - test("buffer is spilled to slower tier only") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - val buffer2 = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer2) - val buffer3 = mockBuffer(bufferId, tier = DISK) - catalog.registerNewBuffer(buffer3) - assert(catalog.isBufferSpilled(bufferId, DEVICE)) - assert(catalog.isBufferSpilled(bufferId, HOST)) - assert(!catalog.isBufferSpilled(bufferId, DISK)) - } - - test("multiple calls to unspill return existing DEVICE buffer") { - withResource(spy(new RapidsDeviceMemoryStore)) { deviceStore => - val mockStore = mock[RapidsBufferStore] - withResource( - new RapidsHostMemoryStore(Some(10000))) { hostStore => - deviceStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - val catalog = new RapidsBufferCatalog(deviceStore) - val handle = withResource(DeviceMemoryBuffer.allocate(1024)) { buff => - val meta = MetaUtils.getTableMetaNoTable(buff.getLength) - catalog.addBuffer( - buff, meta, -1) - } - withResource(handle) { _ => - catalog.synchronousSpill(deviceStore, 0) - val acquiredHostBuffer = catalog.acquireBuffer(handle) - val unspilled = withResource(acquiredHostBuffer) { _ => - assertResult(HOST)(acquiredHostBuffer.storageTier) - val unspilled = - catalog.unspillBufferToDeviceStore( - acquiredHostBuffer, - Cuda.DEFAULT_STREAM) - withResource(unspilled) { _ => - assertResult(DEVICE)(unspilled.storageTier) - } - val unspilledSame = catalog.unspillBufferToDeviceStore( - acquiredHostBuffer, - Cuda.DEFAULT_STREAM) - withResource(unspilledSame) { _ => - assertResult(unspilled)(unspilledSame) - } - // verify that we invoked the copy function exactly once - verify(deviceStore, times(1)).copyBuffer(any(), any(), any()) - unspilled - } - val unspilledSame = catalog.unspillBufferToDeviceStore( - acquiredHostBuffer, - Cuda.DEFAULT_STREAM) - withResource(unspilledSame) { _ => - assertResult(unspilled)(unspilledSame) - } - // verify that we invoked the copy function exactly once - verify(deviceStore, times(1)).copyBuffer(any(), any(), any()) - } - } - } - } - - test("remove buffer tier") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - val buffer2 = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer2) - val buffer3 = mockBuffer(bufferId, tier = DISK) - catalog.registerNewBuffer(buffer3) - catalog.removeBufferTier(bufferId, DEVICE) - catalog.removeBufferTier(bufferId, DISK) - assert(catalog.acquireBuffer(MockBufferId(5), DEVICE).isEmpty) - assert(catalog.acquireBuffer(MockBufferId(5), HOST).isDefined) - assert(catalog.acquireBuffer(MockBufferId(5), DISK).isEmpty) - } - - test("remove nonexistent buffer tier") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - catalog.removeBufferTier(bufferId, HOST) - catalog.removeBufferTier(bufferId, DISK) - assert(catalog.acquireBuffer(MockBufferId(5), DEVICE).isDefined) - assert(catalog.acquireBuffer(MockBufferId(5), HOST).isEmpty) - assert(catalog.acquireBuffer(MockBufferId(5), DISK).isEmpty) - } - - test("remove buffer releases buffer resources") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId) - catalog.registerNewBuffer(buffer) - val handle = catalog.makeNewHandle( - bufferId, -1) - handle.close() - verify(buffer).free() - } - - test("remove buffer releases buffer resources at all tiers") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - val handle = catalog.makeNewHandle( - bufferId, -1) - - // these next registrations don't get their own handle. This is an internal - // operation from the store where it has spilled to host and disk the RapidsBuffer - val buffer2 = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer2) - val buffer3 = mockBuffer(bufferId, tier = DISK) - catalog.registerNewBuffer(buffer3) - - // removing the original handle removes all buffers from all tiers. - handle.close() - verify(buffer).free() - verify(buffer2).free() - verify(buffer3).free() - } - - private def mockBuffer( - bufferId: RapidsBufferId, - tableMeta: TableMeta = null, - tier: StorageTier = StorageTier.DEVICE, - acquireAttempts: Int = 1, - initialPriority: Long = -1): RapidsBuffer = { - spy(new RapidsBuffer { - var _acquireAttempts: Int = acquireAttempts - var currentPriority: Long = initialPriority - override val id: RapidsBufferId = bufferId - override val memoryUsedBytes: Long = 0 - override def meta: TableMeta = tableMeta - override val storageTier: StorageTier = tier - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = null - override def getMemoryBuffer: MemoryBuffer = null - override def copyToMemoryBuffer( - srcOffset: Long, - dst: MemoryBuffer, - dstOffset: Long, - length: Long, - stream: Cuda.Stream): Unit = {} - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = null - override def getHostMemoryBuffer: HostMemoryBuffer = null - override def addReference(): Boolean = { - if (_acquireAttempts > 0) { - _acquireAttempts -= 1 - } - _acquireAttempts == 0 - } - override def free(): Unit = {} - override def getSpillPriority: Long = currentPriority - override def setSpillPriority(priority: Long): Unit = { - currentPriority = priority - } - - override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { body(null) } - override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { body(null) } - override def close(): Unit = {} - }) - } -} - -case class MockBufferId(override val tableId: Int) extends RapidsBufferId { - override def getDiskPath(dbm: RapidsDiskBlockManager): File = - throw new UnsupportedOperationException -} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala deleted file mode 100644 index 45d96be4cb6..00000000000 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala +++ /dev/null @@ -1,489 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.File -import java.math.RoundingMode - -import scala.collection.mutable.ArrayBuffer - -import ai.rapids.cudf.{ColumnVector, ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Table} -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta -import org.mockito.ArgumentCaptor -import org.mockito.Mockito.{spy, verify} -import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.mockito.MockitoSugar - -import org.apache.spark.sql.rapids.RapidsDiskBlockManager -import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, StringType} - -class RapidsDeviceMemoryStoreSuite extends AnyFunSuite with MockitoSugar { - private def buildTable(): Table = { - new Table.TestBuilder() - .column(5, null.asInstanceOf[java.lang.Integer], 3, 1) - .column("five", "two", null, null) - .column(5.0D, 2.0D, 3.0D, 1.0D) - .decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123) - .build() - } - - private def buildTableWithDuplicate(): Table = { - withResource(ColumnVector.fromInts(5, null.asInstanceOf[java.lang.Integer], 3, 1)) { intCol => - withResource(ColumnVector.fromStrings("five", "two", null, null)) { stringCol => - withResource(ColumnVector.fromDoubles(5.0, 2.0, 3.0, 1.0)) { doubleCol => - // add intCol twice - new Table(intCol, intCol, stringCol, doubleCol) - } - } - } - } - - private def buildContiguousTable(): ContiguousTable = { - withResource(buildTable()) { table => - table.contiguousSplit()(0) - } - } - - test("add table registers with catalog") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val bufferId = MockRapidsBufferId(7) - withResource(buildContiguousTable()) { ct => - catalog.addContiguousTable( - bufferId, ct, spillPriority, false) - } - val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) - verify(catalog).registerNewBuffer(captor.capture()) - val resultBuffer = captor.getValue - assertResult(bufferId)(resultBuffer.id) - assertResult(spillPriority)(resultBuffer.getSpillPriority) - } - } - - test("a non-contiguous table is spillable and it is handed over to the store") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTable() - catalog.addTable(table, spillPriority) - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("a non-contiguous table becomes non-spillable when batch is obtained") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTable() - val handle = catalog.addTable(table, spillPriority) - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - val batch = withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.getColumnarBatch(types) - } - withResource(batch) { _ => - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - } - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("a non-contiguous table is non-spillable until all columns are returned") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTable() - val handle = catalog.addTable(table, spillPriority) - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - // incRefCount all the columns via `batch` - val batch = withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.getColumnarBatch(types) - } - val columns = GpuColumnVector.extractBases(batch) - withResource(columns.head) { _ => - columns.head.incRefCount() - withResource(batch) { _ => - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - } - // still 0 after the batch is closed, because of the extra incRefCount - // for columns.head - assertResult(0)(store.currentSpillableSize) - } - // columns.head is closed, so now our RapidsTable is spillable again - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("an aliased non-contiguous table is not spillable (until closing the alias) ") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTable() - val handle = catalog.addTable(table, spillPriority) - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - val aliasHandle = withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - // extract the batch from the table we added, and add it back as a batch - withResource(rapidsBuffer.getColumnarBatch(types)) { batch => - catalog.addBatch(batch, spillPriority) - } - } // we now have two copies in the store - assertResult(buffSize*2)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - - aliasHandle.close() // remove the alias - - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("an aliased non-contiguous table is not spillable (until closing the original) ") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTable() - val handle = catalog.addTable(table, spillPriority) - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - // extract the batch from the table we added, and add it back as a batch - withResource(rapidsBuffer.getColumnarBatch(types)) { batch => - catalog.addBatch(batch, spillPriority) - } - } // we now have two copies in the store - assertResult(buffSize * 2)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - - handle.close() // remove the original - - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("an non-contiguous table supports duplicated columns") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTableWithDuplicate() - val handle = catalog.addTable(table, spillPriority) - val types: Array[DataType] = - Seq(IntegerType, IntegerType, StringType, DoubleType).toArray - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - // extract the batch from the table we added, and add it back as a batch - withResource(rapidsBuffer.getColumnarBatch(types)) { batch => - catalog.addBatch(batch, spillPriority) - } - } // we now have two copies in the store - assertResult(buffSize * 2)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - - handle.close() // remove the original - - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("a contiguous table is not spillable until the owner closes it") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val bufferId = MockRapidsBufferId(7) - val ct = buildContiguousTable() - val buffSize = ct.getBuffer.getLength - withResource(ct) { _ => - catalog.addContiguousTable( - bufferId, - ct, - spillPriority, - false) - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - } - // after closing the original table, the RapidsBuffer should be spillable - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("a buffer is not spillable until the owner closes columns referencing it") { - withResource(new RapidsDeviceMemoryStore) { store => - val spillPriority = 3 - val bufferId = MockRapidsBufferId(7) - val ct = buildContiguousTable() - val buffSize = ct.getBuffer.getLength - withResource(ct) { _ => - val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) - withResource(ct) { _ => - store.addBuffer( - bufferId, - ct.getBuffer, - meta, - spillPriority, - false) - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - } - } - // after closing the original table, the RapidsBuffer should be spillable - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("a buffer is not spillable when the underlying device buffer is obtained from it") { - withResource(new RapidsDeviceMemoryStore) { store => - val spillPriority = 3 - val bufferId = MockRapidsBufferId(7) - val ct = buildContiguousTable() - val buffSize = ct.getBuffer.getLength - val buffer = withResource(ct) { _ => - val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) - val buffer = store.addBuffer( - bufferId, - ct.getBuffer, - meta, - spillPriority, - false) - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - buffer - } - - // after closing the original table, the RapidsBuffer should be spillable - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - - // if a device memory buffer is obtained from the buffer, it is no longer spillable - withResource(buffer.getDeviceMemoryBuffer) { deviceBuffer => - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - } - - // once the DeviceMemoryBuffer is closed, the RapidsBuffer should be spillable again - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("add buffer registers with catalog") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val bufferId = MockRapidsBufferId(7) - val meta = withResource(buildContiguousTable()) { ct => - val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) - withResource(ct) { _ => - catalog.addBuffer( - bufferId, - ct.getBuffer, - meta, - spillPriority, - false) - } - meta - } - val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) - verify(catalog).registerNewBuffer(captor.capture()) - val resultBuffer = captor.getValue - assertResult(bufferId)(resultBuffer.id) - assertResult(spillPriority)(resultBuffer.getSpillPriority) - assertResult(meta)(resultBuffer.meta) - } - } - - test("get memory buffer") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val bufferId = MockRapidsBufferId(7) - withResource(buildContiguousTable()) { ct => - withResource(HostMemoryBuffer.allocate(ct.getBuffer.getLength)) { expectedHostBuffer => - expectedHostBuffer.copyFromDeviceBuffer(ct.getBuffer) - val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) - val handle = withResource(ct) { _ => - catalog.addBuffer( - bufferId, - ct.getBuffer, - meta, - initialSpillPriority = 3, - needsSync = false) - } - withResource(catalog.acquireBuffer(handle)) { buffer => - withResource(buffer.getMemoryBuffer.asInstanceOf[DeviceMemoryBuffer]) { devbuf => - withResource(HostMemoryBuffer.allocate(devbuf.getLength)) { actualHostBuffer => - actualHostBuffer.copyFromDeviceBuffer(devbuf) - assertResult(expectedHostBuffer.asByteBuffer())(actualHostBuffer.asByteBuffer()) - } - } - } - } - } - } - } - - test("get column batch") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = new RapidsBufferCatalog(store) - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val bufferId = MockRapidsBufferId(7) - withResource(buildContiguousTable()) { ct => - withResource(GpuColumnVector.from(ct.getTable, sparkTypes)) { - expectedBatch => - val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) - val handle = withResource(ct) { _ => - catalog.addBuffer( - bufferId, - ct.getBuffer, - meta, - initialSpillPriority = 3, - false) - } - withResource(catalog.acquireBuffer(handle)) { buffer => - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - } - } - } - } - - test("size statistics") { - - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = new RapidsBufferCatalog(store) - assertResult(0)(store.currentSize) - val bufferSizes = new Array[Long](2) - val bufferHandles = new Array[RapidsBufferHandle](2) - bufferSizes.indices.foreach { i => - withResource(buildContiguousTable()) { ct => - bufferSizes(i) = ct.getBuffer.getLength - // store takes ownership of the table - bufferHandles(i) = - catalog.addContiguousTable( - MockRapidsBufferId(i), - ct, - initialSpillPriority = 0, - false) - } - assertResult(bufferSizes.take(i+1).sum)(store.currentSize) - } - bufferHandles(0).close() - assertResult(bufferSizes(1))(store.currentSize) - bufferHandles(1).close() - assertResult(0)(store.currentSize) - } - } - - test("spill") { - val spillStore = new MockSpillStore - val spillPriorities = Array(0, -1, 2) - val bufferSizes = new Array[Long](spillPriorities.length) - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = new RapidsBufferCatalog(store) - store.setSpillStore(spillStore) - spillPriorities.indices.foreach { i => - withResource(buildContiguousTable()) { ct => - bufferSizes(i) = ct.getBuffer.getLength - // store takes ownership of the table - catalog.addContiguousTable( - MockRapidsBufferId(i), ct, spillPriorities(i), - false) - } - } - assert(spillStore.spilledBuffers.isEmpty) - - // asking to spill 0 bytes should not spill - val sizeBeforeSpill = store.currentSize - catalog.synchronousSpill(store, sizeBeforeSpill) - assert(spillStore.spilledBuffers.isEmpty) - assertResult(sizeBeforeSpill)(store.currentSize) - catalog.synchronousSpill(store, sizeBeforeSpill + 1) - assert(spillStore.spilledBuffers.isEmpty) - assertResult(sizeBeforeSpill)(store.currentSize) - - // spilling 1 byte should force one buffer to spill in priority order - catalog.synchronousSpill(store, sizeBeforeSpill - 1) - assertResult(1)(spillStore.spilledBuffers.length) - assertResult(bufferSizes.drop(1).sum)(store.currentSize) - assertResult(1)(spillStore.spilledBuffers(0).tableId) - - // spilling to zero should force all buffers to spill in priority order - catalog.synchronousSpill(store, 0) - assertResult(3)(spillStore.spilledBuffers.length) - assertResult(0)(store.currentSize) - assertResult(0)(spillStore.spilledBuffers(1).tableId) - assertResult(2)(spillStore.spilledBuffers(2).tableId) - } - } - - case class MockRapidsBufferId(tableId: Int) extends RapidsBufferId { - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - throw new UnsupportedOperationException - } - - class MockSpillStore extends RapidsBufferStore(StorageTier.HOST) { - val spilledBuffers = new ArrayBuffer[RapidsBufferId] - - override protected def createBuffer( - b: RapidsBuffer, - c: RapidsBufferCatalog, - s: Cuda.Stream): Option[RapidsBufferBase] = { - spilledBuffers += b.id - Some(new MockRapidsBuffer( - b.id, b.getPackedSizeBytes, b.meta, b.getSpillPriority)) - } - - class MockRapidsBuffer(id: RapidsBufferId, size: Long, meta: TableMeta, spillPriority: Long) - extends RapidsBufferBase(id, meta, spillPriority) { - override protected def releaseResources(): Unit = {} - - override val storageTier: StorageTier = StorageTier.HOST - - override def getMemoryBuffer: MemoryBuffer = - throw new UnsupportedOperationException - - /** The size of this buffer in bytes. */ - override val memoryUsedBytes: Long = size - } - } -} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala deleted file mode 100644 index fce88f116b3..00000000000 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala +++ /dev/null @@ -1,607 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.File -import java.math.RoundingMode - -import ai.rapids.cudf.{ColumnVector, ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, Table} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} -import org.mockito.ArgumentMatchers -import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{spy, times, verify, when} -import org.scalatestplus.mockito.MockitoSugar - -import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.{RapidsDiskBlockManager, TempSpillBufferId} -import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, StringType} -import org.apache.spark.storage.{BlockId, ShuffleBlockId} - - -class RapidsDiskStoreSuite extends FunSuiteWithTempDir with MockitoSugar { - - private def buildContiguousTable(): ContiguousTable = { - withResource(buildTable()) { table => - table.contiguousSplit()(0) - } - } - - private def buildTable(): Table = { - new Table.TestBuilder() - .column(5, null.asInstanceOf[java.lang.Integer], 3, 1) - .column("five", "two", null, null) - .column(5.0, 2.0, 3.0, 1.0) - .decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123) - .build() - } - - private def buildEmptyTable(): Table = { - withResource(buildTable()) { tbl => - withResource(ColumnVector.fromBooleans(false, false, false, false)) { mask => - tbl.filter(mask) // filter all out - } - } - } - - private val mockTableDataTypes: Array[DataType] = - Array(IntegerType, StringType, DoubleType, DecimalType(10, 5)) - - test("spill updates catalog") { - val bufferId = MockRapidsBufferId(7, canShareDiskPaths = false) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = spy(new RapidsBufferCatalog(devStore)) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - assertResult(0)(diskStore.currentSize) - hostStore.setSpillStore(diskStore) - val (bufferSize, handle) = - addContiguousTableToCatalog(catalog, bufferId, spillPriority) - val path = handle.id.getDiskPath(null) - assert(!path.exists()) - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - assertResult(0)(hostStore.currentSize) - assertResult(bufferSize)(diskStore.currentSize) - assert(path.exists) - assertResult(bufferSize)(path.length) - verify(catalog, times(3)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) - verify(catalog).removeBufferTier( - ArgumentMatchers.eq(handle.id), ArgumentMatchers.eq(StorageTier.DEVICE)) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - assertResult(bufferSize)(buffer.memoryUsedBytes) - assertResult(handle.id)(buffer.id) - assertResult(spillPriority)(buffer.getSpillPriority) - } - } - } - } - } - - test("Get columnar batch") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - assert(!bufferPath.exists) - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { - diskStore => - hostStore.setSpillStore(diskStore) - val (_, handle) = addContiguousTableToCatalog(catalog, bufferId, spillPriority) - assert(!handle.id.getDiskPath(null).exists()) - val expectedTable = withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DEVICE)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { beforeSpill => - withResource(GpuColumnVector.from(beforeSpill)) { table => - table.contiguousSplit()(0) - } - } // closing the batch from the store so that we can spill it - } - withResource(expectedTable) { _ => - withResource( - GpuColumnVector.from(expectedTable.getTable, sparkTypes)) { expectedBatch => - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - } - } - } - } - } - } - - test("get memory buffer") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - assert(!bufferPath.exists) - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val (_, handle) = addContiguousTableToCatalog(catalog, bufferId, spillPriority) - assert(!handle.id.getDiskPath(mockDiskBlockManager).exists()) - val expectedBuffer = withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DEVICE)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { devbuf => - closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => - hostbuf.copyFromDeviceBuffer(devbuf.asInstanceOf[DeviceMemoryBuffer]) - hostbuf - } - } - } - withResource(expectedBuffer) { expectedBuffer => - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { actualBuffer => - assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) - val actualHostBuffer = actualBuffer.asInstanceOf[HostMemoryBuffer] - assertResult(expectedBuffer. - asByteBuffer.limit())(actualHostBuffer.asByteBuffer.limit()) - } - } - } - } - } - } - } - - test("Compression on with or without encryption for spill block using single batch") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.io.compression.codec", "zstd") - conf.set("spark.shuffle.spill.compress", "true") - conf.set("spark.shuffle.compress", "true") - readWriteTestWithBatches(conf, TempSpillBufferId.apply()) - } - } - - test("Compression off with or without encryption for spill block using single batch") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.shuffle.spill.compress", "false") - conf.set("spark.shuffle.compress", "false") - readWriteTestWithBatches(conf, TempSpillBufferId.apply()) - } - } - - test("Compression on with or without encryption for spill block using multiple batches") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.io.compression.codec", "zstd") - conf.set("spark.shuffle.spill.compress", "true") - conf.set("spark.shuffle.compress", "true") - readWriteTestWithBatches(conf, TempSpillBufferId.apply(), TempSpillBufferId.apply()) - } - } - - test("Compression off with or without encryption for spill block using multiple batches") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.shuffle.spill.compress", "false") - conf.set("spark.shuffle.compress", "false") - readWriteTestWithBatches(conf, TempSpillBufferId.apply(), TempSpillBufferId.apply()) - } - } - - // ===== Tests for shuffle block ===== - - test("Compression on with or without encryption for shuffle block using single batch") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.io.compression.codec", "zstd") - conf.set("spark.shuffle.spill.compress", "true") - conf.set("spark.shuffle.compress", "true") - readWriteTestWithBatches(conf, ShuffleBufferId(ShuffleBlockId(1, 1, 1), 1)) - } - } - - test("Compression off with or without encryption for shuffle block using single batch") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.shuffle.spill.compress", "false") - conf.set("spark.shuffle.compress", "false") - readWriteTestWithBatches(conf, ShuffleBufferId(ShuffleBlockId(1, 1, 1), 1)) - } - } - - test("Compression on with or without encryption for shuffle block using multiple batches") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.io.compression.codec", "zstd") - conf.set("spark.shuffle.spill.compress", "true") - conf.set("spark.shuffle.compress", "true") - readWriteTestWithBatches(conf, - ShuffleBufferId(ShuffleBlockId(1, 1, 1), 1), ShuffleBufferId(ShuffleBlockId(2, 2, 2), 2)) - } - } - - test("Compression off with or without encryption for shuffle block using multiple batches") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.shuffle.spill.compress", "false") - conf.set("spark.shuffle.compress", "false") - readWriteTestWithBatches(conf, - ShuffleBufferId(ShuffleBlockId(1, 1, 1), 1), ShuffleBufferId(ShuffleBlockId(2, 2, 2), 2)) - } - } - - test("No encryption and compression for shuffle block using multiple batches") { - readWriteTestWithBatches(new SparkConf(), - ShuffleBufferId(ShuffleBlockId(1, 1, 1), 1), ShuffleBufferId(ShuffleBlockId(2, 2, 2), 2)) - } - - private def readWriteTestWithBatches(conf: SparkConf, bufferIds: RapidsBufferId*) = { - assert(bufferIds.size != 0) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(conf)) - - if (bufferIds(0).canShareDiskPaths) { - // Return the same path - val bufferPath = new File(TEST_FILES_ROOT, s"diskbuffer-${bufferIds(0).tableId}") - when(mockDiskBlockManager.getFile(any[BlockId]())).thenReturn(bufferPath) - if (bufferPath.exists) bufferPath.delete() - } else { - when(mockDiskBlockManager.getFile(any[BlockId]())) - .thenAnswer { invocation => - new File(TEST_FILES_ROOT, s"diskbuffer-${invocation.getArgument[BlockId](0).name}") - } - } - - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - bufferIds.foreach { bufferId => - val (_, handle) = addContiguousTableToCatalog(catalog, bufferId, spillPriority) - val expectedBuffer = withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DEVICE)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { devbuf => - closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => - hostbuf.copyFromDeviceBuffer(devbuf.asInstanceOf[DeviceMemoryBuffer]) - hostbuf - } - } - } - withResource(expectedBuffer) { expectedBuffer => - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { actualBuffer => - assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) - val actualHostBuffer = actualBuffer.asInstanceOf[HostMemoryBuffer] - assertResult(expectedBuffer.asByteBuffer)(actualHostBuffer.asByteBuffer) - } - } - } - } - } - } - } - } - - test("skip host: spill device memory buffer to disk") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new AlwaysFailingRapidsHostMemoryStore) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val (_, handle) = addContiguousTableToCatalog(catalog, bufferId, spillPriority) - assert(!handle.id.getDiskPath(null).exists()) - val expectedBuffer = withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DEVICE)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { devbuf => - closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => - hostbuf.copyFromDeviceBuffer(devbuf.asInstanceOf[DeviceMemoryBuffer]) - hostbuf - } - } - } - withResource(expectedBuffer) { expectedBuffer => - catalog.synchronousSpill(devStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { actualBuffer => - assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) - val actualHostBuffer = actualBuffer.asInstanceOf[HostMemoryBuffer] - assertResult(expectedBuffer.asByteBuffer)(actualHostBuffer.asByteBuffer) - } - } - } - } - } - } - } - - test("skip host: spill table to disk") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new AlwaysFailingRapidsHostMemoryStore) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val handle = addTableToCatalog(catalog, bufferId, spillPriority) - withResource(buildTable()) { expectedTable => - withResource( - GpuColumnVector.from(expectedTable, mockTableDataTypes)) { expectedBatch => - catalog.synchronousSpill(devStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assert(handle.id.getDiskPath(null).exists()) - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(mockTableDataTypes)) { fromDiskBatch => - TestUtils.compareBatches(expectedBatch, fromDiskBatch) - } - } - } - } - } - } - } - } - - test("skip host: spill table to disk with small host bounce buffer") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - withResource(new RapidsDeviceMemoryStore(1L*1024*1024, 10)) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new AlwaysFailingRapidsHostMemoryStore) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val handle = addTableToCatalog(catalog, bufferId, spillPriority) - withResource(buildTable()) { expectedTable => - withResource( - GpuColumnVector.from(expectedTable, mockTableDataTypes)) { expectedBatch => - catalog.synchronousSpill(devStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assert(handle.id.getDiskPath(null).exists()) - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(mockTableDataTypes)) { fromDiskBatch => - TestUtils.compareBatches(expectedBatch, fromDiskBatch) - } - } - } - } - } - } - } - } - - - test("0-byte table is never spillable as we would fail to mmap") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - val bufferId2 = MockRapidsBufferId(2, canShareDiskPaths = false) - assert(!bufferPath.exists) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val handle = addZeroRowsTableToCatalog(catalog, bufferId, spillPriority - 1) - val handle2 = addTableToCatalog(catalog, bufferId2, spillPriority) - withResource(handle2) { _ => - assert(!handle.id.getDiskPath(null).exists()) - withResource(buildTable()) { expectedTable => - withResource(buildEmptyTable()) { expectedEmptyTable => - withResource( - GpuColumnVector.from( - expectedTable, mockTableDataTypes)) { expectedCb => - withResource( - GpuColumnVector.from( - expectedEmptyTable, mockTableDataTypes)) { expectedEmptyCb => - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - withResource(catalog.acquireBuffer(handle2)) { buffer => - withResource(catalog.acquireBuffer(handle)) { emptyBuffer => - // the 0-byte table never moved from device. It is not spillable - assertResult(StorageTier.DEVICE)(emptyBuffer.storageTier) - withResource(emptyBuffer.getColumnarBatch(mockTableDataTypes)) { cb => - TestUtils.compareBatches(expectedEmptyCb, cb) - } - // the second table (with rows) did spill - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(mockTableDataTypes)) { cb => - TestUtils.compareBatches(expectedCb, cb) - } - } - } - assertResult(0)(devStore.currentSize) - assertResult(0)(hostStore.currentSize) - } - } - } - } - } - } - } - } - } - - test("exclusive spill files are deleted when buffer deleted") { - testBufferFileDeletion(canShareDiskPaths = false) - } - - test("shared spill files are not deleted when a buffer is deleted") { - testBufferFileDeletion(canShareDiskPaths = true) - } - - class AlwaysFailingRapidsHostMemoryStore extends RapidsHostMemoryStore(Some(0L)){ - override def createBuffer( - other: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] = { - None - } - } - - private def testBufferFileDeletion(canShareDiskPaths: Boolean): Unit = { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val (_, handle) = addContiguousTableToCatalog(catalog, bufferId, spillPriority) - val bufferPath = handle.id.getDiskPath(null) - assert(!bufferPath.exists()) - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - assert(bufferPath.exists) - handle.close() - if (canShareDiskPaths) { - assert(bufferPath.exists()) - } else { - assert(!bufferPath.exists) - } - } - } - } - } - - private def addContiguousTableToCatalog( - catalog: RapidsBufferCatalog, - bufferId: RapidsBufferId, - spillPriority: Long): (Long, RapidsBufferHandle) = { - withResource(buildContiguousTable()) { ct => - val bufferSize = ct.getBuffer.getLength - // store takes ownership of the table - val handle = catalog.addContiguousTable( - bufferId, - ct, - spillPriority, - false) - (bufferSize, handle) - } - } - - private def addTableToCatalog( - catalog: RapidsBufferCatalog, - bufferId: RapidsBufferId, - spillPriority: Long): RapidsBufferHandle = { - // store takes ownership of the table - catalog.addTable( - bufferId, - buildTable(), - spillPriority, - false) - } - - private def addZeroRowsTableToCatalog( - catalog: RapidsBufferCatalog, - bufferId: RapidsBufferId, - spillPriority: Long): RapidsBufferHandle = { - val table = buildEmptyTable() - // store takes ownership of the table - catalog.addTable( - bufferId, - table, - spillPriority, - false) - } - - case class MockRapidsBufferId( - tableId: Int, - override val canShareDiskPaths: Boolean) extends RapidsBufferId { - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - new File(TEST_FILES_ROOT, s"diskbuffer-$tableId") - } -} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala deleted file mode 100644 index 1ffad031451..00000000000 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala +++ /dev/null @@ -1,614 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.File -import java.math.RoundingMode - -import ai.rapids.cudf.{ContiguousTable, Cuda, HostColumnVector, HostMemoryBuffer, Table} -import com.nvidia.spark.rapids.Arm._ -import org.mockito.{ArgumentCaptor, ArgumentMatchers} -import org.mockito.Mockito.{spy, times, verify} -import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.mockito.MockitoSugar - -import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.RapidsDiskBlockManager -import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, LongType, StringType} -import org.apache.spark.sql.vectorized.ColumnarBatch - -class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { - private def buildContiguousTable(): ContiguousTable = { - withResource(new Table.TestBuilder() - .column(5, null.asInstanceOf[java.lang.Integer], 3, 1) - .column("five", "two", null, null) - .column(5.0, 2.0, 3.0, 1.0) - .decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123) - .build()) { table => - table.contiguousSplit()(0) - } - } - - private def buildContiguousTable(numRows: Int): ContiguousTable = { - val vals = (0 until numRows).map(_.toLong) - withResource(HostColumnVector.fromLongs(vals: _*)) { hcv => - withResource(hcv.copyToDevice()) { cv => - withResource(new Table(cv)) { table => - table.contiguousSplit()(0) - } - } - } - } - - private def buildHostBatch(): ColumnarBatch = { - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val hostCols = withResource(buildContiguousTable()) { ct => - withResource(ct.getTable) { tbl => - (0 until tbl.getNumberOfColumns) - .map(c => tbl.getColumn(c).copyToHost()) - } - }.toArray - new ColumnarBatch( - hostCols.zip(sparkTypes).map { case (hostCol, dataType) => - new RapidsHostColumnVector(dataType, hostCol) - }, hostCols.head.getRowCount.toInt) - } - - private def buildHostBatchWithDuplicate(): ColumnarBatch = { - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val hostCols = withResource(buildContiguousTable()) { ct => - withResource(ct.getTable) { tbl => - (0 until tbl.getNumberOfColumns) - .map(c => tbl.getColumn(c).copyToHost()) - } - }.toArray - hostCols.foreach(_.incRefCount()) - new ColumnarBatch( - (hostCols ++ hostCols).zip(sparkTypes ++ sparkTypes).map { case (hostCol, dataType) => - new RapidsHostColumnVector(dataType, hostCol) - }, hostCols.head.getRowCount.toInt) - } - - test("spill updates catalog") { - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - val mockStore = mock[RapidsHostMemoryStore] - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = spy(new RapidsBufferCatalog(devStore)) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - assertResult(0)(hostStore.currentSize) - assertResult(hostStoreMaxSize)(hostStore.numBytesFree.get) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - - val (bufferSize, handle) = withResource(buildContiguousTable()) { ct => - val len = ct.getBuffer.getLength - // store takes ownership of the table - val handle = catalog.addContiguousTable( - ct, - spillPriority) - (len, handle) - } - - catalog.synchronousSpill(devStore, 0) - assertResult(bufferSize)(hostStore.currentSize) - assertResult(hostStoreMaxSize - bufferSize)(hostStore.numBytesFree.get) - verify(catalog, times(2)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) - verify(catalog).removeBufferTier( - ArgumentMatchers.eq(handle.id), ArgumentMatchers.eq(StorageTier.DEVICE)) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.HOST)(buffer.storageTier) - assertResult(bufferSize)(buffer.memoryUsedBytes) - assertResult(handle.id)(buffer.id) - assertResult(spillPriority)(buffer.getSpillPriority) - } - } - } - } - - test("get columnar batch") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val mockStore = mock[RapidsHostMemoryStore] - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - var expectedBuffer: HostMemoryBuffer = null - val handle = withResource(buildContiguousTable()) { ct => - expectedBuffer = HostMemoryBuffer.allocate(ct.getBuffer.getLength) - expectedBuffer.copyFromDeviceBuffer(ct.getBuffer) - catalog.addContiguousTable( - ct, - spillPriority) - } - withResource(expectedBuffer) { _ => - catalog.synchronousSpill(devStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - withResource(buffer.getMemoryBuffer) { actualBuffer => - assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) - assertResult(expectedBuffer.asByteBuffer) { - actualBuffer.asInstanceOf[HostMemoryBuffer].asByteBuffer - } - } - } - } - } - } - } - - test("get memory buffer") { - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val mockStore = mock[RapidsHostMemoryStore] - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - var expectedBatch: ColumnarBatch = null - val handle = withResource(buildContiguousTable()) { ct => - // make a copy of the table so we can compare it later to the - // one reconstituted after the spill - withResource(ct.getTable.contiguousSplit()) { copied => - expectedBatch = GpuColumnVector.from(copied(0).getTable, sparkTypes) - } - catalog.addContiguousTable( - ct, - spillPriority) - } - withResource(expectedBatch) { _ => - catalog.synchronousSpill(devStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.HOST)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - } - } - } - } - - test("get memory buffer after host spill") { - RapidsBufferCatalog.close() - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - try { - val bm = new RapidsDiskBlockManager(new SparkConf()) - val (catalog, devStore, hostStore, diskStore) = - closeOnExcept(new RapidsDiskStore(bm)) { diskStore => - closeOnExcept(new RapidsDeviceMemoryStore()) { devStore => - closeOnExcept(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - val catalog = closeOnExcept( - new RapidsBufferCatalog(devStore, hostStore)) { catalog => catalog } - (catalog, devStore, hostStore, diskStore) - } - } - } - - RapidsBufferCatalog.setDeviceStorage(devStore) - RapidsBufferCatalog.setHostStorage(hostStore) - RapidsBufferCatalog.setDiskStorage(diskStore) - RapidsBufferCatalog.setCatalog(catalog) - - var expectedBatch: ColumnarBatch = null - val handle = withResource(buildContiguousTable()) { ct => - // make a copy of the table so we can compare it later to the - // one reconstituted after the spill - withResource(ct.getTable.contiguousSplit()) { copied => - expectedBatch = GpuColumnVector.from(copied(0).getTable, sparkTypes) - } - RapidsBufferCatalog.addContiguousTable( - ct, - spillPriority) - } - withResource(expectedBatch) { _ => - val spilledToHost = - RapidsBufferCatalog.synchronousSpill( - RapidsBufferCatalog.getDeviceStorage, 0) - assert(spilledToHost.isDefined && spilledToHost.get > 0) - - val spilledToDisk = - RapidsBufferCatalog.synchronousSpill( - RapidsBufferCatalog.getHostStorage, 0) - assert(spilledToDisk.isDefined && spilledToDisk.get > 0) - - withResource(RapidsBufferCatalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - } - } finally { - RapidsBufferCatalog.close() - } - } - - test("host buffer originated: get host memory buffer") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val mockStore = mock[RapidsDiskStore] - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - val hmb = HostMemoryBuffer.allocate(1L * 1024) - val spillableBuffer = - SpillableHostBuffer(hmb, hmb.getLength, spillPriority, catalog) - withResource(spillableBuffer) { _ => - // the refcount of 1 is the store - assertResult(1)(hmb.getRefCount) - withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => - assertResult(hmb)(memoryBuffer) - assertResult(2)(memoryBuffer.getRefCount) - } - } - assertResult(0)(hmb.getRefCount) - } - } - } - - test("host buffer originated: get host memory buffer after spill") { - RapidsBufferCatalog.close() - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - try { - val bm = new RapidsDiskBlockManager(new SparkConf()) - val (catalog, devStore, hostStore, diskStore) = - closeOnExcept(new RapidsDiskStore(bm)) { diskStore => - closeOnExcept(new RapidsDeviceMemoryStore()) { devStore => - closeOnExcept(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - val catalog = closeOnExcept( - new RapidsBufferCatalog(devStore, hostStore)) { catalog => catalog } - (catalog, devStore, hostStore, diskStore) - } - } - } - - RapidsBufferCatalog.setDeviceStorage(devStore) - RapidsBufferCatalog.setHostStorage(hostStore) - RapidsBufferCatalog.setDiskStorage(diskStore) - RapidsBufferCatalog.setCatalog(catalog) - - val hmb = HostMemoryBuffer.allocate(1L * 1024) - val spillableBuffer = SpillableHostBuffer( - hmb, - hmb.getLength, - spillPriority) - assertResult(1)(hmb.getRefCount) - // we spill it - RapidsBufferCatalog.synchronousSpill(RapidsBufferCatalog.getHostStorage, 0) - withResource(spillableBuffer) { _ => - // the refcount of the original buffer is 0 because it spilled - assertResult(0)(hmb.getRefCount) - withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => - assertResult(memoryBuffer.getLength)(hmb.getLength) - } - } - } finally { - RapidsBufferCatalog.close() - } - } - - test("host buffer originated: get host memory buffer OOM when unable to spill") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - val hmb = HostMemoryBuffer.allocate(1L * 1024) - val spillableBuffer = SpillableHostBuffer( - hmb, - hmb.getLength, - spillPriority, - catalog) - // spillable is 1K - assertResult(hmb.getLength)(hostStore.currentSpillableSize) - withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => - // 0 because we have a reference to the memoryBuffer - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - } - assertResult(hmb.getLength)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(1L * 1024)(spilled.get) - spillableBuffer.close() - } - } - } - } - - test("host batch originated: get host memory batch") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - - val hostCb = buildHostBatch() - - val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) - - withResource( - SpillableHostColumnarBatch(hostCb, spillPriority, catalog)) { spillableBuffer => - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - - withResource(spillableBuffer.getColumnarBatch()) { hostCb => - // 0 because we have a reference to the memoryBuffer - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - } - - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(sizeOnHost)(spilled.get) - - val sizeOnDisk = diskStore.currentSpillableSize - - // reconstitute batch from disk - withResource(spillableBuffer.getColumnarBatch()) { hostCbFromDisk => - // disk has a different size, so this spillable batch has a different sizeInBytes - // right now, because this is the serialized represenation size - assertResult(sizeOnDisk)(spillableBuffer.sizeInBytes) - // lets recreate our original batch and compare to make sure contents match - withResource(buildHostBatch()) { expectedHostCb => - TestUtils.compareBatches(expectedHostCb, hostCbFromDisk) - } - } - } - } - } - } - } - - test("a host batch is not spillable when we leak it") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - - val hostCb = buildHostBatch() - - val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) - - val leakedBatch = withResource( - SpillableHostColumnarBatch(hostCb, spillPriority, catalog)) { spillableBuffer => - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - - val leakedBatch = spillableBuffer.getColumnarBatch() - // 0 because we have a reference to the host batch - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - leakedBatch - } - - withResource(leakedBatch) { _ => - // 0 because we have leaked that the host batch - assertResult(0)(hostStore.currentSize) - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - } - // after closing we still have 0 bytes in the store or available to spill - assertResult(0)(hostStore.currentSize) - assertResult(0)(hostStore.currentSpillableSize) - } - } - } - } - - test("a host batch is not spillable when columns are incRefCounted") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - - val hostCb = buildHostBatch() - - val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) - - withResource( - SpillableHostColumnarBatch(hostCb, spillPriority, catalog)) { spillableBuffer => - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - - val leakedFirstColumn = withResource(spillableBuffer.getColumnarBatch()) { hostCb => - // 0 because we have a reference to the host batch - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - // leak it by increasing the ref count of the underlying cuDF column - RapidsHostColumnVector.extractBases(hostCb).head.incRefCount() - } - withResource(leakedFirstColumn) { _ => - // 0 because we have a reference to the first column - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - } - // batch is now spillable because we close our reference to the column - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(sizeOnHost)(spilled.get) - } - } - } - } - } - - test("an aliased host batch is not spillable (until closing the original) ") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - val hostBatch = buildHostBatch() - val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostBatch) - val handle = withResource(hostBatch) { _ => - catalog.addBatch(hostBatch, spillPriority) - } - withResource(handle) { _ => - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - assertResult(sizeOnHost)(hostStore.currentSize) - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - // extract the batch from the table we added, and add it back as a batch - withResource(rapidsBuffer.getHostColumnarBatch(types)) { batch => - catalog.addBatch(batch, spillPriority) - } - } // we now have two copies in the store - assertResult(sizeOnHost * 2)(hostStore.currentSize) - assertResult(0)(hostStore.currentSpillableSize) - } // remove the original - assertResult(sizeOnHost)(hostStore.currentSize) - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - } - } - } - } - - test("an aliased host batch supports duplicated columns") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - val hostBatch = buildHostBatchWithDuplicate() - val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostBatch) - val handle = withResource(hostBatch) { _ => - catalog.addBatch(hostBatch, spillPriority) - } - withResource(handle) { _ => - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - assertResult(sizeOnHost)(hostStore.currentSize) - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - // extract the batch from the table we added, and add it back as a batch - withResource(rapidsBuffer.getHostColumnarBatch(types)) { batch => - catalog.addBatch(batch, spillPriority) - } - } // we now have two copies in the store - assertResult(sizeOnHost * 2)(hostStore.currentSize) - assertResult(0)(hostStore.currentSpillableSize) - } // remove the original - assertResult(sizeOnHost)(hostStore.currentSize) - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - } - } - } - } - - test("buffer exceeds maximum size") { - val sparkTypes = Array[DataType](LongType) - val spillPriority = -10 - val hostStoreMaxSize = 256 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - val spyStore = spy(new RapidsDiskStore(new RapidsDiskBlockManager(new SparkConf()))) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(spyStore) - var bigHandle: RapidsBufferHandle = null - var bigTable = buildContiguousTable(1024 * 1024) - closeOnExcept(bigTable) { _ => - // make a copy of the table so we can compare it later to the - // one reconstituted after the spill - val expectedBatch = - withResource(bigTable.getTable.contiguousSplit()) { expectedTable => - GpuColumnVector.from(expectedTable(0).getTable, sparkTypes) - } - withResource(expectedBatch) { _ => - bigHandle = withResource(bigTable) { _ => - catalog.addContiguousTable( - bigTable, - spillPriority) - } // close the bigTable so it can be spilled - bigTable = null - withResource(catalog.acquireBuffer(bigHandle)) { buffer => - assertResult(StorageTier.DEVICE)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - catalog.synchronousSpill(devStore, 0) - val rapidsBufferCaptor: ArgumentCaptor[RapidsBuffer] = - ArgumentCaptor.forClass(classOf[RapidsBuffer]) - verify(spyStore).copyBuffer( - rapidsBufferCaptor.capture(), - ArgumentMatchers.any[RapidsBufferCatalog], - ArgumentMatchers.any[Cuda.Stream]) - assertResult(bigHandle.id)(rapidsBufferCaptor.getValue.id) - withResource(catalog.acquireBuffer(bigHandle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - } - } - } - } - } - - case class MockRapidsBufferId(tableId: Int) extends RapidsBufferId { - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - throw new UnsupportedOperationException - } -} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala index 5776b2f99a8..24d5984f749 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala @@ -18,16 +18,15 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{Rmm, RmmAllocationMode, RmmEventHandler} import com.nvidia.spark.rapids.jni.RmmSpark -import org.mockito.Mockito.spy +import com.nvidia.spark.rapids.spill.SpillFramework import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite +import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession trait RmmSparkRetrySuiteBase extends AnyFunSuite with BeforeAndAfterEach { private var rmmWasInitialized = false - protected var deviceStorage: RapidsDeviceMemoryStore = _ - override def beforeEach(): Unit = { super.beforeEach() SparkSession.getActiveSession.foreach(_.stop()) @@ -37,14 +36,14 @@ trait RmmSparkRetrySuiteBase extends AnyFunSuite with BeforeAndAfterEach { rmmWasInitialized = true Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) } - deviceStorage = spy(new RapidsDeviceMemoryStore()) - val hostStore = new RapidsHostMemoryStore(Some(1L * 1024 * 1024)) - deviceStorage.setSpillStore(hostStore) - val catalog = new RapidsBufferCatalog(deviceStorage, hostStore) - // set these against the singleton so we close them later - RapidsBufferCatalog.setDeviceStorage(deviceStorage) - RapidsBufferCatalog.setHostStorage(hostStore) - RapidsBufferCatalog.setCatalog(catalog) + val sc = new SparkConf + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1MB") + val conf = new RapidsConf(sc) + SpillFramework.shutdown() + SpillFramework.initialize(conf) + + RmmSpark.clearEventHandler() + val mockEventHandler = new BaseRmmEventHandler() RmmSpark.setEventHandler(mockEventHandler) RmmSpark.currentThreadIsDedicatedToTask(1) @@ -53,11 +52,11 @@ trait RmmSparkRetrySuiteBase extends AnyFunSuite with BeforeAndAfterEach { override def afterEach(): Unit = { super.afterEach() + SpillFramework.shutdown() SparkSession.getActiveSession.foreach(_.stop()) SparkSession.clearActiveSession() RmmSpark.removeAllCurrentThreadAssociation() RmmSpark.clearEventHandler() - RapidsBufferCatalog.close() GpuSemaphore.shutdown() if (rmmWasInitialized) { Rmm.shutdown() diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala index 56ecb1a8c57..c6c251aeabb 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +18,15 @@ package com.nvidia.spark.rapids import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} -import ai.rapids.cudf.Table +import ai.rapids.cudf.{Rmm, RmmAllocationMode, Table} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.spill.SpillFramework import org.apache.commons.lang3.SerializationUtils import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite +import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch} import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StringType} @@ -33,12 +35,17 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} class SerializationSuite extends AnyFunSuite with BeforeAndAfterAll { + override def beforeAll(): Unit = { - RapidsBufferCatalog.setDeviceStorage(new RapidsDeviceMemoryStore()) + super.beforeAll() + Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) + SpillFramework.initialize(new RapidsConf(new SparkConf)) } override def afterAll(): Unit = { - RapidsBufferCatalog.close() + super.afterAll() + SpillFramework.shutdown() + Rmm.shutdown() } private def buildBatch(): ColumnarBatch = { @@ -170,7 +177,7 @@ class SerializationSuite extends AnyFunSuite withResource(toHostBatch(gpuBatch)) { expectedHostBatch => val broadcast = makeBroadcastBatch(gpuBatch) withBroadcast(broadcast) { _ => - withResource(broadcast.batch.getColumnarBatch()) { materialized => + withResource(broadcast.batch.getColumnarBatch) { materialized => TestUtils.compareBatches(gpuBatch, materialized) } // the host batch here is obtained from the GPU batch since @@ -193,12 +200,12 @@ class SerializationSuite extends AnyFunSuite batches.foreach { gpuExpected => val broadcast = makeBroadcastBatch(gpuExpected) withBroadcast(broadcast) { _ => - withResource(broadcast.batch.getColumnarBatch()) { gpuBatch => + withResource(broadcast.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } // clone via serialization after manifesting the GPU batch withBroadcast(SerializationUtils.clone(broadcast)) { clonedObj => - withResource(clonedObj.batch.getColumnarBatch()) { gpuClonedBatch => + withResource(clonedObj.batch.getColumnarBatch) { gpuClonedBatch => TestUtils.compareBatches(gpuExpected, gpuClonedBatch) } // try to clone it again from the cloned object @@ -214,12 +221,12 @@ class SerializationSuite extends AnyFunSuite batches.foreach { gpuExpected => val broadcast = makeBroadcastBatch(gpuExpected) withBroadcast(broadcast) { _ => - withResource(broadcast.batch.getColumnarBatch()) { gpuBatch => + withResource(broadcast.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } // clone via serialization after manifesting the GPU batch withBroadcast(SerializationUtils.clone(broadcast)) { clonedObj => - withResource(clonedObj.batch.getColumnarBatch()) { gpuClonedBatch => + withResource(clonedObj.batch.getColumnarBatch) { gpuClonedBatch => TestUtils.compareBatches(gpuExpected, gpuClonedBatch) } // try to clone it again from the cloned object @@ -234,12 +241,12 @@ class SerializationSuite extends AnyFunSuite withResource(buildBatch()) { gpuExpected => val broadcast = makeBroadcastBatch(gpuExpected) withBroadcast(broadcast) { _ => - withResource(broadcast.batch.getColumnarBatch()) { gpuBatch => + withResource(broadcast.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } // clone via serialization after manifesting the GPU batch withBroadcast(SerializationUtils.clone(broadcast)) { clonedObj => - withResource(clonedObj.batch.getColumnarBatch()) { gpuClonedBatch => + withResource(clonedObj.batch.getColumnarBatch) { gpuClonedBatch => TestUtils.compareBatches(gpuExpected, gpuClonedBatch) } // try to clone it again from the cloned object @@ -263,7 +270,7 @@ class SerializationSuite extends AnyFunSuite broadcast.doReadObject(inputStream) // use it now - withResource(broadcast.batch.getColumnarBatch()) { gpuBatch => + withResource(broadcast.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } } @@ -275,7 +282,7 @@ class SerializationSuite extends AnyFunSuite val broadcast = makeBroadcastBatch(gpuExpected) withBroadcast(broadcast) { _ => // materialize - withResource(broadcast.batch.getColumnarBatch()) { cb => + withResource(broadcast.batch.getColumnarBatch) { cb => TestUtils.compareBatches(gpuExpected, cb) } @@ -292,7 +299,7 @@ class SerializationSuite extends AnyFunSuite assertResult(before)(broadcast.batch) // it is the same as before // use it now - withResource(broadcast.batch.getColumnarBatch()) { gpuBatch => + withResource(broadcast.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } } @@ -311,7 +318,7 @@ class SerializationSuite extends AnyFunSuite .deserialize[SerializeConcatHostBuffersDeserializeBatch]( inputStream)) { materialized => // this materializes a new batch from what was deserialized - withResource(materialized.batch.getColumnarBatch()) { gpuBatch => + withResource(materialized.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala index 00209461d3c..a7bea7a5b37 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,21 +16,84 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.shuffle.RapidsShuffleTestHelper +import com.nvidia.spark.rapids.spill.SpillFramework +import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar -import org.apache.spark.sql.rapids.RapidsDiskBlockManager +import org.apache.spark.SparkConf +import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.storage.ShuffleBlockId -class ShuffleBufferCatalogSuite extends AnyFunSuite with MockitoSugar { - test("registered shuffles should be active") { - val catalog = mock[RapidsBufferCatalog] - val rapidsDiskBlockManager = mock[RapidsDiskBlockManager] - val shuffleCatalog = new ShuffleBufferCatalog(catalog, rapidsDiskBlockManager) +class ShuffleBufferCatalogSuite + extends AnyFunSuite with MockitoSugar with BeforeAndAfterEach { + + override def beforeEach(): Unit = { + super.beforeEach() + SpillFramework.initialize(new RapidsConf(new SparkConf)) + } + override def afterEach(): Unit = { + super.afterEach() + SpillFramework.shutdown() + } + + test("registered shuffles should be active") { + val shuffleCatalog = new ShuffleBufferCatalog() assertResult(false)(shuffleCatalog.hasActiveShuffle(123)) shuffleCatalog.registerShuffle(123) assertResult(true)(shuffleCatalog.hasActiveShuffle(123)) shuffleCatalog.unregisterShuffle(123) assertResult(false)(shuffleCatalog.hasActiveShuffle(123)) } + + test("adding a degenerate batch") { + val shuffleCatalog = new ShuffleBufferCatalog() + val tableMeta = mock[TableMeta] + // need to register the shuffle id first + assertThrows[IllegalStateException] { + shuffleCatalog.addDegenerateRapidsBuffer(ShuffleBlockId(1, 1L, 1), tableMeta) + } + shuffleCatalog.registerShuffle(1) + shuffleCatalog.addDegenerateRapidsBuffer(ShuffleBlockId(1,1L,1), tableMeta) + val storedMetas = shuffleCatalog.blockIdToMetas(ShuffleBlockId(1, 1L, 1)) + assertResult(1)(storedMetas.size) + assertResult(tableMeta)(storedMetas.head) + } + + test("adding a contiguous batch adds it to the spill store") { + val shuffleCatalog = new ShuffleBufferCatalog() + val ct = RapidsShuffleTestHelper.buildContiguousTable(1000) + shuffleCatalog.registerShuffle(1) + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + shuffleCatalog.addContiguousTable(ShuffleBlockId(1, 1L, 1), ct, -1) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + val storedMetas = shuffleCatalog.blockIdToMetas(ShuffleBlockId(1, 1L, 1)) + assertResult(1)(storedMetas.size) + shuffleCatalog.unregisterShuffle(1) + } + + test("get a columnar batch iterator from catalog") { + val shuffleCatalog = new ShuffleBufferCatalog() + shuffleCatalog.registerShuffle(1) + // add metadata only table + val tableMeta = RapidsShuffleTestHelper.mockTableMeta(0) + shuffleCatalog.addDegenerateRapidsBuffer(ShuffleBlockId(1, 1L, 1), tableMeta) + val ct = RapidsShuffleTestHelper.buildContiguousTable(1000) + shuffleCatalog.addContiguousTable(ShuffleBlockId(1, 1L, 1), ct, -1) + val iter = + shuffleCatalog.getColumnarBatchIterator( + ShuffleBlockId(1, 1L, 1), Array[DataType](IntegerType)) + withResource(iter.toArray) { cbs => + assertResult(2)(cbs.length) + assertResult(0)(cbs.head.numRows()) + assertResult(1)(cbs.head.numCols()) + assertResult(1000)(cbs.last.numRows()) + assertResult(1)(cbs.last.numCols()) + shuffleCatalog.unregisterShuffle(1) + } + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/WindowRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/WindowRetrySuite.scala index 0ecc5faf1a3..bf61e495c06 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/WindowRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/WindowRetrySuite.scala @@ -80,7 +80,7 @@ class WindowRetrySuite assertResult(4)(hostCol.getLong(row)) } } - verify(inputBatch, times(2)).getColumnarBatch() + verify(inputBatch, times(2)).getColumnarBatch verify(inputBatch, times(1)).close() } } @@ -102,7 +102,7 @@ class WindowRetrySuite assertResult(row + 1)(hostCol.getLong(row)) } } - verify(inputBatch, times(2)).getColumnarBatch() + verify(inputBatch, times(2)).getColumnarBatch verify(inputBatch, times(1)).close() } } @@ -126,7 +126,7 @@ class WindowRetrySuite assertResult(4)(hostCol.getLong(row)) } } - verify(inputBatch, times(2)).getColumnarBatch() + verify(inputBatch, times(2)).getColumnarBatch verify(inputBatch, times(1)).close() } } @@ -143,7 +143,7 @@ class WindowRetrySuite assertThrows[GpuSplitAndRetryOOM] { it.next() } - verify(inputBatch, times(1)).getColumnarBatch() + verify(inputBatch, times(1)).getColumnarBatch verify(inputBatch, times(1)).close() } @@ -173,7 +173,7 @@ class WindowRetrySuite } } } - verify(inputBatch, times(2)).getColumnarBatch() + verify(inputBatch, times(2)).getColumnarBatch verify(inputBatch, times(1)).close() } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala index aa003c454f1..0f37b7566d9 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,11 +21,13 @@ import com.nvidia.spark.Retryable import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitTargetSizeInHalfGpu, withRestoreOnRetry, withRetry, withRetryNoSplit} import com.nvidia.spark.rapids.jni.{GpuRetryOOM, GpuSplitAndRetryOOM, RmmSpark} +import com.nvidia.spark.rapids.spill.SpillFramework import org.mockito.Mockito._ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar +import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.{DataType, LongType} @@ -52,19 +54,16 @@ class WithRetrySuite rmmWasInitialized = true Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) } - val deviceStorage = new RapidsDeviceMemoryStore() - val catalog = new RapidsBufferCatalog(deviceStorage) - RapidsBufferCatalog.setDeviceStorage(deviceStorage) - RapidsBufferCatalog.setCatalog(catalog) + SpillFramework.initialize(new RapidsConf(new SparkConf)) val mockEventHandler = new BaseRmmEventHandler() RmmSpark.setEventHandler(mockEventHandler) RmmSpark.currentThreadIsDedicatedToTask(1) } override def afterEach(): Unit = { + SpillFramework.shutdown() RmmSpark.removeAllCurrentThreadAssociation() RmmSpark.clearEventHandler() - RapidsBufferCatalog.close() if (rmmWasInitialized) { Rmm.shutdown() } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala index e9873b0bc5f..515ff75ffaa 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -255,7 +255,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { verify(client, times(1)).track(any[DeviceMemoryBuffer](), tmCaptor.capture()) verifyTableMeta(tableMeta, tmCaptor.getValue.asInstanceOf[TableMeta]) verify(mockCatalog, times(1)) - .addBuffer(dmbCaptor.capture(), any(), any(), any()) + .addBuffer(dmbCaptor.capture(), any(), any()) val receivedBuff = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer] assertResult(tableMeta.bufferMeta().size())(receivedBuff.getLength) @@ -310,8 +310,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { verify(client, times(1)).track(any[DeviceMemoryBuffer](), tmCaptor.capture()) verifyTableMeta(tableMeta, tmCaptor.getValue.asInstanceOf[TableMeta]) verify(mockCatalog, times(1)) - .addBuffer(dmbCaptor.capture(), any(), any(), any()) - verify(mockCatalog, times(1)).removeBuffer(any()) + .addBuffer(dmbCaptor.capture(), any(), any()) val receivedBuff = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer] assertResult(tableMeta.bufferMeta().size())(receivedBuff.getLength) @@ -367,7 +366,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { } verify(mockCatalog, times(5)) - .addBuffer(dmbCaptor.capture(), any(), any(), any()) + .addBuffer(dmbCaptor.capture(), any(), any()) assertResult(totalExpectedSize)( dmbCaptor.getAllValues().toArray().map(_.asInstanceOf[DeviceMemoryBuffer].getLength).sum) @@ -424,7 +423,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { } verify(mockCatalog, times(20)) - .addBuffer(dmbCaptor.capture(), any(), any(), any()) + .addBuffer(dmbCaptor.capture(), any(), any()) assertResult(totalExpectedSize)( dmbCaptor.getAllValues().toArray().map(_.asInstanceOf[DeviceMemoryBuffer].getLength).sum) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala index 70064682ed0..af24b332c83 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,9 @@ package com.nvidia.spark.rapids.shuffle -import com.nvidia.spark.rapids.{RapidsBuffer, RapidsBufferHandle} +import com.nvidia.spark.rapids.RapidsShuffleHandle import com.nvidia.spark.rapids.jni.RmmSpark +import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ @@ -30,18 +31,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val taskId = 1 try { RmmSpark.currentThreadIsDedicatedToTask(taskId) - val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - - val cl = new RapidsShuffleIterator( - RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - mockConf, - mockTransport, - blocksByAddress, - testMetricsUpdater, - Array.empty, - taskId, - mockCatalog, - 123) + val cl = + RapidsShuffleTestHelper.makeIterator( + mockConf, mockTransport, testMetricsUpdater, taskId, mockCatalog) when(mockTransaction.getStatus).thenReturn(TransactionStatus.Error) @@ -64,18 +56,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { RmmSpark.currentThreadIsDedicatedToTask(taskId) when(mockTransaction.getStatus).thenReturn(status) - val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - - val cl = spy(new RapidsShuffleIterator( - RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - mockConf, - mockTransport, - blocksByAddress, - testMetricsUpdater, - Array.empty, - taskId, - mockCatalog, - 123)) + val cl = + RapidsShuffleTestHelper.makeIterator( + mockConf, mockTransport, testMetricsUpdater, taskId, mockCatalog) val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any())).thenReturn(client) @@ -112,18 +95,10 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val taskId = 1 try { RmmSpark.currentThreadIsDedicatedToTask(taskId) - val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - - val cl = spy(new RapidsShuffleIterator( - RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - mockConf, - mockTransport, - blocksByAddress, - testMetricsUpdater, - Array.empty, - taskId, - mockCatalog, - 123)) + + val cl = + RapidsShuffleTestHelper.makeIterator( + mockConf, mockTransport, testMetricsUpdater, taskId, mockCatalog) val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any())).thenReturn(client) @@ -162,18 +137,10 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val taskId = 1 try { RmmSpark.currentThreadIsDedicatedToTask(taskId) - val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - - val cl = spy(new RapidsShuffleIterator( - RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - mockConf, - mockTransport, - blocksByAddress, - testMetricsUpdater, - Array.empty, - taskId, - mockCatalog, - 123)) + + val cl = + RapidsShuffleTestHelper.makeIterator( + mockConf, mockTransport, testMetricsUpdater, taskId, mockCatalog) when(mockTransport.makeClient(any())).thenReturn(client) doNothing().when(client).doFetch(any(), any()) @@ -198,29 +165,21 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val taskId = 1 try { RmmSpark.currentThreadIsDedicatedToTask(taskId) - val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - - val cl = new RapidsShuffleIterator( - RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - mockConf, - mockTransport, - blocksByAddress, - testMetricsUpdater, - Array.empty, - taskId, - mockCatalog, - 123) + val cl = + RapidsShuffleTestHelper.makeIterator( + mockConf, mockTransport, testMetricsUpdater, taskId, mockCatalog) val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any())).thenReturn(client) doNothing().when(client).doFetch(any(), ac.capture()) - val mockBuffer = mock[RapidsBuffer] + val mockBuffer = RapidsShuffleHandle(mock[SpillableDeviceBufferHandle], null) + when(mockBuffer.spillable.sizeInBytes).thenReturn(123L) val cb = new ColumnarBatch(Array.empty, 10) - val handle = mock[RapidsBufferHandle] - when(mockBuffer.getColumnarBatch(Array.empty)).thenReturn(cb) - when(mockCatalog.acquireBuffer(any[RapidsBufferHandle]())).thenReturn(mockBuffer) - doNothing().when(mockCatalog).removeBuffer(any()) + val handle = mock[RapidsShuffleHandle] + doAnswer(_ => (cb, 123L)).when(mockCatalog) + .getColumnarBatchAndRemove(any[RapidsShuffleHandle](), any()) + cl.start() val handler = ac.getValue.asInstanceOf[RapidsShuffleFetchHandler] @@ -232,7 +191,7 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { assert(cl.hasNext) assertResult(cb)(cl.next()) assertResult(1)(testMetricsUpdater.totalRemoteBlocksFetched) - assertResult(mockBuffer.memoryUsedBytes)(testMetricsUpdater.totalRemoteBytesRead) + assertResult(123L)(testMetricsUpdater.totalRemoteBytesRead) assertResult(10)(testMetricsUpdater.totalRowsFetched) } finally { RmmSpark.taskDone(taskId) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala index 3eb73ef0f13..8d7415fba04 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,73 +18,58 @@ package com.nvidia.spark.rapids.shuffle import java.io.IOException import java.nio.ByteBuffer -import java.util -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.{MetaUtils, RapidsBuffer, ShuffleMetadata} +import ai.rapids.cudf.{DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} +import com.nvidia.spark.rapids.{MetaUtils, RapidsShuffleHandle, ShuffleMetadata} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle import org.mockito.{ArgumentCaptor, ArgumentMatchers} -import org.mockito.ArgumentMatchers.{any, anyLong} +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.apache.spark.storage.ShuffleBlockBatchId -class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { +class MockRapidsShuffleRequestHandler(mockBuffers: Seq[RapidsShuffleHandle]) + extends RapidsShuffleRequestHandler with AutoCloseable { + var acquiredTables = Seq[Int]() + override def getShuffleBufferMetas( + shuffleBlockBatchId: ShuffleBlockBatchId): Seq[TableMeta] = { + throw new NotImplementedError("getShuffleBufferMetas") + } - def setupMocks(deviceBuffers: Seq[DeviceMemoryBuffer]): (RapidsShuffleRequestHandler, - Seq[RapidsBuffer], util.HashMap[RapidsBuffer, Int]) = { + override def getShuffleHandle(tableId: Int): RapidsShuffleHandle = { + acquiredTables = acquiredTables :+ tableId + mockBuffers(tableId) + } - val numCloses = new util.HashMap[RapidsBuffer, Int]() + override def close(): Unit = { + // a removeShuffle action would likewise remove handles + mockBuffers.foreach(_.close()) + } +} + +class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { + + def setupMocks(deviceBuffers: Seq[DeviceMemoryBuffer]): MockRapidsShuffleRequestHandler = { val mockBuffers = deviceBuffers.map { deviceBuffer => withResource(HostMemoryBuffer.allocate(deviceBuffer.getLength)) { hostBuff => fillBuffer(hostBuff) deviceBuffer.copyFromHostBuffer(hostBuff) - val mockBuffer = mock[RapidsBuffer] val mockMeta = RapidsShuffleTestHelper.mockTableMeta(100000) - when(mockBuffer.copyToMemoryBuffer(anyLong(), any[MemoryBuffer](), anyLong(), anyLong(), - any[Cuda.Stream]())).thenAnswer { invocation => - // start at 1 close, since we'll need to close at refcount 0 too - val newNumCloses = numCloses.getOrDefault(mockBuffer, 1) + 1 - numCloses.put(mockBuffer, newNumCloses) - val srcOffset = invocation.getArgument[Long](0) - val dst = invocation.getArgument[MemoryBuffer](1) - val dstOffset = invocation.getArgument[Long](2) - val length = invocation.getArgument[Long](3) - val stream = invocation.getArgument[Cuda.Stream](4) - dst.copyFromMemoryBuffer(dstOffset, deviceBuffer, srcOffset, length, stream) - } - when(mockBuffer.getPackedSizeBytes).thenReturn(deviceBuffer.getLength) - when(mockBuffer.meta).thenReturn(mockMeta) - mockBuffer + RapidsShuffleHandle(SpillableDeviceBufferHandle(deviceBuffer), mockMeta) } } - - val handler = new RapidsShuffleRequestHandler { - var acquiredTables = Seq[Int]() - override def getShuffleBufferMetas( - shuffleBlockBatchId: ShuffleBlockBatchId): Seq[TableMeta] = { - throw new NotImplementedError("getShuffleBufferMetas") - } - - override def acquireShuffleBuffer(tableId: Int): RapidsBuffer = { - acquiredTables = acquiredTables :+ tableId - mockBuffers(tableId) - } - } - (handler, mockBuffers, numCloses) + new MockRapidsShuffleRequestHandler(mockBuffers) } - class MockBlockWithSize(val b: DeviceMemoryBuffer) extends BlockWithSize { - override def size: Long = b.getLength - } + class MockBlockWithSize(override val size: Long) extends BlockWithSize {} def compareRanges( bounceBuffer: SendBounceBuffers, - receiveBlocks: Seq[BlockRange[MockBlockWithSize]]): Unit = { + receiveBlocks: Seq[(BlockRange[MockBlockWithSize], DeviceMemoryBuffer)]): Unit = { var bounceBuffOffset = 0L - receiveBlocks.foreach { range => - val deviceBuff = range.block.b + receiveBlocks.foreach { case (range, deviceBuff) => val deviceBounceBuff = bounceBuffer.deviceBounceBuffer.buffer withResource(deviceBounceBuff.slice(bounceBuffOffset, range.rangeSize())) { bbSlice => bounceBuffOffset = bounceBuffOffset + range.rangeSize() @@ -104,26 +89,24 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { when(mockTx.releaseMessage()).thenReturn(transferRequest) val bb = closeOnExcept(getSendBounceBuffer(10000)) { bounceBuffer => - withResource((0 until 10).map(_ => DeviceMemoryBuffer.allocate(1000))) { deviceBuffers => - val receiveSide = deviceBuffers.map(b => new MockBlockWithSize(b)) + val deviceBuffers = (0 until 10).map(_ => DeviceMemoryBuffer.allocate(1000)) + val receiveSide = deviceBuffers.map(_ => new MockBlockWithSize(1000)) + withResource(setupMocks(deviceBuffers)) { handler => val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) - val (handler, mockBuffers, numCloses) = setupMocks(deviceBuffers) withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss => assert(bss.hasMoreSends) withResource(bss.getBufferToSend()) { mb => val receiveBlocks = receiveWindow.next() - compareRanges(bounceBuffer, receiveBlocks) + compareRanges(bounceBuffer, receiveBlocks.zip(deviceBuffers)) assertResult(10000)(mb.getLength) assert(!bss.hasMoreSends) bss.releaseAcquiredToCatalog() - mockBuffers.foreach { b: RapidsBuffer => - // should have seen 2 closes, one for BufferSendState acquiring for metadata - // and the second acquisition for copying - verify(b, times(numCloses.get(b))).close() - } } } } + deviceBuffers.foreach { b => + assertResult(0)(b.getRefCount) + } bounceBuffer } assert(bb.deviceBounceBuffer.isClosed) @@ -136,32 +119,29 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { when(mockTx.releaseMessage()).thenReturn(transferRequest) val bb = closeOnExcept(getSendBounceBuffer(10000)) { bounceBuffer => - withResource((0 until 20).map(_ => DeviceMemoryBuffer.allocate(1000))) { deviceBuffers => - val receiveSide = deviceBuffers.map(b => new MockBlockWithSize(b)) - val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) - val (handler, mockBuffers, numCloses) = setupMocks(deviceBuffers) + val deviceBuffers = (0 until 20).map(_ => DeviceMemoryBuffer.allocate(1000)) + val receiveSide = deviceBuffers.map(_ => new MockBlockWithSize(1000)) + val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) + withResource(setupMocks(deviceBuffers)) { handler => withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss => withResource(bss.getBufferToSend()) { _ => val receiveBlocks = receiveWindow.next() - compareRanges(bounceBuffer, receiveBlocks) + compareRanges(bounceBuffer, receiveBlocks.zip(deviceBuffers)) assert(bss.hasMoreSends) bss.releaseAcquiredToCatalog() } withResource(bss.getBufferToSend()) { _ => val receiveBlocks = receiveWindow.next() - compareRanges(bounceBuffer, receiveBlocks) + compareRanges(bounceBuffer, receiveBlocks.zip(deviceBuffers)) assert(!bss.hasMoreSends) bss.releaseAcquiredToCatalog() } - - mockBuffers.foreach { b: RapidsBuffer => - // should have seen 2 closes, one for BufferSendState acquiring for metadata - // and the second acquisition for copying - verify(b, times(numCloses.get(b))).close() - } } } + deviceBuffers.foreach { b => + assertResult(0)(b.getRefCount) + } bounceBuffer } assert(bb.deviceBounceBuffer.isClosed) @@ -174,24 +154,23 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { when(mockTx.releaseMessage()).thenReturn(transferRequest) val bb = closeOnExcept(getSendBounceBuffer(10000)) { bounceBuffer => - withResource((0 until 20).map(_ => DeviceMemoryBuffer.allocate(123000))) { deviceBuffers => - val (handler, mockBuffers, numCloses) = setupMocks(deviceBuffers) - - val receiveSide = deviceBuffers.map(b => new MockBlockWithSize(b)) + val deviceBuffers = (0 until 20).map(_ => DeviceMemoryBuffer.allocate(123000)) + withResource(setupMocks(deviceBuffers)) { handler => + val receiveSide = deviceBuffers.map(_ => new MockBlockWithSize(123000)) val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss => (0 until 246).foreach { _ => withResource(bss.getBufferToSend()) { _ => val receiveBlocks = receiveWindow.next() - compareRanges(bounceBuffer, receiveBlocks) + compareRanges(bounceBuffer, receiveBlocks.zip(deviceBuffers)) bss.releaseAcquiredToCatalog() } } assert(!bss.hasMoreSends) } - mockBuffers.foreach { b: RapidsBuffer => - verify(b, times(numCloses.get(b))).close() - } + } + deviceBuffers.foreach { b => + assertResult(0)(b.getRefCount) } bounceBuffer } @@ -224,14 +203,13 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { any(), any(), any(), any[MemoryBuffer](), ac.capture())).thenReturn(mockTransaction) val mockRequestHandler = mock[RapidsShuffleRequestHandler] - val rapidsBuffer = mock[RapidsBuffer] val bb = ByteBuffer.allocateDirect(123) withResource(new RefCountedDirectByteBuffer(bb)) { _ => val tableMeta = MetaUtils.buildTableMeta(1, 456, bb, 100) - when(rapidsBuffer.meta).thenReturn(tableMeta) - when(rapidsBuffer.getPackedSizeBytes).thenReturn(tableMeta.bufferMeta().size()) - when(mockRequestHandler.acquireShuffleBuffer(ArgumentMatchers.eq(1))) + val testHandle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(456)) + val rapidsBuffer = RapidsShuffleHandle(testHandle, tableMeta) + when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(1))) .thenReturn(rapidsBuffer) val server = new RapidsShuffleServer( @@ -245,7 +223,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { server.start() - val bss = new BufferSendState(mockTransaction, mockSendBuffer, mockRequestHandler, null) + val bss = new BufferSendState(mockTransaction, mockSendBuffer, mockRequestHandler) server.doHandleTransferRequest(Seq(bss)) val cb = ac.getValue.asInstanceOf[TransactionCallback] cb(mockTransaction) @@ -253,10 +231,14 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { cb(mockTransaction) // bounce buffers are freed verify(mockSendBuffer, times(1)).close() - // acquire 3 times, and close 3 times - verify(mockRequestHandler, times(3)) - .acquireShuffleBuffer(ArgumentMatchers.eq(1)) - verify(rapidsBuffer, times(3)).close() + // acquire once at the beginning, and closed at the end + verify(mockRequestHandler, times(1)) + .getShuffleHandle(ArgumentMatchers.eq(1)) + withResource(rapidsBuffer.spillable.materialize()) { dmb => + // refcount=2 because it was on the device, and we +1 to materialize. + // but it shows no leaks. + assertResult(2)(dmb.getRefCount) + } } } } @@ -264,74 +246,80 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { test("when we fail to prepare a send, throw if nothing can be handled") { val mockSendBuffer = mock[SendBounceBuffers] val mockDeviceBounceBuffer = mock[BounceBuffer] - withResource(DeviceMemoryBuffer.allocate(123)) { buff => - when(mockDeviceBounceBuffer.buffer).thenReturn(buff) - when(mockSendBuffer.bounceBufferSize).thenReturn(buff.getLength) - when(mockSendBuffer.hostBounceBuffer).thenReturn(None) - when(mockSendBuffer.deviceBounceBuffer).thenReturn(mockDeviceBounceBuffer) - - when(mockTransport.tryGetSendBounceBuffers(any(), any())) - .thenReturn(Seq(mockSendBuffer)) - - val tr = ShuffleMetadata.buildTransferRequest(0, Seq(1)) - when(mockTransaction.getStatus) - .thenReturn(TransactionStatus.Success) - when(mockTransaction.releaseMessage()).thenReturn( - new MetadataTransportBuffer(new RefCountedDirectByteBuffer(tr))) - - val mockServerConnection = mock[ServerConnection] - val mockRequestHandler = mock[RapidsShuffleRequestHandler] - val rapidsBuffer = mock[RapidsBuffer] - - val bb = ByteBuffer.allocateDirect(123) - withResource(new RefCountedDirectByteBuffer(bb)) { _ => - val tableMeta = MetaUtils.buildTableMeta(1, 456, bb, 100) - when(rapidsBuffer.meta).thenReturn(tableMeta) - when(rapidsBuffer.getPackedSizeBytes).thenReturn(tableMeta.bufferMeta().size()) - when(mockRequestHandler.acquireShuffleBuffer(ArgumentMatchers.eq(1))) - .thenReturn(rapidsBuffer) - - val server = spy(new RapidsShuffleServer( - mockTransport, - mockServerConnection, - RapidsShuffleTestHelper.makeMockBlockManager("1", "foo"), - mockRequestHandler, - mockExecutor, - mockBssExecutor, - mockConf)) - - server.start() - - val ioe = new IOException("mmap failed in test") - - when(rapidsBuffer.copyToMemoryBuffer(any(), any(), any(), any(), any())) - .thenAnswer(_ => throw ioe) - - val bss = new BufferSendState(mockTransaction, mockSendBuffer, mockRequestHandler, null) - // if nothing else can be handled, we throw - assertThrows[IllegalStateException] { - try { - server.doHandleTransferRequest(Seq(bss)) - } catch { - case e: Throwable => - assertResult(1)(e.getSuppressed.length) - assertResult(ioe)(e.getSuppressed()(0).getCause) - throw e - } + val mockDeviceMemoryBuffer = mock[DeviceMemoryBuffer] + when(mockDeviceBounceBuffer.buffer).thenReturn(mockDeviceMemoryBuffer) + when(mockSendBuffer.bounceBufferSize).thenReturn(1024) + when(mockSendBuffer.hostBounceBuffer).thenReturn(None) + when(mockSendBuffer.deviceBounceBuffer).thenReturn(mockDeviceBounceBuffer) + + when(mockTransport.tryGetSendBounceBuffers(any(), any())) + .thenReturn(Seq(mockSendBuffer)) + + val tr = ShuffleMetadata.buildTransferRequest(0, Seq(1, 2)) + when(mockTransaction.getStatus) + .thenReturn(TransactionStatus.Success) + when(mockTransaction.releaseMessage()).thenReturn( + new MetadataTransportBuffer(new RefCountedDirectByteBuffer(tr))) + + val mockServerConnection = mock[ServerConnection] + val mockRequestHandler = mock[RapidsShuffleRequestHandler] + + val bb = ByteBuffer.allocateDirect(123) + withResource(new RefCountedDirectByteBuffer(bb)) { _ => + val tableMeta = MetaUtils.buildTableMeta(1, 456, bb, 100) + val mockHandle = mock[SpillableDeviceBufferHandle] + val mockHandleThatThrows = mock[SpillableDeviceBufferHandle] + val mockMaterialized = mock[DeviceMemoryBuffer] + when(mockHandle.sizeInBytes).thenReturn(tableMeta.bufferMeta().size()) + when(mockHandle.materialize()).thenAnswer(_ => mockMaterialized) + + when(mockHandleThatThrows.sizeInBytes).thenReturn(tableMeta.bufferMeta().size()) + val ex = new IllegalStateException("something happened") + when(mockHandleThatThrows.materialize()).thenThrow(ex) + + val rapidsBuffer = RapidsShuffleHandle(mockHandle, tableMeta) + val rapidsBufferThatThrows = RapidsShuffleHandle(mockHandleThatThrows, tableMeta) + + when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(1))) + .thenReturn(rapidsBuffer) + when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(2))) + .thenReturn(rapidsBufferThatThrows) + + val server = spy(new RapidsShuffleServer( + mockTransport, + mockServerConnection, + RapidsShuffleTestHelper.makeMockBlockManager("1", "foo"), + mockRequestHandler, + mockExecutor, + mockBssExecutor, + mockConf)) + + server.start() + + val bss = new BufferSendState(mockTransaction, mockSendBuffer, mockRequestHandler, null) + // if nothing else can be handled, we throw + assertThrows[IllegalStateException] { + try { + server.doHandleTransferRequest(Seq(bss)) + } catch { + case e: Throwable => + assertResult(1)(e.getSuppressed.length) + assertResult(ex)(e.getSuppressed()(0).getCause) + throw e } + } - // since nothing could be handled, we don't try again - verify(server, times(0)).addToContinueQueue(any()) + // since nothing could be handled, we don't try again + verify(server, times(0)).addToContinueQueue(any()) - // bounce buffers are freed - verify(mockSendBuffer, times(1)).close() + // bounce buffers are freed + verify(mockSendBuffer, times(1)).close() - // acquire 2 times, 1 to make the ranges, and the 2 before the copy - // close 2 times corresponding to each open - verify(mockRequestHandler, times(2)) - .acquireShuffleBuffer(ArgumentMatchers.eq(1)) - verify(rapidsBuffer, times(2)).close() - } + verify(mockRequestHandler, times(1)) + .getShuffleHandle(ArgumentMatchers.eq(1)) + + // the spillable that materialized we need to close + verify(mockMaterialized, times(1)).close() } } @@ -367,13 +355,24 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { val mockRequestHandler = mock[RapidsShuffleRequestHandler] - def makeMockBuffer(tableId: Int, bb: ByteBuffer): RapidsBuffer = { - val rapidsBuffer = mock[RapidsBuffer] + def makeMockBuffer(tableId: Int, bb: ByteBuffer, error: Boolean): RapidsShuffleHandle = { val tableMeta = MetaUtils.buildTableMeta(tableId, 456, bb, 100) - when(rapidsBuffer.meta).thenReturn(tableMeta) - when(rapidsBuffer.getPackedSizeBytes).thenReturn(tableMeta.bufferMeta().size()) - when(mockRequestHandler.acquireShuffleBuffer(ArgumentMatchers.eq(tableId))) - .thenReturn(rapidsBuffer) + val rapidsBuffer = if (error) { + val mockHandle = mock[SpillableDeviceBufferHandle] + val rapidsBuffer = RapidsShuffleHandle(mockHandle, tableMeta) + when(mockHandle.sizeInBytes).thenReturn(tableMeta.bufferMeta().size()) + // mock an error with the copy + when(rapidsBuffer.spillable.materialize()) + .thenAnswer(_ => { + throw new IOException("mmap failed in test") + }) + rapidsBuffer + } else { + val testHandle = spy(SpillableDeviceBufferHandle(spy(DeviceMemoryBuffer.allocate(456)))) + RapidsShuffleHandle(testHandle, tableMeta) + } + when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(tableId))) + .thenAnswer(_ => rapidsBuffer) rapidsBuffer } @@ -381,19 +380,8 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { val bb2 = ByteBuffer.allocateDirect(123) withResource(new RefCountedDirectByteBuffer(bb)) { _ => withResource(new RefCountedDirectByteBuffer(bb2)) { _ => - val rapidsBuffer = makeMockBuffer(1, bb) - val rapidsBuffer2 = makeMockBuffer(2, bb2) - - // error with copy - when(rapidsBuffer.copyToMemoryBuffer(any(), any(), any(), any(), any())) - .thenAnswer(_ => { - throw new IOException("mmap failed in test") - }) - - // successful copy - doNothing() - .when(rapidsBuffer2) - .copyToMemoryBuffer(any(), any(), any(), any(), any()) + val rapidsHandle = makeMockBuffer(1, bb, error = true) + val rapidsHandle2 = makeMockBuffer(2, bb2, error = false) val server = spy(new RapidsShuffleServer( mockTransport, @@ -407,14 +395,21 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { server.start() val bssFailed = new BufferSendState( - mockTransaction, mockSendBuffer, mockRequestHandler, null) + mockTransaction, mockSendBuffer, mockRequestHandler) val bssSuccess = spy(new BufferSendState( - mockTransaction2, mockSendBuffer, mockRequestHandler, null)) - - when(bssSuccess.hasMoreSends) - .thenReturn(true) // send 1 bounce buffer length - .thenReturn(false) + mockTransaction2, mockSendBuffer, mockRequestHandler)) + + var callCount = 0 + doAnswer { _ => + callCount += 1 + // send 1 buffer length + if (callCount > 1){ + false + } else { + true + } + }.when(bssSuccess).hasMoreSends // if something else can be handled we don't throw, and re-queue server.doHandleTransferRequest(Seq(bssFailed, bssSuccess)) @@ -427,15 +422,21 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { // the bounce buffer is freed 1 time for `bssSuccess`, but not for `bssFailed` verify(mockSendBuffer, times(1)).close() - // acquire/close 4 times => - // we had two requests for 1 buffer, and each request acquires 2 times and closes - // 2 times. - verify(mockRequestHandler, times(2)) - .acquireShuffleBuffer(ArgumentMatchers.eq(1)) - verify(mockRequestHandler, times(2)) - .acquireShuffleBuffer(ArgumentMatchers.eq(2)) - verify(rapidsBuffer, times(2)).close() - verify(rapidsBuffer2, times(2)).close() + // we obtained the handles once, we don't need to get them again + verify(mockRequestHandler, times(1)) + .getShuffleHandle(ArgumentMatchers.eq(1)) + verify(mockRequestHandler, times(1)) + .getShuffleHandle(ArgumentMatchers.eq(2)) + // this handle fails to materialize + verify(rapidsHandle.spillable, times(1)).materialize() + + // this handle materializes, so make sure we close it + verify(rapidsHandle2.spillable, times(1)).materialize() + withResource(rapidsHandle2.spillable.materialize()) { dmb => + // refcount=2 because it was on the device, and we +1 to materialize. + // but it shows no leaks. + assertResult(2)(dmb.getRefCount) + } } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala new file mode 100644 index 00000000000..31377695fe4 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala @@ -0,0 +1,1105 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.spill + +import java.io.File +import java.math.RoundingMode + +import scala.collection.mutable.ArrayBuffer + +import ai.rapids.cudf._ +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.format.CodecType +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfterAll +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.SparkConf +import org.apache.spark.sql.rapids.RapidsDiskBlockManager +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch + +class SpillFrameworkSuite + extends FunSuiteWithTempDir + with MockitoSugar + with BeforeAndAfterAll { + + override def beforeEach(): Unit = { + super.beforeEach() + val sc = new SparkConf + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1024") + SpillFramework.initialize(new RapidsConf(sc)) + } + + override def afterEach(): Unit = { + super.afterEach() + SpillFramework.shutdown() + } + + private def buildContiguousTable(): (ContiguousTable, Array[DataType]) = { + val (tbl, dataTypes) = buildTable() + withResource(tbl) { _ => + (tbl.contiguousSplit()(0), dataTypes) + } + } + + private def buildTableOfLongs(numRows: Int): (ContiguousTable, Array[DataType])= { + val vals = (0 until numRows).map(_.toLong) + withResource(HostColumnVector.fromLongs(vals: _*)) { hcv => + withResource(hcv.copyToDevice()) { cv => + withResource(new Table(cv)) { table => + (table.contiguousSplit()(0), Array[DataType](LongType)) + } + } + } + } + + private def buildNonContiguousTableOfLongs( + numRows: Int): (Table, Array[DataType])= { + val vals = (0 until numRows).map(_.toLong) + withResource(HostColumnVector.fromLongs(vals: _*)) { hcv => + withResource(hcv.copyToDevice()) { cv => + (new Table(cv), Array[DataType](LongType)) + } + } + } + + private def buildTable(): (Table, Array[DataType]) = { + val tbl = new Table.TestBuilder() + .column(5, null.asInstanceOf[java.lang.Integer], 3, 1) + .column("five", "two", null, null) + .column(5.0, 2.0, 3.0, 1.0) + .decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123) + .build() + val types: Array[DataType] = + Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray + (tbl, types) + } + + private def buildTableWithDuplicate(): (Table, Array[DataType]) = { + withResource(ColumnVector.fromInts(5, null.asInstanceOf[java.lang.Integer], 3, 1)) { intCol => + withResource(ColumnVector.fromStrings("five", "two", null, null)) { stringCol => + withResource(ColumnVector.fromDoubles(5.0, 2.0, 3.0, 1.0)) { doubleCol => + // add intCol twice + (new Table(intCol, intCol, stringCol, doubleCol), + Array(IntegerType, IntegerType, StringType, DoubleType)) + } + } + } + } + + private def buildEmptyTable(): (Table, Array[DataType]) = { + val (tbl, types) = buildTable() + val emptyTbl = withResource(tbl) { _ => + withResource(ColumnVector.fromBooleans(false, false, false, false)) { mask => + tbl.filter(mask) // filter all out + } + } + (emptyTbl, types) + } + + private def testBufferFileDeletion(canShareDiskPaths: Boolean): Unit = { + val (_, handle, _) = addContiguousTableToFramework() + var path: File = null + withResource(handle) { _ => + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) + SpillFramework.stores.hostStore.spill(handle.approxSizeInBytes) + assert(handle.host.isDefined) + assert(handle.host.map(_.disk.isDefined).get) + path = SpillFramework.stores.diskStore.getFile(handle.host.flatMap(_.disk).get.blockId) + assert(path.exists) + } + assert(!path.exists) + } + + private def addContiguousTableToFramework(): ( + Long, SpillableColumnarBatchFromBufferHandle, Array[DataType]) = { + val (ct, dataTypes) = buildContiguousTable() + val bufferSize = ct.getBuffer.getLength + val handle = SpillableColumnarBatchFromBufferHandle(ct, dataTypes) + (bufferSize, handle, dataTypes) + } + + private def addTableToFramework(): (SpillableColumnarBatchHandle, Array[DataType]) = { + // store takes ownership of the table + val (tbl, dataTypes) = buildTable() + val cb = withResource(tbl) { _ => GpuColumnVector.from(tbl, dataTypes) } + val handle = SpillableColumnarBatchHandle(cb) + (handle, dataTypes) + } + + private def addZeroRowsTableToFramework(): (SpillableColumnarBatchHandle, Array[DataType]) = { + val (table, dataTypes) = buildEmptyTable() + val cb = withResource(table) { _ => GpuColumnVector.from(table, dataTypes) } + val handle = SpillableColumnarBatchHandle(cb) + (handle, dataTypes) + } + + private def buildHostBatch(): (ColumnarBatch, Array[DataType]) = { + val (ct, dataTypes) = buildContiguousTable() + val hostCols = withResource(ct) { _ => + withResource(ct.getTable) { tbl => + (0 until tbl.getNumberOfColumns) + .map(c => tbl.getColumn(c).copyToHost()) + } + }.toArray + (new ColumnarBatch( + hostCols.zip(dataTypes).map { case (hostCol, dataType) => + new RapidsHostColumnVector(dataType, hostCol) + }, hostCols.head.getRowCount.toInt), dataTypes) + } + + private def buildHostBatchWithDuplicate(): (ColumnarBatch, Array[DataType]) = { + val (ct, dataTypes) = buildContiguousTable() + val hostCols = withResource(ct) { _ => + withResource(ct.getTable) { tbl => + (0 until tbl.getNumberOfColumns) + .map(c => tbl.getColumn(c).copyToHost()) + } + }.toArray + hostCols.foreach(_.incRefCount()) + (new ColumnarBatch( + (hostCols ++ hostCols).zip(dataTypes ++ dataTypes).map { case (hostCol, dataType) => + new RapidsHostColumnVector(dataType, hostCol) + }, hostCols.head.getRowCount.toInt), dataTypes) + } + + test("add table registers with device store") { + val (ct, dataTypes) = buildContiguousTable() + withResource(SpillableColumnarBatchFromBufferHandle(ct, dataTypes)) { _ => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + } + } + + test("a non-contiguous table is spillable and it is handed over to the store") { + val (tbl, dataTypes) = buildTable() + withResource(SpillableColumnarBatchHandle(tbl, dataTypes)) { handle => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + } + } + + test("a non-contiguous table becomes non-spillable when batch is obtained") { + val (tbl, dataTypes) = buildTable() + withResource(SpillableColumnarBatchHandle(tbl, dataTypes)) { handle => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + withResource(handle.materialize(dataTypes)) { _ => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(!handle.spillable) + assertResult(0)(SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes)) + } + assert(handle.spillable) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assertResult(handle.approxSizeInBytes)( + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes)) + } + } + + test("a non-contiguous table is non-spillable until all columns are returned") { + val (table, dataTypes) = buildTable() + withResource(SpillableColumnarBatchHandle(table, dataTypes)) { handle => + assert(handle.spillable) + val cb = handle.materialize(dataTypes) + assert(!handle.spillable) + val columns = GpuColumnVector.extractBases(cb) + withResource(columns.head) { _ => + columns.head.incRefCount() + withResource(cb) { _ => + assert(!handle.spillable) + } + // still 0 after the batch is closed, because of the extra incRefCount + // for columns.head + assert(!handle.spillable) + } + // columns.head is closed, so now our RapidsTable is spillable again + assert(handle.spillable) + } + } + + test("an aliased non-contiguous table is not spillable (until closing the alias) ") { + val (table, dataTypes) = buildTable() + withResource(SpillableColumnarBatchHandle(table, dataTypes)) { handle => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + withResource(SpillableColumnarBatchHandle(handle.materialize(dataTypes))) { aliasHandle => + assertResult(2)(SpillFramework.stores.deviceStore.numHandles) + assert(!handle.spillable) + assert(!aliasHandle.spillable) + } // we now have two copies in the store + assert(handle.spillable) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + } + } + + test("an aliased contiguous table is not spillable (until closing the alias) ") { + val (table, dataTypes) = buildContiguousTable() + withResource(SpillableColumnarBatchFromBufferHandle(table, dataTypes)) { handle => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + val materialized = handle.materialize(dataTypes) + // note that materialized is a batch "from buffer", it is not a regular batch + withResource(SpillableColumnarBatchFromBufferHandle(materialized)) { aliasHandle => + // we now have two copies in the store + assertResult(2)(SpillFramework.stores.deviceStore.numHandles) + assert(!handle.spillable) + assert(!aliasHandle.spillable) + } + assert(handle.spillable) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + } + } + + test("an non-contiguous table supports duplicated columns") { + val (table, dataTypes) = buildTableWithDuplicate() + withResource(SpillableColumnarBatchHandle(table, dataTypes)) { handle => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + withResource(SpillableColumnarBatchHandle(handle.materialize(dataTypes))) { aliasHandle => + assertResult(2)(SpillFramework.stores.deviceStore.numHandles) + assert(!handle.spillable) + assert(!aliasHandle.spillable) + } // we now have two copies in the store + assert(handle.spillable) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + } + } + + test("a buffer is not spillable until the owner closes columns referencing it") { + val (ct, _) = buildContiguousTable() + // the contract for spillable handles is that they take ownership + // incRefCount to follow that pattern + val buff = ct.getBuffer + buff.incRefCount() + withResource(SpillableDeviceBufferHandle(buff)) { handle => + withResource(ct) { _ => + assert(!handle.spillable) + } + assert(handle.spillable) + } + } + + private def buildContiguousTable(start: Int, numRows: Int): ContiguousTable = { + val vals = (0 until numRows).map(_.toLong + start) + withResource(HostColumnVector.fromLongs(vals: _*)) { hcv => + withResource(hcv.copyToDevice()) { cv => + withResource(HostColumnVector.decimalFromLongs(-3, vals: _*)) { decHcv => + withResource(decHcv.copyToDevice()) { decCv => + withResource(new Table(cv, decCv)) { table => + table.contiguousSplit()(0) + } + } + } + } + } + } + + private def buildCompressedBatch(start: Int, numRows: Int): ColumnarBatch = { + val codec = TableCompressionCodec.getCodec( + CodecType.NVCOMP_LZ4, TableCompressionCodec.makeCodecConfig(new RapidsConf(new SparkConf))) + withResource(codec.createBatchCompressor(0, Cuda.DEFAULT_STREAM)) { compressor => + compressor.addTableToCompress(buildContiguousTable(start, numRows)) + withResource(compressor.finish()) { compressed => + GpuCompressedColumnVector.from(compressed.head) + } + } + } + + private def decompressBatch(cb: ColumnarBatch): ColumnarBatch = { + val schema = new StructType().add("i", LongType) + .add("j", DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 3)) + val sparkTypes = GpuColumnVector.extractTypes(schema) + val codec = TableCompressionCodec.getCodec( + CodecType.NVCOMP_LZ4, TableCompressionCodec.makeCodecConfig(new RapidsConf(new SparkConf))) + withResource(codec.createBatchDecompressor(0, Cuda.DEFAULT_STREAM)) { decompressor => + val gcv = cb.column(0).asInstanceOf[GpuCompressedColumnVector] + // we need to incRefCount since the decompressor closes its inputs + gcv.getTableBuffer.incRefCount() + decompressor.addBufferToDecompress(gcv.getTableBuffer, gcv.getTableMeta.bufferMeta()) + withResource(decompressor.finishAsync()) { decompressed => + MetaUtils.getBatchFromMeta( + decompressed.head, + MetaUtils.dropCodecs(gcv.getTableMeta), + sparkTypes) + } + } + } + + test("a compressed batch can be added and recovered") { + val ct = buildCompressedBatch(0, 1000) + withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => + assert(handle.spillable) + withResource(handle.materialize()) { materialized => + assert(!handle.spillable) + // since we didn't spill, these buffers are exactly the same + assert( + ct.column(0).asInstanceOf[GpuCompressedColumnVector].getTableBuffer == + materialized.column(0).asInstanceOf[GpuCompressedColumnVector].getTableBuffer) + } + assert(handle.spillable) + } + } + + test("a compressed batch can be added and recovered after being spilled to host") { + val ct = buildCompressedBatch(0, 1000) + withResource(decompressBatch(ct)) { decompressedExpected => + withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => + assert(handle.spillable) + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) + assert(!handle.spillable) + assert(handle.dev.isEmpty) + assert(handle.host.isDefined) + withResource(handle.materialize()) { materialized => + withResource(decompressBatch(materialized)) { decompressed => + TestUtils.compareBatches(decompressedExpected, decompressed) + } + } + } + } + } + + test("a compressed batch can be added and recovered after being spilled to disk") { + val ct = buildCompressedBatch(0, 1000) + withResource(decompressBatch(ct)) { decompressedExpected => + withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => + assert(handle.spillable) + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) + assert(!handle.spillable) + SpillFramework.stores.hostStore.spill(handle.approxSizeInBytes) + assert(handle.dev.isEmpty) + assert(handle.host.isDefined) + assert(handle.host.get.host.isEmpty) + assert(handle.host.get.disk.isDefined) + withResource(handle.materialize()) { materialized => + withResource(decompressBatch(materialized)) { decompressed => + TestUtils.compareBatches(decompressedExpected, decompressed) + } + } + } + } + } + + + test("a second handle prevents buffer to be spilled") { + val buffer = DeviceMemoryBuffer.allocate(123) + val handle1 = SpillableDeviceBufferHandle(buffer) + // materialize will incRefCount `buffer`. This looks a little weird + // but it simulates aliasing as it happens in real code + val handle2 = SpillableDeviceBufferHandle(handle1.materialize()) + + withResource(handle1) { _ => + withResource(handle2) { _ => + assertResult(2)(handle1.dev.get.getRefCount) + assertResult(2)(handle2.dev.get.getRefCount) + assertResult(false)(handle1.spillable) + assertResult(false)(handle2.spillable) + } + assertResult(1)(handle1.dev.get.getRefCount) + assertResult(true)(handle1.spillable) + } + } + + test("removing handle releases buffer resources in all stores") { + val handle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(123)) + withResource(handle) { _ => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assertResult(0)(SpillFramework.stores.hostStore.numHandles) + assertResult(0)(SpillFramework.stores.diskStore.numHandles) + + assertResult(123)(SpillFramework.stores.deviceStore.spill(123)) // spill to host memory + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + assertResult(1)(SpillFramework.stores.hostStore.numHandles) + assertResult(0)(SpillFramework.stores.diskStore.numHandles) + assert(handle.dev.isEmpty) + assert(handle.host.isDefined) + assert(handle.host.get.host.isDefined) + + assertResult(123)(SpillFramework.stores.hostStore.spill(123)) // spill to disk + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + assertResult(0)(SpillFramework.stores.hostStore.numHandles) + assertResult(1)(SpillFramework.stores.diskStore.numHandles) + assert(handle.dev.isEmpty) + assert(handle.host.isDefined) + assert(handle.host.get.host.isEmpty) + assert(handle.host.get.disk.isDefined) + } + assert(handle.host.isEmpty) + assert(handle.dev.isEmpty) + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + assertResult(0)(SpillFramework.stores.hostStore.numHandles) + assertResult(0)(SpillFramework.stores.diskStore.numHandles) + } + + test("spill updates store state") { + val diskStore = SpillFramework.stores.diskStore + val hostStore = SpillFramework.stores.hostStore + val deviceStore = SpillFramework.stores.deviceStore + + val (bufferSize, handle, _) = + addContiguousTableToFramework() + + withResource(handle) { _ => + assertResult(1)(deviceStore.numHandles) + assertResult(0)(diskStore.numHandles) + assertResult(0)(hostStore.numHandles) + + assertResult(bufferSize)(SpillFramework.stores.deviceStore.spill(bufferSize)) + assertResult(bufferSize)(SpillFramework.stores.hostStore.spill(bufferSize)) + + assertResult(0)(deviceStore.numHandles) + assertResult(0)(hostStore.numHandles) + assertResult(1)(diskStore.numHandles) + + val diskHandle = handle.host.flatMap(_.disk).get + val path = diskStore.getFile(diskHandle.blockId) + assert(path.exists) + } + } + + test("get columnar batch after host spill") { + val (ct, dataTypes) = buildContiguousTable() + val expectedBatch = GpuColumnVector.from(ct.getTable, dataTypes) + withResource(SpillableColumnarBatchFromBufferHandle( + ct, dataTypes)) { handle => + withResource(expectedBatch) { _ => + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) + withResource(handle.materialize(dataTypes)) { cb => + TestUtils.compareBatches(expectedBatch, cb) + } + } + } + } + + test("get memory buffer after host spill") { + val (ct, dataTypes) = buildContiguousTable() + val expectedBatch = closeOnExcept(ct) { _ => + // make a copy of the table so we can compare it later to the + // one reconstituted after the spill + withResource(ct.getTable.contiguousSplit()) { copied => + GpuColumnVector.from(copied(0).getTable, dataTypes) + } + } + val handle = SpillableColumnarBatchFromBufferHandle(ct, dataTypes) + withResource(handle) { _ => + withResource(expectedBatch) { _ => + assertResult(SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes))( + handle.approxSizeInBytes) + val hostSize = handle.host.get.approxSizeInBytes + assertResult(SpillFramework.stores.hostStore.spill(hostSize))(hostSize) + withResource(handle.materialize(dataTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } + } + } + + test("host originated: get host memory buffer") { + val spillPriority = -10 + val hmb = HostMemoryBuffer.allocate(1L * 1024) + val spillableBuffer = SpillableHostBuffer(hmb, hmb.getLength, spillPriority) + withResource(spillableBuffer) { _ => + // the refcount of 1 is the store + assertResult(1)(hmb.getRefCount) + withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => + assertResult(hmb)(memoryBuffer) + assertResult(2)(memoryBuffer.getRefCount) + } + } + assertResult(0)(hmb.getRefCount) + } + + test("host originated: get host memory buffer after spill to disk") { + val spillPriority = -10 + val hmb = HostMemoryBuffer.allocate(1L * 1024) + val spillableBuffer = SpillableHostBuffer( + hmb, + hmb.getLength, + spillPriority) + assertResult(1)(hmb.getRefCount) + // we spill it + SpillFramework.stores.hostStore.spill(hmb.getLength) + withResource(spillableBuffer) { _ => + // the refcount of the original buffer is 0 because it spilled + assertResult(0)(hmb.getRefCount) + withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => + assertResult(memoryBuffer.getLength)(hmb.getLength) + } + } + } + + test("host originated: a buffer is not spillable when we leak it") { + val spillPriority = -10 + val hmb = HostMemoryBuffer.allocate(1L * 1024) + withResource(SpillableHostBuffer(hmb, hmb.getLength, spillPriority)) { spillableBuffer => + withResource(spillableBuffer.getHostBuffer()) { _ => + assertResult(0)(SpillFramework.stores.hostStore.spill(hmb.getLength)) + } + assertResult(hmb.getLength)(SpillFramework.stores.hostStore.spill(hmb.getLength)) + } + } + + test("host originated: a host batch is not spillable when we leak it") { + val (hostCb, sparkTypes) = buildHostBatch() + val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) + withResource(SpillableHostColumnarBatchHandle(hostCb)) { handle => + assertResult(true)(handle.spillable) + + withResource(handle.materialize(sparkTypes)) { _ => + // 0 because we have a reference to the host batch + assertResult(false)(handle.spillable) + assertResult(0)(SpillFramework.stores.hostStore.spill(sizeOnHost)) + } + + // after closing we still have 0 bytes in the store or available to spill + assertResult(true)(handle.spillable) + } + } + + test("host originated: a host batch is not spillable when columns are incRefCounted") { + val (hostCb, sparkTypes) = buildHostBatch() + val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) + withResource(SpillableHostColumnarBatchHandle(hostCb)) { handle => + assertResult(true)(handle.spillable) + val leakedFirstColumn = withResource(handle.materialize(sparkTypes)) { cb => + // 0 because we have a reference to the host batch + assertResult(false)(handle.spillable) + assertResult(0)(SpillFramework.stores.hostStore.spill(sizeOnHost)) + // leak it by increasing the ref count of the underlying cuDF column + RapidsHostColumnVector.extractBases(cb).head.incRefCount() + } + withResource(leakedFirstColumn) { _ => + // 0 because we have a reference to the first column + assertResult(false)(handle.spillable) + assertResult(0)(SpillFramework.stores.hostStore.spill(sizeOnHost)) + } + // batch is now spillable because we close our reference to the column + assertResult(true)(handle.spillable) + assertResult(sizeOnHost)(SpillFramework.stores.hostStore.spill(sizeOnHost)) + } + } + + test("host originated: an aliased host batch is not spillable (until closing the original) ") { + val (hostBatch, sparkTypes) = buildHostBatch() + val handle = SpillableHostColumnarBatchHandle(hostBatch) + withResource(handle) { _ => + assertResult(1)(SpillFramework.stores.hostStore.numHandles) + assertResult(true)(handle.spillable) + withResource(handle.materialize(sparkTypes)) { _ => + assertResult(false)(handle.spillable) + } // we now have two copies in the store + assertResult(true)(handle.spillable) + } + } + + test("host originated: a host batch supports duplicated columns") { + val (hostBatch, sparkTypes) = buildHostBatchWithDuplicate() + val handle = SpillableHostColumnarBatchHandle(hostBatch) + withResource(handle) { _ => + assertResult(1)(SpillFramework.stores.hostStore.numHandles) + assertResult(true)(handle.spillable) + withResource(handle.materialize(sparkTypes)) { _ => + assertResult(false)(handle.spillable) + } // we now have two copies in the store + assertResult(true)(handle.spillable) + } + } + + test("host originated: a host batch supports aliasing and duplicated columns") { + SpillFramework.shutdown() + val sc = new SparkConf + // disables the host store limit by enabling off heap limits + sc.set(RapidsConf.OFF_HEAP_LIMIT_ENABLED.key, "true") + SpillFramework.initialize(new RapidsConf(sc)) + + try { + val (hostBatch, sparkTypes) = buildHostBatchWithDuplicate() + withResource(SpillableHostColumnarBatchHandle(hostBatch)) { handle => + withResource(SpillableHostColumnarBatchHandle(handle.materialize(sparkTypes))) { handle2 => + assertResult(2)(SpillFramework.stores.hostStore.numHandles) + assertResult(false)(handle.spillable) + assertResult(false)(handle2.spillable) + } + assertResult(true)(handle.spillable) + } + } finally { + SpillFramework.shutdown() + } + } + + // this is a key behavior that we wanted to keep during the spill refactor + // where host objects that are added directly to the store do not cause a + // host->disk spill on their own, instead they will get spilled later + // due to device->host spills. + test("host factory methods do not spill on addition") { + SpillFramework.shutdown() + val sc = new SparkConf + // set a very small store size + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1KB") + SpillFramework.initialize(new RapidsConf(sc)) + + try { + // add a lot of batches, surpassing the limits of the store + val handles = new ArrayBuffer[SpillableHostColumnarBatchHandle]() + var dataTypes: Array[DataType] = null + (0 until 100).foreach { _ => + val (hostBatch, dt) = buildHostBatch() + if (dataTypes == null) { + dataTypes = dt + } + handles.append(SpillableHostColumnarBatchHandle(hostBatch)) + } + // no spill to disk + assertResult(100)(SpillFramework.stores.hostStore.numHandles) + + val dmb = DeviceMemoryBuffer.allocate(1024) + withResource(SpillableDeviceBufferHandle(dmb)) { _ => + // simulate an OOM by spilling device memory + SpillFramework.stores.deviceStore.spill(1024) + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + } + + val buffersSpilledToDisk = SpillFramework.stores.diskStore.numHandles + // we spilled to disk + assert(SpillFramework.stores.diskStore.numHandles > 0) + // and the remaining objects that didn't spill, are still in the host store + assertResult(100 - buffersSpilledToDisk)(SpillFramework.stores.hostStore.numHandles) + assert(SpillFramework.stores.hostStore.totalSize <= 1024) + } finally { + SpillFramework.shutdown() + } + } + + test("direct spill to disk: when buffer exceeds maximum size") { + var (bigTable, sparkTypes) = buildTableOfLongs(2 * 1024 * 1024) + closeOnExcept(bigTable) { _ => + // make a copy of the table so we can compare it later to the + // one reconstituted after the spill + val expectedBatch = + withResource(bigTable.getTable.contiguousSplit()) { expectedTable => + GpuColumnVector.from(expectedTable(0).getTable, sparkTypes) + } + withResource(expectedBatch) { _ => + withResource(SpillableColumnarBatchFromBufferHandle( + bigTable, sparkTypes)) { bigHandle => + bigTable = null + withResource(bigHandle.materialize(sparkTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + SpillFramework.stores.deviceStore.spill(bigHandle.approxSizeInBytes) + assertResult(true)(bigHandle.dev.isEmpty) + assertResult(true)(bigHandle.host.get.host.isEmpty) + assertResult(false)(bigHandle.host.get.disk.isEmpty) + + withResource(bigHandle.materialize(sparkTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } + } + } + } + + test("get columnar batch after spilling to disk") { + val (size, handle, dataTypes) = addContiguousTableToFramework() + val diskStore = SpillFramework.stores.diskStore + val hostStore = SpillFramework.stores.hostStore + val deviceStore = SpillFramework.stores.deviceStore + withResource(handle) { _ => + assertResult(1)(deviceStore.numHandles) + assertResult(0)(diskStore.numHandles) + assertResult(0)(hostStore.numHandles) + + val expectedTable = + withResource(handle.materialize(dataTypes)) { beforeSpill => + withResource(GpuColumnVector.from(beforeSpill)) { table => + table.contiguousSplit()(0) + } + } // closing the batch from the store so that we can spill it + + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable.getTable, dataTypes)) { expectedBatch => + deviceStore.spill(size) + hostStore.spill(size) + + assertResult(0)(deviceStore.numHandles) + assertResult(0)(hostStore.numHandles) + assertResult(1)(diskStore.numHandles) + + val diskHandle = handle.host.flatMap(_.disk).get + val path = diskStore.getFile(diskHandle.blockId) + assert(path.exists) + withResource(handle.materialize(dataTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } + } + } + } + + // -1 disables the host store limit + val hostSpillStorageSizes = Seq("-1", "1MB", "16MB") + val spillToDiskBounceBuffers = Seq("128KB", "2MB", "128MB") + val chunkedPackBounceBuffers = Seq("1MB", "8MB", "128MB") + hostSpillStorageSizes.foreach { hostSpillStorageSize => + spillToDiskBounceBuffers.foreach { spillToDiskBounceBufferSize => + chunkedPackBounceBuffers.foreach { chunkedPackBounceBufferSize => + test("materialize non-contiguous batch after " + + s"host_storage_size=$hostSpillStorageSize " + + s"spilling chunked_pack_bb=$chunkedPackBounceBufferSize " + + s"spill_to_disk_bb=$spillToDiskBounceBufferSize") { + SpillFramework.shutdown() + try { + val sc = new SparkConf + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, hostSpillStorageSize) + sc.set(RapidsConf.CHUNKED_PACK_BOUNCE_BUFFER_SIZE.key, chunkedPackBounceBufferSize) + sc.set(RapidsConf.SPILL_TO_DISK_BOUNCE_BUFFER_SIZE.key, spillToDiskBounceBufferSize) + SpillFramework.initialize(new RapidsConf(sc)) + val (largeTable, dataTypes) = buildNonContiguousTableOfLongs(numRows = 1000000) + val handle = SpillableColumnarBatchHandle(largeTable, dataTypes) + val diskStore = SpillFramework.stores.diskStore + val hostStore = SpillFramework.stores.hostStore + val deviceStore = SpillFramework.stores.deviceStore + withResource(handle) { _ => + assertResult(1)(deviceStore.numHandles) + assertResult(0)(diskStore.numHandles) + assertResult(0)(hostStore.numHandles) + + val expectedTable = + withResource(handle.materialize(dataTypes)) { beforeSpill => + withResource(GpuColumnVector.from(beforeSpill)) { table => + table.contiguousSplit()(0) + } + } // closing the batch from the store so that we can spill it + + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable.getTable, dataTypes)) { expectedBatch => + deviceStore.spill(handle.approxSizeInBytes) + hostStore.spill(handle.approxSizeInBytes) + + assertResult(0)(deviceStore.numHandles) + assertResult(0)(hostStore.numHandles) + assertResult(1)(diskStore.numHandles) + + val diskHandle = handle.host.flatMap(_.disk).get + val path = diskStore.getFile(diskHandle.blockId) + assert(path.exists) + withResource(handle.materialize(dataTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } + } + } + } finally { + SpillFramework.shutdown() + } + } + } + } + } + + test("get memory buffer after spilling to disk") { + val handle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(123)) + val diskStore = SpillFramework.stores.diskStore + val hostStore = SpillFramework.stores.hostStore + val deviceStore = SpillFramework.stores.deviceStore + withResource(handle) { _ => + assertResult(1)(deviceStore.numHandles) + assertResult(0)(diskStore.numHandles) + assertResult(0)(hostStore.numHandles) + val expectedBuffer = + withResource(handle.materialize()) { devbuf => + closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => + hostbuf.copyFromDeviceBuffer(devbuf) + hostbuf + } + } + withResource(expectedBuffer) { expectedBuffer => + deviceStore.spill(handle.approxSizeInBytes) + hostStore.spill(handle.approxSizeInBytes) + withResource(handle.host.map(_.materialize()).get) { actualHostBuffer => + assertResult(expectedBuffer. + asByteBuffer.limit())(actualHostBuffer.asByteBuffer.limit()) + } + } + } + } + + test("Compression on with or without encryption for spill block using single batch") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.io.compression.codec", "zstd") + conf.set("spark.shuffle.spill.compress", "true") + conf.set("spark.shuffle.compress", "true") + readWriteTestWithBatches(conf, false) + } + } + + test("Compression off with or without encryption for spill block using single batch") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.shuffle.spill.compress", "false") + conf.set("spark.shuffle.compress", "false") + readWriteTestWithBatches(conf, false) + } + } + + test("Compression on with or without encryption for spill block using multiple batches") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.io.compression.codec", "zstd") + conf.set("spark.shuffle.spill.compress", "true") + conf.set("spark.shuffle.compress", "true") + readWriteTestWithBatches(conf, false) + } + } + + test("Compression off with or without encryption for spill block using multiple batches") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.shuffle.spill.compress", "false") + conf.set("spark.shuffle.compress", "false") + readWriteTestWithBatches(conf, false) + } + } + + // ===== Tests for shuffle block ===== + + test("Compression on with or without encryption for shuffle block using single batch") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.io.compression.codec", "zstd") + conf.set("spark.shuffle.spill.compress", "true") + conf.set("spark.shuffle.compress", "true") + readWriteTestWithBatches(conf, true) + } + } + + test("Compression off with or without encryption for shuffle block using single batch") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.shuffle.spill.compress", "false") + conf.set("spark.shuffle.compress", "false") + readWriteTestWithBatches(conf, true) + } + } + + test("Compression on with or without encryption for shuffle block using multiple batches") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.io.compression.codec", "zstd") + conf.set("spark.shuffle.spill.compress", "true") + conf.set("spark.shuffle.compress", "true") + readWriteTestWithBatches(conf, true, true) + } + } + + test("Compression off with or without encryption for shuffle block using multiple batches") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.shuffle.spill.compress", "false") + conf.set("spark.shuffle.compress", "false") + readWriteTestWithBatches(conf, true, true) + } + } + + test("No encryption and compression for shuffle block using multiple batches") { + readWriteTestWithBatches(new SparkConf(), true, true) + } + + private def readWriteTestWithBatches(conf: SparkConf, shareDiskPaths: Boolean*) = { + assert(shareDiskPaths.nonEmpty) + val mockDiskBlockManager = mock[RapidsDiskBlockManager] + when(mockDiskBlockManager.getSerializerManager()) + .thenReturn(new RapidsSerializerManager(conf)) + + shareDiskPaths.foreach { _ => + val (_, handle, dataTypes) = addContiguousTableToFramework() + withResource(handle) { _ => + val expectedCt = withResource(handle.materialize(dataTypes)) { devbatch => + withResource(GpuColumnVector.from(devbatch)) { tmpTbl => + tmpTbl.contiguousSplit()(0) + } + } + withResource(expectedCt) { _ => + val expectedBatch = withResource(expectedCt.getTable) { expectedTbl => + GpuColumnVector.from(expectedTbl, dataTypes) + } + withResource(expectedBatch) { _ => + assertResult(true)( + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) > 0) + assertResult(true)( + SpillFramework.stores.hostStore.spill(handle.approxSizeInBytes) > 0) + withResource(handle.materialize(dataTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } + } + } + } + } + + test("skip host: spill device memory buffer to disk") { + SpillFramework.shutdown() + try { + val sc = new SparkConf + // disables the host store limit + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1KB") + SpillFramework.initialize(new RapidsConf(sc)) + // buffer is too big for host store limit, so we will skip host + val handle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(1025)) + val deviceStore = SpillFramework.stores.deviceStore + withResource(handle) { _ => + val expectedBuffer = + withResource(handle.materialize()) { devbuf => + closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => + hostbuf.copyFromDeviceBuffer(devbuf) + hostbuf + } + } + + withResource(expectedBuffer) { _ => + // host store will fail to spill + deviceStore.spill(handle.approxSizeInBytes) + assert(handle.host.map(_.host.isEmpty).get) + assert(handle.host.map(_.disk.isDefined).get) + withResource(handle.host.map(_.materialize()).get) { buffer => + assertResult(expectedBuffer.asByteBuffer)(buffer.asByteBuffer) + } + } + } + } finally { + SpillFramework.shutdown() + } + } + + test("skip host: spill table to disk") { + SpillFramework.shutdown() + try { + val sc = new SparkConf + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1KB") + SpillFramework.initialize(new RapidsConf(sc)) + // fill up the host store + withResource(SpillableHostBufferHandle(HostMemoryBuffer.allocate(1024))) { hostHandle => + // make sure the host handle isn't spillable + withResource(hostHandle.materialize()) { _ => + val (handle, _) = addTableToFramework() + withResource(handle) { _ => + val (expectedTable, dataTypes) = buildTable() + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable, dataTypes)) { expectedBatch => + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) + assert(handle.host.map(_.host.isEmpty).get) + assert(handle.host.map(_.disk.isDefined).get) + withResource(handle.materialize(dataTypes)) { fromDiskBatch => + TestUtils.compareBatches(expectedBatch, fromDiskBatch) + assert(handle.dev.isEmpty) + assert(handle.host.map(_.host.isEmpty).get) + assert(handle.host.map(_.disk.isDefined).get) + } + } + } + } + } + } + } finally { + SpillFramework.shutdown() + } + } + + test("skip host: spill table to disk with small host bounce buffer") { + try { + SpillFramework.shutdown() + val sc = new SparkConf + // make this super small so we skip the host + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1") + sc.set(RapidsConf.SPILL_TO_DISK_BOUNCE_BUFFER_SIZE.key, "10") + sc.set(RapidsConf.CHUNKED_PACK_BOUNCE_BUFFER_SIZE.key, "1MB") + val rapidsConf = new RapidsConf(sc) + SpillFramework.initialize(rapidsConf) + val (handle, _) = addTableToFramework() + withResource(handle) { _ => + val (expectedTable, dataTypes) = buildTable() + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable, dataTypes)) { expectedBatch => + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) + assert(handle.dev.isEmpty) + assert(handle.host.map(_.host.isEmpty).get) + assert(handle.host.map(_.disk.isDefined).get) + withResource(handle.materialize(dataTypes)) { fromDiskBatch => + TestUtils.compareBatches(expectedBatch, fromDiskBatch) + } + } + } + } + } finally { + SpillFramework.shutdown() + } + } + + test("0-byte table is never spillable") { + val (handle, _) = addZeroRowsTableToFramework() + val (handle2, _) = addTableToFramework() + + withResource(handle) { _ => + withResource(handle2) { _ => + assert(handle2.host.isEmpty) + val (expectedTable, expectedTypes) = buildTable() + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable, expectedTypes)) { expectedCb => + SpillFramework.stores.deviceStore.spill( + handle.approxSizeInBytes + handle2.approxSizeInBytes) + SpillFramework.stores.hostStore.spill( + handle.approxSizeInBytes + handle2.approxSizeInBytes) + // the 0-byte table never moved from device. It is not spillable + assert(handle.host.isEmpty) + assert(!handle.spillable) + // the second table (with rows) did spill + assert(handle2.host.isDefined) + assert(handle2.host.map(_.host.isEmpty).get) + assert(handle2.host.map(_.disk.isDefined).get) + + withResource(handle2.materialize(expectedTypes)) { spilledBatch => + TestUtils.compareBatches(expectedCb, spilledBatch) + } + } + } + } + } + } + + test("exclusive spill files are deleted when buffer deleted") { + testBufferFileDeletion(canShareDiskPaths = false) + } + + test("shared spill files are not deleted when a buffer is deleted") { + testBufferFileDeletion(canShareDiskPaths = true) + } + +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala index a9618a448cf..26ac7a177a3 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala @@ -65,11 +65,13 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl * Create a Parquet file to test */ override def beforeAll(): Unit = { + super.beforeAll() withCpuSparkSession( spark => createDF(spark).write.mode("overwrite").parquet(path)) } override def afterAll(): Unit = { + super.afterAll() FileUtils.deleteRecursively(new File(path)) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZoneSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZoneSuite.scala index dcfbc508034..b22f827916b 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZoneSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZoneSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -319,6 +319,7 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAll { } override def afterAll(): Unit = { + super.afterAll() if (useGPU) { GpuTimeZoneDB.shutdown() } diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala index d52c8b47ae7..033173468fd 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala @@ -15,10 +15,11 @@ */ package org.apache.spark.sql.rapids -import ai.rapids.cudf.TableWriter -import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuColumnVector, GpuLiteral, RapidsBufferCatalog, RapidsDeviceMemoryStore, ScalableTaskCompletion} +import ai.rapids.cudf.{Rmm, RmmAllocationMode, TableWriter} +import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuColumnVector, GpuLiteral, RapidsConf, ScalableTaskCompletion} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.jni.{GpuRetryOOM, GpuSplitAndRetryOOM} +import com.nvidia.spark.rapids.spill.SpillFramework import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FSDataOutputStream import org.apache.hadoop.mapred.TaskAttemptContext @@ -28,6 +29,7 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar.mock +import org.apache.spark.SparkConf import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, ExprId, SortOrder} @@ -42,7 +44,6 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { private var mockCommitter: FileCommitProtocol = _ private var mockOutputWriterFactory: ColumnarOutputWriterFactory = _ private var mockOutputWriter: NoTransformColumnarOutputWriter = _ - private var devStore: RapidsDeviceMemoryStore = _ private var allCols: Seq[AttributeReference] = _ private var partSpec: Seq[AttributeReference] = _ private var dataSpec: Seq[AttributeReference] = _ @@ -175,16 +176,13 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { } override def beforeEach(): Unit = { - devStore = new RapidsDeviceMemoryStore() - val catalog = new RapidsBufferCatalog(devStore) - RapidsBufferCatalog.setCatalog(catalog) + Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) + SpillFramework.initialize(new RapidsConf(new SparkConf)) } override def afterEach(): Unit = { - // test that no buffers we left in the spill framework - assertResult(0)(RapidsBufferCatalog.numBuffers) - RapidsBufferCatalog.close() - devStore.close() + SpillFramework.shutdown() + Rmm.shutdown() } def buildEmptyBatch: ColumnarBatch = diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala index 001f82ab3a0..9b7a1dbdddd 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,53 +16,30 @@ package org.apache.spark.sql.rapids -import java.util.UUID - -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.{RapidsBuffer, RapidsBufferCatalog, RapidsBufferId, SpillableColumnarBatchImpl, StorageTier} -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta +import ai.rapids.cudf.DeviceMemoryBuffer +import com.nvidia.spark.rapids.{RapidsConf, SpillableBuffer} +import com.nvidia.spark.rapids.spill.SpillFramework +import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite -import org.apache.spark.sql.types.{DataType, IntegerType} -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.storage.TempLocalBlockId +import org.apache.spark.SparkConf -class SpillableColumnarBatchSuite extends AnyFunSuite { +class SpillableColumnarBatchSuite extends AnyFunSuite with BeforeAndAfterAll { + override def beforeAll(): Unit = { + super.beforeAll() + SpillFramework.initialize(new RapidsConf(new SparkConf())) + } - test("close updates catalog") { - val id = TempSpillBufferId(0, TempLocalBlockId(new UUID(1, 2))) - val mockBuffer = new MockBuffer(id) - val catalog = RapidsBufferCatalog.singleton - val oldBufferCount = catalog.numBuffers - catalog.registerNewBuffer(mockBuffer) - val handle = catalog.makeNewHandle(id, -1) - assertResult(oldBufferCount + 1)(catalog.numBuffers) - val spillableBatch = new SpillableColumnarBatchImpl( - handle, - 5, - Array[DataType](IntegerType)) - spillableBatch.close() - assertResult(oldBufferCount)(catalog.numBuffers) + override def afterAll(): Unit = { + super.afterAll() + SpillFramework.shutdown() } - class MockBuffer(override val id: RapidsBufferId) extends RapidsBuffer { - override val memoryUsedBytes: Long = 123 - override def meta: TableMeta = null - override val storageTier: StorageTier = StorageTier.DEVICE - override def getMemoryBuffer: MemoryBuffer = null - override def copyToMemoryBuffer(srcOffset: Long, dst: MemoryBuffer, dstOffset: Long, - length: Long, stream: Cuda.Stream): Unit = {} - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = null - override def getHostMemoryBuffer: HostMemoryBuffer = null - override def addReference(): Boolean = true - override def free(): Unit = {} - override def getSpillPriority: Long = 0 - override def setSpillPriority(priority: Long): Unit = {} - override def close(): Unit = {} - override def getColumnarBatch( - sparkTypes: Array[DataType]): ColumnarBatch = null - override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { body(null) } - override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { body(null) } + test("close updates catalog") { + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + val deviceHandle = SpillableBuffer(DeviceMemoryBuffer.allocate(1234), -1) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + deviceHandle.close() + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) } } diff --git a/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index ab303d8098e..525f30c7a87 100644 --- a/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -36,10 +36,11 @@ package com.nvidia.spark.rapids.shuffle import java.nio.ByteBuffer import java.util.concurrent.Executor +import scala.collection.immutable import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer} -import com.nvidia.spark.rapids.{GpuColumnVector, MetaUtils, RapidsBufferHandle, RapidsConf, RapidsDeviceMemoryStore, RmmSparkRetrySuiteBase, ShuffleMetadata, ShuffleReceivedBufferCatalog} +import com.nvidia.spark.rapids.{GpuColumnVector, MetaUtils, RapidsConf, RapidsShuffleHandle, RmmSparkRetrySuiteBase, ShuffleMetadata, ShuffleReceivedBufferCatalog} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.TableMeta import org.mockito.ArgumentCaptor @@ -79,7 +80,6 @@ abstract class RapidsShuffleTestHelper var mockCopyExecutor: Executor = _ var mockBssExecutor: Executor = _ var mockHandler: RapidsShuffleFetchHandler = _ - var mockStorage: RapidsDeviceMemoryStore = _ var mockCatalog: ShuffleReceivedBufferCatalog = _ var mockConf: RapidsConf = _ var testMetricsUpdater: TestShuffleMetricsUpdater = _ @@ -160,11 +160,11 @@ abstract class RapidsShuffleTestHelper testMetricsUpdater = spy(new TestShuffleMetricsUpdater) val dmbCaptor = ArgumentCaptor.forClass(classOf[DeviceMemoryBuffer]) - when(mockCatalog.addBuffer(dmbCaptor.capture(), any(), any(), any())) + when(mockCatalog.addBuffer(dmbCaptor.capture(), any(), any())) .thenAnswer(_ => { val buffer = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer] buffersToClose.append(buffer) - mock[RapidsBufferHandle] + mock[RapidsShuffleHandle] }) client = spy(new RapidsShuffleClient( @@ -185,25 +185,30 @@ object RapidsShuffleTestHelper extends MockitoSugar { MetaUtils.buildDegenerateTableMeta(new ColumnarBatch(Array.empty, 123)) } - def withMockContiguousTable[T](numRows: Long)(body: ContiguousTable => T): T = { + def buildContiguousTable(numRows: Long): ContiguousTable = { val rows: Seq[Integer] = (0 until numRows.toInt).map(Int.box) withResource(ColumnVector.fromBoxedInts(rows:_*)) { cvBase => cvBase.incRefCount() val gpuCv = GpuColumnVector.from(cvBase, IntegerType) withResource(new ColumnarBatch(Array(gpuCv))) { cb => withResource(GpuColumnVector.from(cb)) { table => - withResource(table.contiguousSplit(0, numRows.toInt)) { ct => - body(ct(1)) // we get a degenerate table at 0 and another at 2 - } + val cts = table.contiguousSplit() + cts(0) } } } } + def withMockContiguousTable[T](numRows: Long)(body: ContiguousTable => T): T = { + withResource(buildContiguousTable(numRows)) { ct => + body(ct) + } + } + def mockMetaResponse( mockTransaction: Transaction, numRows: Long, - numBatches: Int): (Seq[TableMeta], MetadataTransportBuffer) = + numBatches: Int): (immutable.Seq[TableMeta], MetadataTransportBuffer) = withMockContiguousTable(numRows) { ct => val tableMetas = (0 until numBatches).map(b => buildMockTableMeta(b, ct)) val res = ShuffleMetadata.buildMetaResponse(tableMetas) @@ -214,7 +219,7 @@ object RapidsShuffleTestHelper extends MockitoSugar { def mockDegenerateMetaResponse( mockTransaction: Transaction, - numBatches: Int): (Seq[TableMeta], MetadataTransportBuffer) = { + numBatches: Int): (immutable.Seq[TableMeta], MetadataTransportBuffer) = { val tableMetas = (0 until numBatches).map(b => buildDegenerateMockTableMeta()) val res = ShuffleMetadata.buildMetaResponse(tableMetas) val refCountedRes = new MetadataTransportBuffer(new RefCountedDirectByteBuffer(res)) @@ -246,8 +251,8 @@ object RapidsShuffleTestHelper extends MockitoSugar { tableMeta } - def getShuffleBlocks: Seq[(ShuffleBlockBatchId, Long, Int)] = { - Seq( + def getShuffleBlocks: Array[(ShuffleBlockBatchId, Long, Int)] = { + Array( (ShuffleBlockBatchId(1,1,1,1), 123L, 1), (ShuffleBlockBatchId(2,2,2,2), 456L, 2), (ShuffleBlockBatchId(3,3,3,3), 456L, 3) @@ -261,11 +266,24 @@ object RapidsShuffleTestHelper extends MockitoSugar { bmId } - def getBlocksByAddress: Array[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - val blocksByAddress = new ArrayBuffer[(BlockManagerId, Seq[(BlockId, Long, Int)])]() + def makeIterator(conf: RapidsConf, + transport: RapidsShuffleTransport, + testMetricsUpdater: TestShuffleMetricsUpdater, + taskId: Long, + catalog: ShuffleReceivedBufferCatalog): RapidsShuffleIterator = { + val blocksByAddress = new ArrayBuffer[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]() val blocks = getShuffleBlocks blocksByAddress.append((makeMockBlockManager("2", "2"), blocks)) - blocksByAddress.toArray + spy(new RapidsShuffleIterator( + RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), + conf, + transport, + blocksByAddress.toArray, + testMetricsUpdater, + Array.empty, + taskId, + catalog, + 123)) } } @@ -289,4 +307,3 @@ class MockClientConnection(mockTransaction: Transaction) extends ClientConnectio override def registerReceiveHandler(messageType: MessageType.Value): Unit = {} } - diff --git a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index 89c317f7620..0efcb4f1d7d 100644 --- a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -37,7 +37,7 @@ import scala.collection import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer} -import com.nvidia.spark.rapids.{GpuColumnVector, MetaUtils, RapidsBufferHandle, RapidsConf, RapidsDeviceMemoryStore, RmmSparkRetrySuiteBase, ShuffleMetadata, ShuffleReceivedBufferCatalog} +import com.nvidia.spark.rapids.{GpuColumnVector, MetaUtils, RapidsConf, RapidsShuffleHandle, RmmSparkRetrySuiteBase, ShuffleMetadata, ShuffleReceivedBufferCatalog} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.TableMeta import org.mockito.ArgumentCaptor @@ -77,7 +77,6 @@ abstract class RapidsShuffleTestHelper var mockCopyExecutor: Executor = _ var mockBssExecutor: Executor = _ var mockHandler: RapidsShuffleFetchHandler = _ - var mockStorage: RapidsDeviceMemoryStore = _ var mockCatalog: ShuffleReceivedBufferCatalog = _ var mockConf: RapidsConf = _ var testMetricsUpdater: TestShuffleMetricsUpdater = _ @@ -158,11 +157,11 @@ abstract class RapidsShuffleTestHelper testMetricsUpdater = spy(new TestShuffleMetricsUpdater) val dmbCaptor = ArgumentCaptor.forClass(classOf[DeviceMemoryBuffer]) - when(mockCatalog.addBuffer(dmbCaptor.capture(), any(), any(), any())) + when(mockCatalog.addBuffer(dmbCaptor.capture(), any(), any())) .thenAnswer(_ => { val buffer = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer] buffersToClose.append(buffer) - mock[RapidsBufferHandle] + mock[RapidsShuffleHandle] }) client = spy(new RapidsShuffleClient( @@ -183,25 +182,30 @@ object RapidsShuffleTestHelper extends MockitoSugar { MetaUtils.buildDegenerateTableMeta(new ColumnarBatch(Array.empty, 123)) } - def withMockContiguousTable[T](numRows: Long)(body: ContiguousTable => T): T = { + def buildContiguousTable(numRows: Long): ContiguousTable = { val rows: Seq[Integer] = (0 until numRows.toInt).map(Int.box) withResource(ColumnVector.fromBoxedInts(rows:_*)) { cvBase => cvBase.incRefCount() val gpuCv = GpuColumnVector.from(cvBase, IntegerType) withResource(new ColumnarBatch(Array(gpuCv))) { cb => withResource(GpuColumnVector.from(cb)) { table => - withResource(table.contiguousSplit(0, numRows.toInt)) { ct => - body(ct(1)) // we get a degenerate table at 0 and another at 2 - } + val cts = table.contiguousSplit() + cts(0) } } } } + def withMockContiguousTable[T](numRows: Long)(body: ContiguousTable => T): T = { + withResource(buildContiguousTable(numRows)) { ct => + body(ct) + } + } + def mockMetaResponse( mockTransaction: Transaction, numRows: Long, - numBatches: Int): (Seq[TableMeta], MetadataTransportBuffer) = + numBatches: Int): (collection.Seq[TableMeta], MetadataTransportBuffer) = withMockContiguousTable(numRows) { ct => val tableMetas = (0 until numBatches).map(b => buildMockTableMeta(b, ct)) val res = ShuffleMetadata.buildMetaResponse(tableMetas) @@ -212,7 +216,7 @@ object RapidsShuffleTestHelper extends MockitoSugar { def mockDegenerateMetaResponse( mockTransaction: Transaction, - numBatches: Int): (Seq[TableMeta], MetadataTransportBuffer) = { + numBatches: Int): (collection.Seq[TableMeta], MetadataTransportBuffer) = { val tableMetas = (0 until numBatches).map(b => buildDegenerateMockTableMeta()) val res = ShuffleMetadata.buildMetaResponse(tableMetas) val refCountedRes = new MetadataTransportBuffer(new RefCountedDirectByteBuffer(res)) @@ -244,8 +248,8 @@ object RapidsShuffleTestHelper extends MockitoSugar { tableMeta } - def getShuffleBlocks: collection.Seq[(ShuffleBlockBatchId, Long, Int)] = { - collection.Seq( + def getShuffleBlocks: Array[(ShuffleBlockBatchId, Long, Int)] = { + Array( (ShuffleBlockBatchId(1,1,1,1), 123L, 1), (ShuffleBlockBatchId(2,2,2,2), 456L, 2), (ShuffleBlockBatchId(3,3,3,3), 456L, 3) @@ -259,11 +263,25 @@ object RapidsShuffleTestHelper extends MockitoSugar { bmId } - def getBlocksByAddress: Array[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = { + def makeIterator( + conf: RapidsConf, + transport: RapidsShuffleTransport, + testMetricsUpdater: TestShuffleMetricsUpdater, + taskId: Long, + catalog: ShuffleReceivedBufferCatalog): RapidsShuffleIterator = { val blocksByAddress = new ArrayBuffer[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]() val blocks = getShuffleBlocks blocksByAddress.append((makeMockBlockManager("2", "2"), blocks)) - blocksByAddress.toArray + spy(new RapidsShuffleIterator( + RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), + conf, + transport, + blocksByAddress.toArray, + testMetricsUpdater, + Array.empty, + taskId, + catalog, + 123)) } } @@ -287,4 +305,3 @@ class MockClientConnection(mockTransaction: Transaction) extends ClientConnectio override def registerReceiveHandler(messageType: MessageType.Value): Unit = {} } -