From ed170dc79ae8fe4c814919e150fa8f06c5739c5b Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 28 Aug 2024 08:05:35 -0700 Subject: [PATCH 01/36] Spill framework refactor for better performance and extensibility Signed-off-by: Alessandro Bellina --- .../nvidia/spark/rapids/GpuColumnVector.java | 4 + .../rapids/GpuColumnVectorFromBuffer.java | 25 +- .../rapids/GpuCompressedColumnVector.java | 11 +- .../InternalRowToColumnarBatchIterator.java | 13 +- .../com/nvidia/spark/rapids/implicits.scala | 54 - .../scala/com/nvidia/spark/rapids/Arm.scala | 14 - .../rapids/DeviceMemoryEventHandler.scala | 33 +- .../spark/rapids/GpuDeviceManager.scala | 28 +- .../rapids/GpuShuffledSizedHashJoinExec.scala | 2 +- .../spark/rapids/GpuUserDefinedFunction.scala | 2 +- .../com/nvidia/spark/rapids/HostAlloc.scala | 30 +- .../nvidia/spark/rapids/JoinGatherer.scala | 1 - .../com/nvidia/spark/rapids/MetaUtils.scala | 9 +- .../nvidia/spark/rapids/RapidsBuffer.scala | 485 ----- .../spark/rapids/RapidsBufferCatalog.scala | 1005 ----------- .../spark/rapids/RapidsBufferStore.scala | 640 ------- .../rapids/RapidsDeviceMemoryStore.scala | 518 ------ .../nvidia/spark/rapids/RapidsDiskStore.scala | 256 --- .../spark/rapids/RapidsHostMemoryStore.scala | 484 ----- .../rapids/RapidsSerializerManager.scala | 39 +- .../spark/rapids/ShuffleBufferCatalog.scala | 260 +-- .../rapids/ShuffleReceivedBufferCatalog.scala | 126 +- .../nvidia/spark/rapids/SpillFramework.scala | 1606 +++++++++++++++++ .../spark/rapids/SpillableColumnarBatch.scala | 289 +-- .../rapids/shuffle/BufferSendState.scala | 90 +- .../rapids/shuffle/RapidsShuffleClient.scala | 20 +- .../rapids/shuffle/RapidsShuffleServer.scala | 6 +- .../spark/sql/rapids/GpuShuffleEnv.scala | 10 +- .../spark/sql/rapids/GpuTaskMetrics.scala | 4 +- .../RapidsShuffleInternalManagerBase.scala | 37 +- .../spark/sql/rapids/TempSpillBufferId.scala | 50 - .../execution/GpuBroadcastExchangeExec.scala | 2 +- .../rapids/execution/GpuBroadcastHelper.scala | 5 +- .../shuffle/RapidsShuffleIterator.scala | 23 +- .../sql/rapids/RapidsCachingReader.scala | 73 +- .../shuffle/RapidsShuffleIterator.scala | 24 +- .../sql/rapids/RapidsCachingReader.scala | 69 +- .../nvidia/spark/rapids/HostAllocSuite.scala | 21 +- .../DeviceMemoryEventHandlerSuite.scala | 37 +- ...ternalRowToCudfRowIteratorRetrySuite.scala | 64 +- .../rapids/GpuCoalesceBatchesRetrySuite.scala | 4 +- .../spark/rapids/GpuGenerateSuite.scala | 4 +- .../spark/rapids/GpuPartitioningSuite.scala | 115 +- .../rapids/GpuSinglePartitioningSuite.scala | 6 +- .../rapids/HashAggregateRetrySuite.scala | 247 +-- .../rapids/NonDeterministicRetrySuite.scala | 3 +- .../rapids/RapidsBufferCatalogSuite.scala | 368 ---- .../rapids/RapidsDeviceMemoryStoreSuite.scala | 489 ----- .../spark/rapids/RapidsDiskStoreSuite.scala | 607 ------- .../rapids/RapidsHostMemoryStoreSuite.scala | 614 ------- .../spark/rapids/RmmSparkRetrySuiteBase.scala | 22 +- .../spark/rapids/SerializationSuite.scala | 36 +- .../rapids/ShuffleBufferCatalogSuite.scala | 11 +- .../spark/rapids/SpillFrameworkSuite.scala | 984 ++++++++++ .../spark/rapids/WindowRetrySuite.scala | 10 +- .../nvidia/spark/rapids/WithRetrySuite.scala | 10 +- .../shuffle/RapidsShuffleClientSuite.scala | 11 +- .../shuffle/RapidsShuffleIteratorSuite.scala | 17 +- .../shuffle/RapidsShuffleServerSuite.scala | 354 ++-- .../rapids/timezone/TimeZonePerfSuite.scala | 4 +- .../spark/rapids/timezone/TimeZoneSuite.scala | 3 +- .../rapids/GpuFileFormatDataWriterSuite.scala | 17 +- .../rapids/SpillableColumnarBatchSuite.scala | 62 +- .../shuffle/RapidsShuffleTestHelper.scala | 7 +- .../shuffle/RapidsShuffleTestHelper.scala | 8 +- 65 files changed, 3717 insertions(+), 6765 deletions(-) delete mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala delete mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala delete mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala delete mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala delete mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala delete mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala delete mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/TempSpillBufferId.scala delete mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala delete mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala delete mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala delete mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/SpillFrameworkSuite.scala diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java index 30b24fab11d..07111376440 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java @@ -1052,6 +1052,10 @@ public final int numNulls() { public static long getTotalDeviceMemoryUsed(ColumnarBatch batch) { long sum = 0; + if (batch.numCols() == 1 && batch.column(0) instanceof GpuPackedTableColumn) { + // this is a special case for a packed batch + return ((GpuPackedTableColumn) batch.column(0)).getTableBuffer().getLength(); + } if (batch.numCols() > 0) { if (batch.column(0) instanceof WithTableBuffer) { WithTableBuffer wtb = (WithTableBuffer) batch.column(0); diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java index e23fa76c9f3..b5ed621821b 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -113,6 +113,29 @@ public GpuColumnVectorFromBuffer(DataType type, ColumnVector cudfColumn, this.tableMeta = meta; } + public static boolean isFromBuffer(ColumnarBatch cb) { + if (cb.numCols() > 0) { + long bufferAddr = 0L; + boolean isSet = false; + for (int i = 0; i < cb.numCols(); ++i) { + GpuColumnVectorFromBuffer gcvfb = null; + if (!(cb.column(i) instanceof GpuColumnVectorFromBuffer)) { + return false; + } else { + gcvfb = (GpuColumnVectorFromBuffer) cb.column(i); + if (!isSet) { + bufferAddr = gcvfb.buffer.getAddress(); + isSet = true; + } else if (bufferAddr != gcvfb.buffer.getAddress()) { + return false; + } + } + } + return true; + } + return false; + } + /** * Get the underlying contiguous buffer, shared between columns of the original * `ContiguousTable` diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuCompressedColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuCompressedColumnVector.java index cd34f35ecab..1dc85cb2031 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuCompressedColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuCompressedColumnVector.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,6 +47,15 @@ public static boolean isBatchCompressed(ColumnarBatch batch) { return batch.numCols() == 1 && batch.column(0) instanceof GpuCompressedColumnVector; } + public static ColumnarBatch incRefCounts(ColumnarBatch batch) { + if (!isBatchCompressed(batch)) { + throw new IllegalStateException( + "Attempted to incRefCount for a compressed batch, but the batch was not compressed."); + } + ((GpuCompressedColumnVector)batch.column(0)).buffer.incRefCount(); + return batch; + } + /** * Build a columnar batch from a compressed data buffer and specified table metadata * NOTE: The data remains compressed and cannot be accessed directly from the columnar batch. diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java index 0aa3f0978e9..400b54626d8 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java @@ -21,19 +21,11 @@ import java.util.NoSuchElementException; import java.util.Optional; -import com.nvidia.spark.Retryable; import scala.Option; import scala.Tuple2; import scala.collection.Iterator; -import ai.rapids.cudf.ColumnVector; -import ai.rapids.cudf.DType; -import ai.rapids.cudf.HostColumnVector; -import ai.rapids.cudf.HostColumnVectorCore; -import ai.rapids.cudf.HostMemoryBuffer; -import ai.rapids.cudf.NvtxColor; -import ai.rapids.cudf.NvtxRange; -import ai.rapids.cudf.Table; +import ai.rapids.cudf.*; import com.nvidia.spark.rapids.jni.RowConversion; import com.nvidia.spark.rapids.shims.CudfUnsafeRow; @@ -236,8 +228,7 @@ private HostMemoryBuffer[] getHostBuffersWithRetry( try { hBuf = HostAlloc$.MODULE$.alloc((dataBytes + offsetBytes),true); SpillableHostBuffer sBuf = SpillableHostBuffer$.MODULE$.apply(hBuf, hBuf.getLength(), - SpillPriorities$.MODULE$.ACTIVE_ON_DECK_PRIORITY(), - RapidsBufferCatalog$.MODULE$.singleton()); + SpillPriorities$.MODULE$.ACTIVE_ON_DECK_PRIORITY()); hBuf = null; // taken over by spillable host buffer return Tuple2.apply(sBuf, numRowsWrapper); } finally { diff --git a/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala b/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala index eddad69ba97..89e717788f9 100644 --- a/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala +++ b/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala @@ -63,26 +63,6 @@ object RapidsPluginImplicits { } } - implicit class RapidsBufferColumn(rapidsBuffer: RapidsBuffer) { - - /** - * safeFree: Is an implicit on RapidsBuffer class that tries to free the resource, if an - * Exception was thrown prior to this free, it adds the new exception to the suppressed - * exceptions, otherwise just throws - * - * @param e Exception which we don't want to suppress - */ - def safeFree(e: Throwable = null): Unit = { - if (rapidsBuffer != null) { - try { - rapidsBuffer.free() - } catch { - case suppressed: Throwable if e != null => e.addSuppressed(suppressed) - } - } - } - } - implicit class AutoCloseableSeq[A <: AutoCloseable](val in: collection.SeqLike[A, _]) { /** * safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each @@ -111,46 +91,12 @@ object RapidsPluginImplicits { } } - implicit class RapidsBufferSeq[A <: RapidsBuffer](val in: collection.SeqLike[A, _]) { - /** - * safeFree: Is an implicit on a sequence of RapidsBuffer classes that tries to free each - * element of the sequence, even if prior free calls fail. In case of failure in any of the - * free calls, an Exception is thrown containing the suppressed exceptions (getSuppressed), - * if any. - */ - def safeFree(error: Throwable = null): Unit = if (in != null) { - var freeException: Throwable = null - in.foreach { element => - if (element != null) { - try { - element.free() - } catch { - case e: Throwable if error != null => error.addSuppressed(e) - case e: Throwable if freeException == null => freeException = e - case e: Throwable => freeException.addSuppressed(e) - } - } - } - if (freeException != null) { - // an exception happened while we were trying to safely free - // resources, throw the exception to alert the caller - throw freeException - } - } - } - implicit class AutoCloseableArray[A <: AutoCloseable](val in: Array[A]) { def safeClose(e: Throwable = null): Unit = if (in != null) { in.toSeq.safeClose(e) } } - implicit class RapidsBufferArray[A <: RapidsBuffer](val in: Array[A]) { - def safeFree(e: Throwable = null): Unit = if (in != null) { - in.toSeq.safeFree(e) - } - } - class MapsSafely[A, Repr] { /** * safeMap: safeMap implementation that is leveraged by other type-specific implicits. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala index 926f770a683..848c8ce1b81 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala @@ -163,20 +163,6 @@ object Arm extends ArmScalaSpecificImpl { } } - /** Executes the provided code block, freeing the RapidsBuffer only if an exception occurs */ - def freeOnExcept[T <: RapidsBuffer, V](r: T)(block: T => V): V = { - try { - block(r) - } catch { - case t: ControlThrowable => - // Don't close for these cases.. - throw t - case t: Throwable => - r.safeFree(t) - throw t - } - } - /** Executes the provided code block and then closes the resource */ def withResource[T <: AutoCloseable, V](h: CloseableHolder[T]) (block: CloseableHolder[T] => V): V = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala index 72808d1f376..0e415ca06dc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,8 +34,7 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil * depleting the device store */ class DeviceMemoryEventHandler( - catalog: RapidsBufferCatalog, - store: RapidsDeviceMemoryStore, + store: SpillableDeviceStore, oomDumpDir: Option[String], maxFailedOOMRetries: Int) extends RmmEventHandler with Logging { @@ -92,8 +91,8 @@ class DeviceMemoryEventHandler( * from cuDF. If we succeed, cuDF resets `retryCount`, and so the new count sent to us * must be <= than what we saw last, so we can reset our tracking. */ - def resetIfNeeded(retryCount: Int, storeSpillableSize: Long): Unit = { - if (storeSpillableSize != 0 || retryCount <= retryCountLastSynced) { + def resetIfNeeded(retryCount: Int, couldSpill: Boolean): Unit = { + if (couldSpill || retryCount <= retryCountLastSynced) { reset() } } @@ -114,9 +113,6 @@ class DeviceMemoryEventHandler( s"onAllocFailure invoked with invalid retryCount $retryCount") try { - val storeSize = store.currentSize - val storeSpillableSize = store.currentSpillableSize - val attemptMsg = if (retryCount > 0) { s"Attempt ${retryCount}. " } else { @@ -124,12 +120,13 @@ class DeviceMemoryEventHandler( } val retryState = oomRetryState.get() - retryState.resetIfNeeded(retryCount, storeSpillableSize) - logInfo(s"Device allocation of $allocSize bytes failed, device store has " + - s"$storeSize total and $storeSpillableSize spillable bytes. $attemptMsg" + - s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes. ") - if (storeSpillableSize == 0) { + val amountSpilled = store.spill(allocSize) + retryState.resetIfNeeded(retryCount, amountSpilled > 0) + logInfo(s"Device allocation of $allocSize bytes failed. " + + s"Device store spilled $amountSpilled bytes. $attemptMsg " + + s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.") + val shouldRetry = if (amountSpilled == 0) { if (retryState.shouldTrySynchronizing(retryCount)) { Cuda.deviceSynchronize() logWarning(s"[RETRY ${retryState.getRetriesSoFar}] " + @@ -149,15 +146,11 @@ class DeviceMemoryEventHandler( false } } else { - val targetSize = Math.max(storeSpillableSize - allocSize, 0) - logDebug(s"Targeting device store size of $targetSize bytes") - val maybeAmountSpilled = catalog.synchronousSpill(store, targetSize, Cuda.DEFAULT_STREAM) - maybeAmountSpilled.foreach { amountSpilled => - logInfo(s"Spilled $amountSpilled bytes from the device store") - TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled) - } + TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled) true } + + shouldRetry } catch { case t: Throwable => logError(s"Error handling allocation failure", t) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala index b0c86773166..f71ae813010 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import ai.rapids.cudf._ +import com.nvidia.spark.rapids.jni.RmmSpark import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging @@ -169,7 +170,9 @@ object GpuDeviceManager extends Logging { chunkedPackMemoryResource = None poolSizeLimit = 0L - RapidsBufferCatalog.close() + SpillFramework.shutdown() + RmmSpark.clearEventHandler() + Rmm.clearEventHandler() GpuShuffleEnv.shutdown() // try to avoid segfault on RMM shutdown val timeout = System.nanoTime() + TimeUnit.SECONDS.toNanos(10) @@ -278,6 +281,8 @@ object GpuDeviceManager extends Logging { } } + private var memoryEventHandler: DeviceMemoryEventHandler = _ + private def initializeRmm(gpuId: Int, rapidsConf: Option[RapidsConf]): Unit = { if (!Rmm.isInitialized) { val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf)) @@ -385,8 +390,25 @@ object GpuDeviceManager extends Logging { } } - RapidsBufferCatalog.init(conf) - GpuShuffleEnv.init(conf, RapidsBufferCatalog.getDiskBlockManager()) + SpillFramework.initialize(conf) + + memoryEventHandler = new DeviceMemoryEventHandler( + SpillFramework.stores.deviceStore, + conf.gpuOomDumpDir, + conf.gpuOomMaxRetries) + + if (conf.sparkRmmStateEnable) { + val debugLoc = if (conf.sparkRmmDebugLocation.isEmpty) { + null + } else { + conf.sparkRmmDebugLocation + } + RmmSpark.setEventHandler(memoryEventHandler, debugLoc) + } else { + logWarning("SparkRMM retry has been disabled") + Rmm.setEventHandler(memoryEventHandler) + } + GpuShuffleEnv.init(conf) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala index 252c31da125..1217e2252e3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala @@ -1110,7 +1110,7 @@ class CudfSpillableHostConcatResult( val hmb: HostMemoryBuffer) extends SpillableHostConcatResult { override def toBatch: ColumnarBatch = { - closeOnExcept(buffer.getHostBuffer()) { hostBuf => + closeOnExcept(buffer.getHostBuffer) { hostBuf => SerializedTableColumn.from(header, hostBuf) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala index 90fe8b29e3d..eae23b86dd5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala @@ -123,7 +123,7 @@ trait GpuRowBasedUserDefinedFunction extends GpuExpression val retConverter = GpuRowToColumnConverter.getConverterForType(dataType, nullable) val retType = GpuColumnVector.convertFrom(dataType, nullable) val retRow = new GenericInternalRow(size = 1) - closeOnExcept(new RapidsHostColumnBuilder(retType, batch.numRows)) { builder => + withResource(new RapidsHostColumnBuilder(retType, batch.numRows)) { builder => /** * This `nullSafe` is for https://github.com/NVIDIA/spark-rapids/issues/3942. * And more details can be found from diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala index 7223463b8b7..ea61d640906 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala @@ -120,34 +120,26 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L require(retryCount >= 0, s"spillAndCheckRetry invoked with invalid retryCount $retryCount") - val store = RapidsBufferCatalog.getHostStorage - val storeSize = store.currentSize - val storeSpillableSize = store.currentSpillableSize + val store = SpillFramework.stores.hostStore val totalSize: Long = synchronized { currentPinnedAllocated + currentNonPinnedAllocated } - val attemptMsg = if (retryCount > 0) { - s"Attempt $retryCount" - } else { - "First attempt" - } + // TODO: AB fix this + //val attemptMsg = if (retryCount > 0) { + // s"Attempt $retryCount" + //} else { + // "First attempt" + //} - logInfo(s"Host allocation of $allocSize bytes failed, host store has " + - s"$storeSize total and $storeSpillableSize spillable bytes. $attemptMsg.") - if (storeSpillableSize == 0) { + val amountSpilled = store.spill(allocSize) + + if (amountSpilled == 0) { logWarning(s"Host store exhausted, unable to allocate $allocSize bytes. " + s"Total host allocated is $totalSize bytes.") false } else { - val targetSize = Math.max(storeSpillableSize - allocSize, 0) - logDebug(s"Targeting host store size of $targetSize bytes") - // We could not make it work so try and spill enough to make it work - val maybeAmountSpilled = - RapidsBufferCatalog.synchronousSpill(RapidsBufferCatalog.getHostStorage, targetSize) - maybeAmountSpilled.foreach { amountSpilled => - logInfo(s"Spilled $amountSpilled bytes from the host store") - } + logInfo(s"Spilled $amountSpilled bytes from the host store") true } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala index c4584086173..771d91ddada 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala @@ -307,7 +307,6 @@ class LazySpillableColumnarBatchImpl( spill = Some(SpillableColumnarBatch(cached.get, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) } finally { - // Putting data in a SpillableColumnarBatch takes ownership of it. cached = None } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala index 80acddcb257..f1561e2c251 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,6 +51,13 @@ object MetaUtils { ct.getMetadataDirectBuffer, ct.getRowCount) + def buildTableMeta(tableId: Int, compressed: GpuCompressedColumnVector): TableMeta = + buildTableMeta( + tableId, + compressed.getTableBuffer.getLength, + compressed.getTableMeta.bufferMeta().getByteBuffer, + compressed.getTableMeta.rowCount()) + def buildTableMeta(tableId: Int, bufferSize: Long, packedMeta: ByteBuffer, rowCount: Long): TableMeta = { val fbb = new FlatBufferBuilder(1024) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala deleted file mode 100644 index a332755745f..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala +++ /dev/null @@ -1,485 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.File -import java.nio.channels.WritableByteChannel - -import scala.collection.mutable.ArrayBuffer - -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Table} -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.rapids.RapidsDiskBlockManager -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * An identifier for a RAPIDS buffer that can be automatically spilled between buffer stores. - * NOTE: Derived classes MUST implement proper hashCode and equals methods, as these objects are - * used as keys in hash maps. Scala case classes are recommended. - */ -trait RapidsBufferId { - val tableId: Int - - /** - * Indicates whether the buffer may share a spill file with other buffers. - * If false then the spill file will be automatically removed when the buffer is freed. - * If true then the spill file will not be automatically removed, and another subsystem needs - * to be responsible for cleaning up the spill files for those types of buffers. - */ - val canShareDiskPaths: Boolean = false - - /** - * Generate a path to a local file that can be used to spill the corresponding buffer to disk. - * The path must be unique across all buffers unless canShareDiskPaths is true. - */ - def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File -} - -/** Enumeration of the storage tiers */ -object StorageTier extends Enumeration { - type StorageTier = Value - val DEVICE: StorageTier = Value(0, "device memory") - val HOST: StorageTier = Value(1, "host memory") - val DISK: StorageTier = Value(2, "local disk") -} - -/** - * ChunkedPacker is an Iterator that uses a cudf::chunked_pack to copy a cuDF `Table` - * to a target buffer in chunks. - * - * Each chunk is sized at most `bounceBuffer.getLength`, and the caller should cudaMemcpy - * bytes from `bounceBuffer` to a target buffer after each call to `next()`. - * - * @note `ChunkedPacker` must be closed by the caller as it has GPU and host resources - * associated with it. - * - * @param id The RapidsBufferId for this pack operation to be included in the metadata - * @param table cuDF Table to chunk_pack - * @param bounceBuffer GPU memory to be used for packing. The buffer should be at least 1MB - * in length. - */ -class ChunkedPacker( - id: RapidsBufferId, - table: Table, - bounceBuffer: DeviceMemoryBuffer) - extends Iterator[MemoryBuffer] - with Logging - with AutoCloseable { - - private var closed: Boolean = false - - // When creating cudf::chunked_pack use a pool if available, otherwise default to the - // per-device memory resource - private val chunkedPack = { - val pool = GpuDeviceManager.chunkedPackMemoryResource - val cudfChunkedPack = try { - pool.flatMap { chunkedPool => - Some(table.makeChunkedPack(bounceBuffer.getLength, chunkedPool)) - } - } catch { - case _: OutOfMemoryError => - if (!ChunkedPacker.warnedAboutPoolFallback) { - ChunkedPacker.warnedAboutPoolFallback = true - logWarning( - s"OOM while creating chunked_pack using pool sized ${pool.map(_.getMaxSize)}B. " + - "Falling back to the per-device memory resource.") - } - None - } - - // if the pool is not configured, or we got an OOM, try again with the per-device pool - cudfChunkedPack.getOrElse { - table.makeChunkedPack(bounceBuffer.getLength) - } - } - - private val tableMeta = withResource(chunkedPack.buildMetadata()) { packedMeta => - MetaUtils.buildTableMeta( - id.tableId, - chunkedPack.getTotalContiguousSize, - packedMeta.getMetadataDirectBuffer, - table.getRowCount) - } - - // take out a lease on the bounce buffer - bounceBuffer.incRefCount() - - def getTotalContiguousSize: Long = chunkedPack.getTotalContiguousSize - - def getMeta: TableMeta = { - tableMeta - } - - override def hasNext: Boolean = synchronized { - if (closed) { - throw new IllegalStateException(s"ChunkedPacker for $id is closed") - } - chunkedPack.hasNext - } - - def next(): MemoryBuffer = synchronized { - if (closed) { - throw new IllegalStateException(s"ChunkedPacker for $id is closed") - } - val bytesWritten = chunkedPack.next(bounceBuffer) - // we increment the refcount because the caller has no idea where - // this memory came from, so it should close it. - bounceBuffer.slice(0, bytesWritten) - } - - override def close(): Unit = synchronized { - if (!closed) { - closed = true - val toClose = new ArrayBuffer[AutoCloseable]() - toClose.append(chunkedPack, bounceBuffer) - toClose.safeClose() - } - } -} - -object ChunkedPacker { - private var warnedAboutPoolFallback: Boolean = false -} - -/** - * This iterator encapsulates a buffer's internal `MemoryBuffer` access - * for spill reasons. Internally, there are two known implementations: - * - either this is a "single shot" copy, where the entirety of the `RapidsBuffer` is - * already represented as a single contiguous blob of memory, then the expectation - * is that this iterator is exhausted with a single call to `next` - * - or, we have a `RapidsBuffer` that isn't contiguous. This iteration will then - * drive a `ChunkedPacker` to pack the `RapidsBuffer`'s table as needed. The - * iterator will likely need several calls to `next` to be exhausted. - * - * @param buffer `RapidsBuffer` to copy out of its tier. - */ -class RapidsBufferCopyIterator(buffer: RapidsBuffer) - extends Iterator[MemoryBuffer] with AutoCloseable with Logging { - - private val chunkedPacker: Option[ChunkedPacker] = if (buffer.supportsChunkedPacker) { - Some(buffer.makeChunkedPacker) - } else { - None - } - def isChunked: Boolean = chunkedPacker.isDefined - - // this is used for the single shot case to flag when `next` is call - // to satisfy the Iterator interface - private var singleShotCopyHasNext: Boolean = false - private var singleShotBuffer: MemoryBuffer = _ - - if (!isChunked) { - singleShotCopyHasNext = true - singleShotBuffer = buffer.getMemoryBuffer - } - - override def hasNext: Boolean = - chunkedPacker.map(_.hasNext).getOrElse(singleShotCopyHasNext) - - override def next(): MemoryBuffer = { - require(hasNext, - "next called on exhausted iterator") - chunkedPacker.map(_.next()).getOrElse { - singleShotCopyHasNext = false - singleShotBuffer.slice(0, singleShotBuffer.getLength) - } - } - - def getTotalCopySize: Long = { - chunkedPacker - .map(_.getTotalContiguousSize) - .getOrElse(singleShotBuffer.getLength) - } - - override def close(): Unit = { - val toClose = new ArrayBuffer[AutoCloseable]() - toClose.appendAll(chunkedPacker) - toClose.appendAll(Option(singleShotBuffer)) - - toClose.safeClose() - } -} - -/** Interface provided by all types of RAPIDS buffers */ -trait RapidsBuffer extends AutoCloseable { - /** The buffer identifier for this buffer. */ - val id: RapidsBufferId - - /** - * The size of this buffer in bytes in its _current_ store. As the buffer goes through - * contiguous split (either added as a contiguous table already, or spilled to host), - * its size changes because contiguous_split adds its own alignment padding. - * - * @note Do not use this size to allocate a target buffer to copy, always use `getPackedSize.` - */ - val memoryUsedBytes: Long - - /** - * The size of this buffer if it has already gone through contiguous_split. - * - * @note Use this function when allocating a target buffer for spill or shuffle purposes. - */ - def getPackedSizeBytes: Long = memoryUsedBytes - - /** - * At spill time, obtain an iterator used to copy this buffer to a different tier. - */ - def getCopyIterator: RapidsBufferCopyIterator = - new RapidsBufferCopyIterator(this) - - /** Descriptor for how the memory buffer is formatted */ - def meta: TableMeta - - /** The storage tier for this buffer */ - val storageTier: StorageTier - - /** - * Get the columnar batch within this buffer. The caller must have - * successfully acquired the buffer beforehand. - * @param sparkTypes the spark data types the batch should have - * @see [[addReference]] - * @note It is the responsibility of the caller to close the batch. - * @note If the buffer is compressed data then the resulting batch will be built using - * `GpuCompressedColumnVector`, and it is the responsibility of the caller to deal - * with decompressing the data if necessary. - */ - def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch - - /** - * Get the host-backed columnar batch from this buffer. The caller must have - * successfully acquired the buffer beforehand. - * - * If this `RapidsBuffer` was added originally to the device tier, or if this is - * a just a buffer (not a batch), this function will throw. - * - * @param sparkTypes the spark data types the batch should have - * @see [[addReference]] - * @note It is the responsibility of the caller to close the batch. - */ - def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - throw new IllegalStateException(s"$this does not support host columnar batches.") - } - - /** - * Get the underlying memory buffer. This may be either a HostMemoryBuffer or a DeviceMemoryBuffer - * depending on where the buffer currently resides. - * The caller must have successfully acquired the buffer beforehand. - * @see [[addReference]] - * @note It is the responsibility of the caller to close the buffer. - */ - def getMemoryBuffer: MemoryBuffer - - val supportsChunkedPacker: Boolean = false - - /** - * Makes a new chunked packer. It is the responsibility of the caller to close this. - */ - def makeChunkedPacker: ChunkedPacker = { - throw new NotImplementedError("not implemented for this store") - } - - /** - * Copy the content of this buffer into the specified memory buffer, starting from the given - * offset. - * - * @param srcOffset offset to start copying from. - * @param dst the memory buffer to copy into. - * @param dstOffset offset to copy into. - * @param length number of bytes to copy. - * @param stream CUDA stream to use - */ - def copyToMemoryBuffer( - srcOffset: Long, dst: MemoryBuffer, dstOffset: Long, length: Long, stream: Cuda.Stream): Unit - - /** - * Get the device memory buffer from the underlying storage. If the buffer currently resides - * outside of device memory, a new DeviceMemoryBuffer is created with the data copied over. - * The caller must have successfully acquired the buffer beforehand. - * @see [[addReference]] - * @note It is the responsibility of the caller to close the buffer. - */ - def getDeviceMemoryBuffer: DeviceMemoryBuffer - - /** - * Get the host memory buffer from the underlying storage. If the buffer currently resides - * outside of host memory, a new HostMemoryBuffer is created with the data copied over. - * The caller must have successfully acquired the buffer beforehand. - * @see [[addReference]] - * @note It is the responsibility of the caller to close the buffer. - */ - def getHostMemoryBuffer: HostMemoryBuffer - - /** - * Try to add a reference to this buffer to acquire it. - * @note The close method must be called for every successfully obtained reference. - * @return true if the reference was added or false if this buffer is no longer valid - */ - def addReference(): Boolean - - /** - * Schedule the release of the buffer's underlying resources. - * Subsequent attempts to acquire the buffer will fail. As soon as the - * buffer has no outstanding references, the resources will be released. - *

- * This is separate from the close method which does not normally release - * resources. close will only release resources if called as the last - * outstanding reference and the buffer was previously marked as freed. - */ - def free(): Unit - - /** - * Get the spill priority value for this buffer. Lower values are higher - * priority for spilling, meaning buffers with lower values will be - * preferred for spilling over buffers with a higher value. - */ - def getSpillPriority: Long - - /** - * Set the spill priority for this buffer. Lower values are higher priority - * for spilling, meaning buffers with lower values will be preferred for - * spilling over buffers with a higher value. - * @note should only be called from the buffer catalog - * @param priority new priority value for this buffer - */ - def setSpillPriority(priority: Long): Unit - - /** - * Function invoked by the `RapidsBufferStore.addBuffer` method that prompts - * the specific `RapidsBuffer` to check its reference counting to make itself - * spillable or not. Only `RapidsTable` and `RapidsHostMemoryBuffer` implement - * this method. - */ - def updateSpillability(): Unit = {} - - /** - * Obtains a read lock on this instance of `RapidsBuffer` and calls the function - * in `body` while holding the lock. - * @param body function that takes a `MemoryBuffer` and produces `K` - * @tparam K any return type specified by `body` - * @return the result of body(memoryBuffer) - */ - def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K - - /** - * Obtains a write lock on this instance of `RapidsBuffer` and calls the function - * in `body` while holding the lock. - * @param body function that takes a `MemoryBuffer` and produces `K` - * @tparam K any return type specified by `body` - * @return the result of body(memoryBuffer) - */ - def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K -} - -/** - * A buffer with no corresponding device data (zero rows or columns). - * These buffers are not tracked in buffer stores since they have no - * device memory. They are only tracked in the catalog and provide - * a representative `ColumnarBatch` but cannot provide a - * `MemoryBuffer`. - * @param id buffer ID to associate with the buffer - * @param meta schema metadata - */ -sealed class DegenerateRapidsBuffer( - override val id: RapidsBufferId, - override val meta: TableMeta) extends RapidsBuffer { - - override val memoryUsedBytes: Long = 0L - - override val storageTier: StorageTier = StorageTier.DEVICE - - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - val rowCount = meta.rowCount - val packedMeta = meta.packedMetaAsByteBuffer() - if (packedMeta != null) { - withResource(DeviceMemoryBuffer.allocate(0)) { deviceBuffer => - withResource(Table.fromPackedTable(meta.packedMetaAsByteBuffer(), deviceBuffer)) { table => - GpuColumnVectorFromBuffer.from(table, deviceBuffer, meta, sparkTypes) - } - } - } else { - // no packed metadata, must be a table with zero columns - new ColumnarBatch(Array.empty, rowCount.toInt) - } - } - - override def free(): Unit = {} - - override def getMemoryBuffer: MemoryBuffer = - throw new UnsupportedOperationException("degenerate buffer has no memory buffer") - - override def copyToMemoryBuffer(srcOffset: Long, dst: MemoryBuffer, dstOffset: Long, length: Long, - stream: Cuda.Stream): Unit = - throw new UnsupportedOperationException("degenerate buffer cannot copy to memory buffer") - - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = - throw new UnsupportedOperationException("degenerate buffer has no device memory buffer") - - override def getHostMemoryBuffer: HostMemoryBuffer = - throw new UnsupportedOperationException("degenerate buffer has no host memory buffer") - - override def addReference(): Boolean = true - - override def getSpillPriority: Long = Long.MaxValue - - override def setSpillPriority(priority: Long): Unit = {} - - override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { - throw new UnsupportedOperationException("degenerate buffer has no memory buffer") - } - - override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { - throw new UnsupportedOperationException("degenerate buffer has no memory buffer") - } - - override def close(): Unit = {} -} - -trait RapidsHostBatchBuffer extends AutoCloseable { - /** - * Get the host-backed columnar batch from this buffer. The caller must have - * successfully acquired the buffer beforehand. - * - * If this `RapidsBuffer` was added originally to the device tier, or if this is - * a just a buffer (not a batch), this function will throw. - * - * @param sparkTypes the spark data types the batch should have - * @see [[addReference]] - * @note It is the responsibility of the caller to close the batch. - */ - def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch - - val memoryUsedBytes: Long -} - -trait RapidsBufferChannelWritable { - /** - * At spill time, write this buffer to an nio WritableByteChannel. - * @param writableChannel that this buffer can just write itself to, either byte-for-byte - * or via serialization if needed. - * @param stream the Cuda.Stream for the spilling thread. If the `RapidsBuffer` that - * implements this method is on the device, synchronization may be needed - * for staged copies. - * @return the amount of bytes written to the channel - */ - def writeToChannel(writableChannel: WritableByteChannel, stream: Cuda.Stream): Long -} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala deleted file mode 100644 index f61291a31ce..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ /dev/null @@ -1,1005 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.util.concurrent.ConcurrentHashMap -import java.util.function.BiFunction - -import scala.collection.JavaConverters.collectionAsScalaIterableConverter - -import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Rmm, Table} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} -import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcquire -import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta -import com.nvidia.spark.rapids.jni.RmmSpark - -import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.internal.Logging -import org.apache.spark.sql.rapids.{RapidsDiskBlockManager, TempSpillBufferId} -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * Exception thrown when inserting a buffer into the catalog with a duplicate buffer ID - * and storage tier combination. - */ -class DuplicateBufferException(s: String) extends RuntimeException(s) {} - -/** - * An object that client code uses to interact with an underlying RapidsBufferId. - * - * A handle is obtained when a buffer, batch, or table is added to the spill framework - * via the `RapidsBufferCatalog` api. - */ -trait RapidsBufferHandle extends AutoCloseable { - val id: RapidsBufferId - - /** - * Sets the spill priority for this handle and updates the maximum priority - * for the underlying `RapidsBuffer` if this new priority is the maximum. - * @param newPriority new priority for this handle - */ - def setSpillPriority(newPriority: Long): Unit -} - -/** - * Catalog for lookup of buffers by ID. The constructor is only visible for testing, generally - * `RapidsBufferCatalog.singleton` should be used instead. - */ -class RapidsBufferCatalog( - deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.deviceStorage, - hostStorage: RapidsHostMemoryStore = RapidsBufferCatalog.hostStorage) - extends AutoCloseable with Logging { - - /** Map of buffer IDs to buffers sorted by storage tier */ - private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBuffer]] - - /** Map of buffer IDs to buffer handles in insertion order */ - private[this] val bufferIdToHandles = - new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBufferHandleImpl]]() - - /** A counter used to skip a spill attempt if we detect a different thread has spilled */ - @volatile private[this] var spillCount: Integer = 0 - - class RapidsBufferHandleImpl( - override val id: RapidsBufferId, - var priority: Long) - extends RapidsBufferHandle { - - private var closed = false - - override def toString: String = - s"buffer handle $id at $priority" - - override def setSpillPriority(newPriority: Long): Unit = { - priority = newPriority - updateUnderlyingRapidsBuffer(this) - } - - /** - * Get the spill priority that was associated with this handle. Since there can - * be multiple handles associated with one `RapidsBuffer`, the priority returned - * here is only useful for code in the catalog that updates the maximum priority - * for the underlying `RapidsBuffer` as handles are added and removed. - * - * @return this handle's spill priority - */ - def getSpillPriority: Long = priority - - override def close(): Unit = synchronized { - // since the handle is stored in the catalog in addition to being - // handed out to potentially a `SpillableColumnarBatch` or `SpillableBuffer` - // there is a chance we may double close it. For example, a broadcast exec - // that is closing its spillable (and therefore the handle) + the handle being - // closed from the catalog's close method. - if (!closed) { - removeBuffer(this) - } - closed = true - } - } - - /** - * Makes a new `RapidsBufferHandle` associated with `id`, keeping track - * of the spill priority and callback within this handle. - * - * This function also adds the handle for internal tracking in the catalog. - * - * @param id the `RapidsBufferId` that this handle refers to - * @param spillPriority the spill priority specified on creation of the handle - * @note public for testing - * @return a new instance of `RapidsBufferHandle` - */ - def makeNewHandle( - id: RapidsBufferId, - spillPriority: Long): RapidsBufferHandle = { - val handle = new RapidsBufferHandleImpl(id, spillPriority) - trackNewHandle(handle) - handle - } - - /** - * Adds a handle to the internal `bufferIdToHandles` map. - * - * The priority and callback of the `RapidsBuffer` will also be updated. - * - * @param handle handle to start tracking - */ - private def trackNewHandle(handle: RapidsBufferHandleImpl): Unit = { - bufferIdToHandles.compute(handle.id, (_, h) => { - var handles = h - if (handles == null) { - handles = Seq.empty[RapidsBufferHandleImpl] - } - handles :+ handle - }) - updateUnderlyingRapidsBuffer(handle) - } - - /** - * Called when the `RapidsBufferHandle` is no longer needed by calling code - * - * If this is the last handle associated with a `RapidsBuffer`, `stopTrackingHandle` - * returns true, otherwise it returns false. - * - * @param handle handle to stop tracking - * @return true: if this was the last `RapidsBufferHandle` associated with the - * underlying buffer. - * false: if there are remaining live handles - */ - private def stopTrackingHandle(handle: RapidsBufferHandle): Boolean = { - withResource(acquireBuffer(handle)) { buffer => - val id = handle.id - var maxPriority = Long.MinValue - val newHandles = bufferIdToHandles.compute(id, (_, handles) => { - if (handles == null) { - throw new IllegalStateException( - s"$id not found and we attempted to remove handles!") - } - if (handles.size == 1) { - require(handles.head == handle, - "Tried to remove a single handle, and we couldn't match on it") - null - } else { - val newHandles = handles.filter(h => h != handle).map { h => - maxPriority = maxPriority.max(h.getSpillPriority) - h - } - if (newHandles.isEmpty) { - null // remove since no more handles exist, should not happen - } else { - newHandles - } - } - }) - - if (newHandles == null) { - // tell calling code that no more handles exist, - // for this RapidsBuffer - true - } else { - // more handles remain, our priority changed so we need to update things - buffer.setSpillPriority(maxPriority) - false // we have handles left - } - } - } - - /** - * Adds a buffer to the catalog and store. This does NOT take ownership of the - * buffer, so it is the responsibility of the caller to close it. - * - * This version of `addBuffer` should not be called from the shuffle catalogs - * since they provide their own ids. - * - * @param buffer buffer that will be owned by the store - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return RapidsBufferHandle handle for this buffer - */ - def addBuffer( - buffer: MemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean = true): RapidsBufferHandle = synchronized { - // first time we see `buffer` - val existing = getExistingRapidsBufferAndAcquire(buffer) - existing match { - case None => - addBuffer( - TempSpillBufferId(), - buffer, - tableMeta, - initialSpillPriority, - needsSync) - case Some(rapidsBuffer) => - withResource(rapidsBuffer) { _ => - makeNewHandle(rapidsBuffer.id, initialSpillPriority) - } - } - } - - /** - * Adds a contiguous table to the device storage. This does NOT take ownership of the - * contiguous table, so it is the responsibility of the caller to close it. The refcount of the - * underlying device buffer will be incremented so the contiguous table can be closed before - * this buffer is destroyed. - * - * This version of `addContiguousTable` should not be called from the shuffle catalogs - * since they provide their own ids. - * - * @param contigTable contiguous table to track in storage - * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return RapidsBufferHandle handle for this table - */ - def addContiguousTable( - contigTable: ContiguousTable, - initialSpillPriority: Long, - needsSync: Boolean = true): RapidsBufferHandle = synchronized { - val existing = getExistingRapidsBufferAndAcquire(contigTable.getBuffer) - existing match { - case None => - addContiguousTable( - TempSpillBufferId(), - contigTable, - initialSpillPriority, - needsSync) - case Some(rapidsBuffer) => - withResource(rapidsBuffer) { _ => - makeNewHandle(rapidsBuffer.id, initialSpillPriority) - } - } - } - - /** - * Adds a contiguous table to the device storage. This does NOT take ownership of the - * contiguous table, so it is the responsibility of the caller to close it. The refcount of the - * underlying device buffer will be incremented so the contiguous table can be closed before - * this buffer is destroyed. - * - * @param id the RapidsBufferId to use for this buffer - * @param contigTable contiguous table to track in storage - * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return RapidsBufferHandle handle for this table - */ - def addContiguousTable( - id: RapidsBufferId, - contigTable: ContiguousTable, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = synchronized { - addBuffer( - id, - contigTable.getBuffer, - MetaUtils.buildTableMeta(id.tableId, contigTable), - initialSpillPriority, - needsSync) - } - - /** - * Adds a buffer to either the device or host storage. This does NOT take - * ownership of the buffer, so it is the responsibility of the caller to close it. - * - * @param id the RapidsBufferId to use for this buffer - * @param buffer buffer that will be owned by the target store - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this buffer (defaults to true) - * @return RapidsBufferHandle handle for this RapidsBuffer - */ - def addBuffer( - id: RapidsBufferId, - buffer: MemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = synchronized { - val rapidsBuffer = buffer match { - case gpuBuffer: DeviceMemoryBuffer => - deviceStorage.addBuffer( - id, - gpuBuffer, - tableMeta, - initialSpillPriority, - needsSync) - case hostBuffer: HostMemoryBuffer => - hostStorage.addBuffer( - id, - hostBuffer, - tableMeta, - initialSpillPriority, - needsSync) - case _ => - throw new IllegalArgumentException( - s"Cannot call addBuffer with buffer $buffer") - } - registerNewBuffer(rapidsBuffer) - makeNewHandle(id, initialSpillPriority) - } - - /** - * Adds a batch to the device storage. This does NOT take ownership of the - * batch, so it is the responsibility of the caller to close it. - * - * @param batch batch that will be added to the store - * @param initialSpillPriority starting spill priority value for the batch - * @param needsSync whether the spill framework should stream synchronize while adding - * this batch (defaults to true) - * @return RapidsBufferHandle handle for this RapidsBuffer - */ - def addBatch( - batch: ColumnarBatch, - initialSpillPriority: Long, - needsSync: Boolean = true): RapidsBufferHandle = { - require(batch.numCols() > 0, - "Cannot call addBatch with a batch that doesn't have columns") - batch.column(0) match { - case _: RapidsHostColumnVector => - addHostBatch(batch, initialSpillPriority, needsSync) - case _ => - closeOnExcept(GpuColumnVector.from(batch)) { table => - addTable(table, initialSpillPriority, needsSync) - } - } - } - - /** - * Adds a table to the device storage. - * - * This takes ownership of the table. The reason for this is that tables - * don't have a reference count, so we cannot cleanly capture ownership by increasing - * ref count and decreasing from the caller. - * - * @param table table that will be owned by the store - * @param initialSpillPriority starting spill priority value - * @param needsSync whether the spill framework should stream synchronize while adding - * this table (defaults to true) - * @return RapidsBufferHandle handle for this RapidsBuffer - */ - def addTable( - table: Table, - initialSpillPriority: Long, - needsSync: Boolean = true): RapidsBufferHandle = { - addTable(TempSpillBufferId(), table, initialSpillPriority, needsSync) - } - - /** - * Adds a table to the device storage. - * - * This takes ownership of the table. The reason for this is that tables - * don't have a reference count, so we cannot cleanly capture ownership by increasing - * ref count and decreasing from the caller. - * - * @param id specific RapidsBufferId to use for this table - * @param table table that will be owned by the store - * @param initialSpillPriority starting spill priority value - * @param needsSync whether the spill framework should stream synchronize while adding - * this table (defaults to true) - * @return RapidsBufferHandle handle for this RapidsBuffer - */ - def addTable( - id: RapidsBufferId, - table: Table, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = { - val rapidsBuffer = deviceStorage.addTable( - id, - table, - initialSpillPriority, - needsSync) - registerNewBuffer(rapidsBuffer) - makeNewHandle(id, initialSpillPriority) - } - - - /** - * Add a host-backed ColumnarBatch to the catalog. This is only called from addBatch - * after we detect that this is a host-backed batch. - */ - private def addHostBatch( - hostCb: ColumnarBatch, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = { - val id = TempSpillBufferId() - val rapidsBuffer = hostStorage.addBatch( - id, - hostCb, - initialSpillPriority, - needsSync) - registerNewBuffer(rapidsBuffer) - makeNewHandle(id, initialSpillPriority) - } - - /** - * Register a degenerate RapidsBufferId given a TableMeta - * @note this is called from the shuffle catalogs only - */ - def registerDegenerateBuffer( - bufferId: RapidsBufferId, - meta: TableMeta): RapidsBufferHandle = synchronized { - val buffer = new DegenerateRapidsBuffer(bufferId, meta) - registerNewBuffer(buffer) - makeNewHandle(buffer.id, buffer.getSpillPriority) - } - - /** - * Called by the catalog when a handle is first added to the catalog, or to refresh - * the priority of the underlying buffer if a handle's priority changed. - */ - private def updateUnderlyingRapidsBuffer(handle: RapidsBufferHandle): Unit = { - withResource(acquireBuffer(handle)) { buffer => - val handles = bufferIdToHandles.get(buffer.id) - val maxPriority = handles.map(_.getSpillPriority).max - // update the priority of the underlying RapidsBuffer to be the - // maximum priority for all handles associated with it - buffer.setSpillPriority(maxPriority) - } - } - - /** - * Lookup the buffer that corresponds to the specified handle at the highest storage tier, - * and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * @param handle handle associated with this `RapidsBuffer` - * @return buffer that has been acquired - */ - def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = { - val id = handle.id - def lookupAndReturn: Option[RapidsBuffer] = { - val buffers = bufferMap.get(id) - if (buffers == null || buffers.isEmpty) { - throw new NoSuchElementException( - s"Cannot locate buffers associated with ID: $id") - } - val buffer = buffers.head - if (buffer.addReference()) { - Some(buffer) - } else { - None - } - } - - // fast path - (0 until RapidsBufferCatalog.MAX_BUFFER_LOOKUP_ATTEMPTS).foreach { _ => - val mayBuffer = lookupAndReturn - if (mayBuffer.isDefined) { - return mayBuffer.get - } - } - - // try one last time after locking the catalog (slow path) - // if there is a lot of contention here, I would rather lock the world than - // have tasks error out with "Unable to acquire" - synchronized { - val mayBuffer = lookupAndReturn - if (mayBuffer.isDefined) { - return mayBuffer.get - } - } - throw new IllegalStateException(s"Unable to acquire buffer for ID: $id") - } - - /** - * Acquires a RapidsBuffer that the caller expects to be host-backed and not - * device bound. This ensures that the buffer acquired implements the correct - * trait, otherwise it throws and removes its buffer acquisition. - * - * @param handle handle associated with this `RapidsBuffer` - * @return host-backed RapidsBuffer that has been acquired - */ - def acquireHostBatchBuffer(handle: RapidsBufferHandle): RapidsHostBatchBuffer = { - closeOnExcept(acquireBuffer(handle)) { - case hrb: RapidsHostBatchBuffer => hrb - case other => - throw new IllegalStateException( - s"Attempted to acquire a RapidsHostBatchBuffer, but got $other instead") - } - } - - /** - * Lookup the buffer that corresponds to the specified buffer ID at the specified storage tier, - * and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * @param id buffer identifier - * @return buffer that has been acquired, None if not found - */ - def acquireBuffer(id: RapidsBufferId, tier: StorageTier): Option[RapidsBuffer] = { - val buffers = bufferMap.get(id) - if (buffers != null) { - buffers.find(_.storageTier == tier).foreach(buffer => - if (buffer.addReference()) { - return Some(buffer) - } - ) - } - None - } - - /** - * Check if the buffer that corresponds to the specified buffer ID is stored in a slower storage - * tier. - * - * @param id buffer identifier - * @param tier storage tier to check - * @note public for testing - * @return true if the buffer is stored in multiple tiers - */ - def isBufferSpilled(id: RapidsBufferId, tier: StorageTier): Boolean = { - val buffers = bufferMap.get(id) - buffers != null && buffers.exists(_.storageTier > tier) - } - - /** Get the table metadata corresponding to a buffer ID. */ - def getBufferMeta(id: RapidsBufferId): TableMeta = { - val buffers = bufferMap.get(id) - if (buffers == null || buffers.isEmpty) { - throw new NoSuchElementException(s"Cannot locate buffer associated with ID: $id") - } - buffers.head.meta - } - - /** - * Register a new buffer with the catalog. An exception will be thrown if an - * existing buffer was registered with the same buffer ID and storage tier. - * @note public for testing - */ - def registerNewBuffer(buffer: RapidsBuffer): Unit = { - val updater = new BiFunction[RapidsBufferId, Seq[RapidsBuffer], Seq[RapidsBuffer]] { - override def apply(key: RapidsBufferId, value: Seq[RapidsBuffer]): Seq[RapidsBuffer] = { - if (value == null) { - Seq(buffer) - } else { - val(first, second) = value.partition(_.storageTier < buffer.storageTier) - if (second.nonEmpty && second.head.storageTier == buffer.storageTier) { - throw new DuplicateBufferException( - s"Buffer ID ${buffer.id} at tier ${buffer.storageTier} already registered " + - s"${second.head}") - } - first ++ Seq(buffer) ++ second - } - } - } - - bufferMap.compute(buffer.id, updater) - } - - /** - * Free memory in `store` by spilling buffers to the spill store synchronously. - * @param store store to spill from - * @param targetTotalSize maximum total size of this store after spilling completes - * @param stream CUDA stream to use or omit for default stream - * @return optionally number of bytes that were spilled, or None if this call - * made no attempt to spill due to a detected spill race - */ - def synchronousSpill( - store: RapidsBufferStore, - targetTotalSize: Long, - stream: Cuda.Stream = Cuda.DEFAULT_STREAM): Option[Long] = { - if (store.spillStore == null) { - throw new OutOfMemoryError("Requested to spill without a spill store") - } - require(targetTotalSize >= 0, s"Negative spill target size: $targetTotalSize") - - val mySpillCount = spillCount - - // we have to hold this lock while freeing buffers, otherwise we could run - // into the case where a buffer is spilled yet it is aliased in addBuffer - // via an event handler that hasn't been reset (it resets during the free) - synchronized { - if (mySpillCount != spillCount) { - // a different thread already spilled, returning - // None which lets the calling code know that rmm should retry allocation - None - } else { - // this thread wins the race and should spill - spillCount += 1 - Some(store.synchronousSpill(targetTotalSize, this, stream)) - } - } - } - - def updateTiers(bufferSpill: SpillAction): Long = bufferSpill match { - case BufferSpill(spilledBuffer, maybeNewBuffer) => - logDebug(s"Spilled ${spilledBuffer.id} from tier ${spilledBuffer.storageTier}. " + - s"Removing. Registering ${maybeNewBuffer.map(_.id).getOrElse ("None")} " + - s"${maybeNewBuffer}") - maybeNewBuffer.foreach(registerNewBuffer) - removeBufferTier(spilledBuffer.id, spilledBuffer.storageTier) - spilledBuffer.memoryUsedBytes - - case BufferUnspill(unspilledBuffer, maybeNewBuffer) => - logDebug(s"Unspilled ${unspilledBuffer.id} from tier ${unspilledBuffer.storageTier}. " + - s"Removing. Registering ${maybeNewBuffer.map(_.id).getOrElse ("None")} " + - s"${maybeNewBuffer}") - maybeNewBuffer.foreach(registerNewBuffer) - removeBufferTier(unspilledBuffer.id, unspilledBuffer.storageTier) - unspilledBuffer.memoryUsedBytes - } - - /** - * Copies `buffer` to the `deviceStorage` store, registering a new `RapidsBuffer` in - * the process - * @param buffer - buffer to copy - * @param stream - Cuda.Stream to synchronize on - * @return - The `RapidsBuffer` instance that was added to the device store. - */ - def unspillBufferToDeviceStore( - buffer: RapidsBuffer, - stream: Cuda.Stream): RapidsBuffer = synchronized { - // try to acquire the buffer, if it's already in the store - // do not create a new one, else add a reference - acquireBuffer(buffer.id, StorageTier.DEVICE) match { - case None => - val maybeNewBuffer = deviceStorage.copyBuffer(buffer, this, stream) - maybeNewBuffer.map { newBuffer => - newBuffer.addReference() // add a reference since we are about to use it - registerNewBuffer(newBuffer) - newBuffer - }.get // the GPU store has to return a buffer here for now, or throw OOM - case Some(existingBuffer) => existingBuffer - } - } - - /** - * Copies `buffer` to the `hostStorage` store, registering a new `RapidsBuffer` in - * the process - * - * @param buffer - buffer to copy - * @param stream - Cuda.Stream to synchronize on - * @return - The `RapidsBuffer` instance that was added to the host store. - */ - def unspillBufferToHostStore( - buffer: RapidsBuffer, - stream: Cuda.Stream): RapidsBuffer = synchronized { - // try to acquire the buffer, if it's already in the store - // do not create a new one, else add a reference - acquireBuffer(buffer.id, StorageTier.HOST) match { - case Some(existingBuffer) => existingBuffer - case None => - val maybeNewBuffer = hostStorage.copyBuffer(buffer, this, stream) - maybeNewBuffer.map { newBuffer => - logDebug(s"got new RapidsHostMemoryStore buffer ${newBuffer.id}") - newBuffer.addReference() // add a reference since we are about to use it - updateTiers(BufferUnspill(buffer, Some(newBuffer))) - buffer.safeFree() - newBuffer - }.get // the host store has to return a buffer here for now, or throw OOM - } - } - - - /** - * Remove a buffer ID from the catalog at the specified storage tier. - * @note public for testing - */ - def removeBufferTier(id: RapidsBufferId, tier: StorageTier): Unit = synchronized { - val updater = new BiFunction[RapidsBufferId, Seq[RapidsBuffer], Seq[RapidsBuffer]] { - override def apply(key: RapidsBufferId, value: Seq[RapidsBuffer]): Seq[RapidsBuffer] = { - val updated = value.filter(_.storageTier != tier) - if (updated.isEmpty) { - null - } else { - updated - } - } - } - bufferMap.computeIfPresent(id, updater) - } - - /** - * Remove a buffer handle from the catalog and, if it this was the final handle, - * release the resources of the registered buffers. - * - * @return true: if the buffer for this handle was removed from the spill framework - * (`handle` was the last handle) - * false: if buffer was not removed due to other live handles. - */ - private def removeBuffer(handle: RapidsBufferHandle): Boolean = synchronized { - // if this is the last handle, remove the buffer - if (stopTrackingHandle(handle)) { - logDebug(s"Removing buffer ${handle.id}") - bufferMap.remove(handle.id).safeFree() - true - } else { - false - } - } - - /** Return the number of buffers currently in the catalog. */ - def numBuffers: Int = bufferMap.size() - - override def close(): Unit = { - bufferIdToHandles.values.asScala.toSeq.flatMap(_.seq).safeClose() - - bufferIdToHandles.clear() - } -} - -object RapidsBufferCatalog extends Logging { - private val MAX_BUFFER_LOOKUP_ATTEMPTS = 100 - - private var deviceStorage: RapidsDeviceMemoryStore = _ - private var hostStorage: RapidsHostMemoryStore = _ - private var diskBlockManager: RapidsDiskBlockManager = _ - private var diskStorage: RapidsDiskStore = _ - private var memoryEventHandler: DeviceMemoryEventHandler = _ - private var _shouldUnspill: Boolean = _ - private var _singleton: RapidsBufferCatalog = null - - def singleton: RapidsBufferCatalog = { - if (_singleton == null) { - synchronized { - if (_singleton == null) { - _singleton = new RapidsBufferCatalog(deviceStorage) - } - } - } - _singleton - } - - private lazy val conf: SparkConf = { - val env = SparkEnv.get - if (env != null) { - env.conf - } else { - // For some unit tests - new SparkConf() - } - } - - /** - * Set a `RapidsDeviceMemoryStore` instance to use when instantiating our - * catalog. - * @note This should only be called from tests! - */ - def setDeviceStorage(rdms: RapidsDeviceMemoryStore): Unit = { - deviceStorage = rdms - } - - /** - * Set a `RapidsDiskStore` instance to use when instantiating our - * catalog. - * - * @note This should only be called from tests! - */ - def setDiskStorage(rdms: RapidsDiskStore): Unit = { - diskStorage = rdms - } - - /** - * Set a `RapidsHostMemoryStore` instance to use when instantiating our - * catalog. - * - * @note This should only be called from tests! - */ - def setHostStorage(rhms: RapidsHostMemoryStore): Unit = { - hostStorage = rhms - } - - /** - * Set a `RapidsBufferCatalog` instance to use our singleton. - * @note This should only be called from tests! - */ - def setCatalog(catalog: RapidsBufferCatalog): Unit = synchronized { - if (_singleton != null) { - _singleton.close() - } - _singleton = catalog - } - - def init(rapidsConf: RapidsConf): Unit = { - // We are going to re-initialize so make sure all of the old things were closed... - closeImpl() - assert(memoryEventHandler == null) - deviceStorage = new RapidsDeviceMemoryStore( - rapidsConf.chunkedPackBounceBufferSize, - rapidsConf.spillToDiskBounceBufferSize) - diskBlockManager = new RapidsDiskBlockManager(conf) - val hostSpillStorageSize = if (rapidsConf.offHeapLimitEnabled) { - // Disable the limit because it is handled by the RapidsHostMemoryStore - None - } else if (rapidsConf.hostSpillStorageSize == -1) { - // + 1 GiB by default to match backwards compatibility - Some(rapidsConf.pinnedPoolSize + (1024 * 1024 * 1024)) - } else { - Some(rapidsConf.hostSpillStorageSize) - } - hostStorage = new RapidsHostMemoryStore(hostSpillStorageSize) - diskStorage = new RapidsDiskStore(diskBlockManager) - deviceStorage.setSpillStore(hostStorage) - hostStorage.setSpillStore(diskStorage) - - logInfo("Installing GPU memory handler for spill") - memoryEventHandler = new DeviceMemoryEventHandler( - singleton, - deviceStorage, - rapidsConf.gpuOomDumpDir, - rapidsConf.gpuOomMaxRetries) - - if (rapidsConf.sparkRmmStateEnable) { - val debugLoc = if (rapidsConf.sparkRmmDebugLocation.isEmpty) { - null - } else { - rapidsConf.sparkRmmDebugLocation - } - - RmmSpark.setEventHandler(memoryEventHandler, debugLoc) - } else { - logWarning("SparkRMM retry has been disabled") - Rmm.setEventHandler(memoryEventHandler) - } - - _shouldUnspill = rapidsConf.isUnspillEnabled - } - - def close(): Unit = { - logInfo("Closing storage") - closeImpl() - } - - /** - * Only used in unit tests, it returns the number of buffers in the catalog. - */ - def numBuffers: Int = { - _singleton.numBuffers - } - - private def closeImpl(): Unit = synchronized { - Seq(_singleton, deviceStorage, hostStorage, diskStorage).safeClose() - - _singleton = null - // Workaround for shutdown ordering problems where device buffers allocated - // with this handler are being freed after the handler is destroyed - //Rmm.clearEventHandler() - memoryEventHandler = null - deviceStorage = null - hostStorage = null - diskStorage = null - } - - def getDeviceStorage: RapidsDeviceMemoryStore = deviceStorage - - def getHostStorage: RapidsHostMemoryStore = hostStorage - - def shouldUnspill: Boolean = _shouldUnspill - - /** - * Adds a contiguous table to the device storage. This does NOT take ownership of the - * contiguous table, so it is the responsibility of the caller to close it. The refcount of the - * underlying device buffer will be incremented so the contiguous table can be closed before - * this buffer is destroyed. - * @param contigTable contiguous table to trackNewHandle in device storage - * @param initialSpillPriority starting spill priority value for the buffer - * @return RapidsBufferHandle associated with this buffer - */ - def addContiguousTable( - contigTable: ContiguousTable, - initialSpillPriority: Long): RapidsBufferHandle = { - singleton.addContiguousTable(contigTable, initialSpillPriority) - } - - /** - * Adds a buffer to the catalog and store. This does NOT take ownership of the - * buffer, so it is the responsibility of the caller to close it. - * @param buffer buffer that will be owned by the store - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @return RapidsBufferHandle associated with this buffer - */ - def addBuffer( - buffer: MemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long): RapidsBufferHandle = { - singleton.addBuffer(buffer, tableMeta, initialSpillPriority) - } - - def addBatch( - batch: ColumnarBatch, - initialSpillPriority: Long): RapidsBufferHandle = { - singleton.addBatch(batch, initialSpillPriority) - } - - /** - * Lookup the buffer that corresponds to the specified buffer handle and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * @param handle buffer handle - * @return buffer that has been acquired - */ - def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = - singleton.acquireBuffer(handle) - - /** - * Acquires a RapidsBuffer that the caller expects to be host-backed and not - * device bound. This ensures that the buffer acquired implements the correct - * trait, otherwise it throws and removes its buffer acquisition. - * - * @param handle handle associated with this `RapidsBuffer` - * @return host-backed RapidsBuffer that has been acquired - */ - def acquireHostBatchBuffer(handle: RapidsBufferHandle): RapidsHostBatchBuffer = - singleton.acquireHostBatchBuffer(handle) - - def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager - - /** - * Free memory in `store` by spilling buffers to its spill store synchronously. - * @param store store to spill from - * @param targetTotalSize maximum total size of this store after spilling completes - * @param stream CUDA stream to use or omit for default stream - * @return optionally number of bytes that were spilled, or None if this call - * made no attempt to spill due to a detected spill race - */ - def synchronousSpill( - store: RapidsBufferStore, - targetTotalSize: Long, - stream: Cuda.Stream = Cuda.DEFAULT_STREAM): Option[Long] = { - singleton.synchronousSpill(store, targetTotalSize, stream) - } - - /** - * Given a `MemoryBuffer` find out if a `MemoryBuffer.EventHandler` is associated - * with it. - * - * After getting the `RapidsBuffer` try to acquire it via `addReference`. - * If successful, we can point to this buffer with a new handle, otherwise the buffer is - * about to be removed/freed (unlikely, because we are holding onto the reference as we - * are adding it again). - * - * @note public for testing - * @param buffer - the `MemoryBuffer` to inspect - * @return - Some(RapidsBuffer): the handler is associated with a rapids buffer - * and the rapids buffer is currently valid, or - * - * - None: if no `RapidsBuffer` is associated with this buffer (it is - * brand new to the store, or the `RapidsBuffer` is invalid and - * about to be removed). - */ - private def getExistingRapidsBufferAndAcquire(buffer: MemoryBuffer): Option[RapidsBuffer] = { - buffer match { - case hb: HostMemoryBuffer => - HostAlloc.findEventHandler(hb) { - case rapidsBuffer: RapidsBuffer => - if (rapidsBuffer.addReference()) { - Some(rapidsBuffer) - } else { - None - } - }.flatten - case _ => - val eh = buffer.getEventHandler - eh match { - case null => - None - case rapidsBuffer: RapidsBuffer => - if (rapidsBuffer.addReference()) { - Some(rapidsBuffer) - } else { - None - } - case _ => - throw new IllegalStateException("Unknown event handler") - } - } - } -} - diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala deleted file mode 100644 index b1ee9e7a863..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala +++ /dev/null @@ -1,640 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.util.Comparator -import java.util.concurrent.locks.ReentrantReadWriteLock - -import scala.collection.mutable - -import ai.rapids.cudf.{BaseDeviceMemoryBuffer, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.Arm._ -import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.StorageTier.{DEVICE, HOST, StorageTier} -import com.nvidia.spark.rapids.format.TableMeta - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * Helper case classes that contain the buffer we spilled or unspilled from our current tier - * and likely a new buffer created in a target store tier, but it can be set to None. - * If the buffer already exists in the target store, `newBuffer` will be None. - * @param spillBuffer a `RapidsBuffer` we spilled or unspilled from this store - * @param newBuffer an optional `RapidsBuffer` in the target store. - */ -trait SpillAction { - val spillBuffer: RapidsBuffer - val newBuffer: Option[RapidsBuffer] -} - -case class BufferSpill(spillBuffer: RapidsBuffer, newBuffer: Option[RapidsBuffer]) - extends SpillAction - -case class BufferUnspill(spillBuffer: RapidsBuffer, newBuffer: Option[RapidsBuffer]) - extends SpillAction - -/** - * Base class for all buffer store types. - * - * @param tier storage tier of this store - * @param catalog catalog to register this store - */ -abstract class RapidsBufferStore(val tier: StorageTier) - extends AutoCloseable with Logging { - - val name: String = tier.toString - - private class BufferTracker { - private[this] val comparator: Comparator[RapidsBufferBase] = - (o1: RapidsBufferBase, o2: RapidsBufferBase) => - java.lang.Long.compare(o1.getSpillPriority, o2.getSpillPriority) - // buffers: contains all buffers in this store, whether spillable or not - private[this] val buffers = new java.util.HashMap[RapidsBufferId, RapidsBufferBase] - // spillable: contains only those buffers that are currently spillable - private[this] val spillable = new HashedPriorityQueue[RapidsBufferBase](comparator) - // spilling: contains only those buffers that are currently being spilled, but - // have not been removed from the store - private[this] val spilling = new mutable.HashSet[RapidsBufferId]() - // total bytes stored, regardless of spillable status - private[this] var totalBytesStored: Long = 0L - // total bytes that are currently eligible to be spilled - private[this] var totalBytesSpillable: Long = 0L - - def add(buffer: RapidsBufferBase): Unit = synchronized { - val old = buffers.put(buffer.id, buffer) - // it is unlikely that the buffer was in this collection, but removing - // anyway. We assume the buffer is safe in this tier, and is not spilling - spilling.remove(buffer.id) - if (old != null) { - throw new DuplicateBufferException(s"duplicate buffer registered: ${buffer.id}") - } - totalBytesStored += buffer.memoryUsedBytes - - // device buffers "spillability" is handled via DeviceMemoryBuffer ref counting - // so spillableOnAdd should be false, all other buffer tiers are spillable at - // all times. - if (spillableOnAdd && buffer.memoryUsedBytes > 0) { - if (spillable.offer(buffer)) { - totalBytesSpillable += buffer.memoryUsedBytes - } - } - } - - def remove(id: RapidsBufferId): Unit = synchronized { - // when removing a buffer we no longer need to know if it was spilling - spilling.remove(id) - val obj = buffers.remove(id) - if (obj != null) { - totalBytesStored -= obj.memoryUsedBytes - if (spillable.remove(obj)) { - totalBytesSpillable -= obj.memoryUsedBytes - } - } - } - - def freeAll(): Unit = { - val values = synchronized { - val buffs = buffers.values().toArray(new Array[RapidsBufferBase](0)) - buffers.clear() - spillable.clear() - spilling.clear() - buffs - } - // We need to release the `RapidsBufferStore` lock to prevent a lock order inversion - // deadlock: (1) `RapidsBufferBase.free` calls (2) `RapidsBufferStore.remove` and - // (1) `RapidsBufferStore.freeAll` calls (2) `RapidsBufferBase.free`. - values.safeFree() - } - - /** - * Sets a buffers state to spillable or non-spillable. - * - * If the buffer is currently being spilled or it is no longer in the `buffers` collection - * (e.g. it is not in this store), the action is skipped. - * - * @param buffer the buffer to mark as spillable or not - * @param isSpillable whether the buffer should now be spillable - */ - def setSpillable(buffer: RapidsBufferBase, isSpillable: Boolean): Unit = synchronized { - if (isSpillable && buffer.memoryUsedBytes > 0) { - // if this buffer is in the store and isn't currently spilling - if (!spilling.contains(buffer.id) && buffers.containsKey(buffer.id)) { - // try to add it to the spillable collection - if (spillable.offer(buffer)) { - totalBytesSpillable += buffer.memoryUsedBytes - logDebug(s"Buffer ${buffer.id} is spillable. " + - s"total=${totalBytesStored} spillable=${totalBytesSpillable}") - } // else it was already there (unlikely) - } - } else { - if (spillable.remove(buffer)) { - totalBytesSpillable -= buffer.memoryUsedBytes - logDebug(s"Buffer ${buffer.id} is not spillable. " + - s"total=${totalBytesStored}, spillable=${totalBytesSpillable}") - } // else it was already removed - } - } - - def nextSpillableBuffer(): RapidsBufferBase = synchronized { - val buffer = spillable.poll() - if (buffer != null) { - // mark the id as "spilling" (this buffer is in the middle of a spill operation) - spilling.add(buffer.id) - totalBytesSpillable -= buffer.memoryUsedBytes - logDebug(s"Spilling buffer ${buffer.id}. size=${buffer.memoryUsedBytes} " + - s"total=${totalBytesStored}, new spillable=${totalBytesSpillable}") - } - buffer - } - - def updateSpillPriority(buffer: RapidsBufferBase, priority:Long): Unit = synchronized { - buffer.updateSpillPriorityValue(priority) - spillable.priorityUpdated(buffer) - } - - def getTotalBytes: Long = synchronized { totalBytesStored } - - def getTotalSpillableBytes: Long = synchronized { totalBytesSpillable } - } - - /** - * Stores that need to stay within a specific byte limit of buffers stored override - * this function. Only the `HostMemoryBufferStore` requires such a limit. - * @return maximum amount of bytes that can be stored in the store, None for no - * limit - */ - def getMaxSize: Option[Long] = None - - private[this] val buffers = new BufferTracker - - /** A store that can be used for spilling. */ - var spillStore: RapidsBufferStore = _ - - /** Return the current byte total of buffers in this store. */ - def currentSize: Long = buffers.getTotalBytes - - def currentSpillableSize: Long = buffers.getTotalSpillableBytes - - /** - * A store that manages spillability of buffers should override this method - * to false, otherwise `BufferTracker` treats buffers as always spillable. - */ - protected def spillableOnAdd: Boolean = true - - /** - * Specify another store that can be used when this store needs to spill. - * @note Only one spill store can be registered. This will throw if a - * spill store has already been registered. - */ - def setSpillStore(store: RapidsBufferStore): Unit = { - require(spillStore == null, "spill store already registered") - spillStore = store - } - - /** - * Adds an existing buffer from another store to this store. The buffer must already - * have an active reference by the caller and needs to be eventually closed by the caller - * (i.e.: this method will not take ownership of the incoming buffer object). - * This does not need to update the catalog, the caller is responsible for that. - * @param buffer data from another store - * @param catalog RapidsBufferCatalog we may need to modify during this copy - * @param stream CUDA stream to use for copy or null - * @return the new buffer that was created - */ - def copyBuffer( - buffer: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] = { - createBuffer(buffer, catalog, stream).map { newBuffer => - freeOnExcept(newBuffer) { newBuffer => - addBuffer(newBuffer) - newBuffer - } - } - } - - protected def setSpillable(buffer: RapidsBufferBase, isSpillable: Boolean): Unit = { - buffers.setSpillable(buffer, isSpillable) - } - - /** - * Create a new buffer from an existing buffer in another store. - * If the data transfer will be performed asynchronously, this method is responsible for - * adding a reference to the existing buffer and later closing it when the transfer completes. - * - * @note DO NOT close the buffer unless adding a reference! - * @note `createBuffer` impls should synchronize against `stream` before returning, if needed. - * @param buffer data from another store - * @param catalog RapidsBufferCatalog we may need to modify during this create - * @param stream CUDA stream to use or null - * @return the new buffer that was created. - */ - protected def createBuffer( - buffer: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] - - /** Update bookkeeping for a new buffer */ - protected def addBuffer(buffer: RapidsBufferBase): Unit = { - buffers.add(buffer) - buffer.updateSpillability() - } - - /** - * Adds a buffer to the spill framework, stream synchronizing with the producer - * stream to ensure that the buffer is fully materialized, and can be safely copied - * as part of the spill. - * - * @param needsSync true if we should stream synchronize before adding the buffer - */ - protected def addBuffer(buffer: RapidsBufferBase, needsSync: Boolean): Unit = { - if (needsSync) { - Cuda.DEFAULT_STREAM.sync() - } - addBuffer(buffer) - } - - override def close(): Unit = { - buffers.freeAll() - } - - def nextSpillable(): RapidsBuffer = { - buffers.nextSpillableBuffer() - } - - def synchronousSpill( - targetTotalSize: Long, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream = Cuda.DEFAULT_STREAM): Long = { - if (currentSpillableSize > targetTotalSize) { - logWarning(s"Targeting a ${name} size of $targetTotalSize. " + - s"Current total ${currentSize}. " + - s"Current spillable ${currentSpillableSize}") - val bufferSpills = new mutable.ArrayBuffer[BufferSpill]() - withResource(new NvtxRange(s"${name} sync spill", NvtxColor.ORANGE)) { _ => - logWarning(s"${name} store spilling to reduce usage from " + - s"${currentSize} total (${currentSpillableSize} spillable) " + - s"to $targetTotalSize bytes") - - // If the store has 0 spillable bytes left, it has exhausted. - try { - var exhausted = false - var totalSpilled = 0L - while (!exhausted && - currentSpillableSize > targetTotalSize) { - val nextSpillableBuffer = nextSpillable() - if (nextSpillableBuffer != null) { - if (nextSpillableBuffer.addReference()) { - withResource(nextSpillableBuffer) { _ => - val bufferHasSpilled = - catalog.isBufferSpilled( - nextSpillableBuffer.id, - nextSpillableBuffer.storageTier) - val bufferSpill = if (!bufferHasSpilled) { - spillBuffer( - nextSpillableBuffer, this, catalog, stream) - } else { - // if `nextSpillableBuffer` already spilled, we still need to - // remove it from our tier and call free on it, but set - // `newBuffer` to None because there's nothing to register - // as it has already spilled. - BufferSpill(nextSpillableBuffer, None) - } - totalSpilled += bufferSpill.spillBuffer.memoryUsedBytes - bufferSpills.append(bufferSpill) - catalog.updateTiers(bufferSpill) - } - } - } - } - if (totalSpilled <= 0) { - // we didn't spill in this iteration, exit loop - exhausted = true - logWarning("Unable to spill enough to meet request. " + - s"Total=${currentSize} " + - s"Spillable=${currentSpillableSize} " + - s"Target=$targetTotalSize") - } - totalSpilled - } finally { - if (bufferSpills.nonEmpty) { - // This is a hack in order to completely synchronize with the GPU before we free - // a buffer. It is necessary because of non-synchronous cuDF calls that could fall - // behind where the CPU is. Freeing a rapids buffer in these cases needs to wait for - // all launched GPU work, otherwise crashes or data corruption could occur. - // A more performant implementation would be to synchronize on the thread that read - // the buffer via events. - // https://github.com/NVIDIA/spark-rapids/issues/8610 - Cuda.deviceSynchronize() - bufferSpills.foreach(_.spillBuffer.safeFree()) - } - } - } - } else { - 0L // nothing spilled - } - } - - /** - * Given a specific `RapidsBuffer` spill it to `spillStore` - * - * @return a `BufferSpill` instance with the target buffer in this store, and an optional - * new `RapidsBuffer` in the target spill store if this rapids buffer hadn't already - * spilled. - * @note called with catalog lock held - */ - private def spillBuffer( - buffer: RapidsBuffer, - store: RapidsBufferStore, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): BufferSpill = { - // copy the buffer to spillStore - var maybeNewBuffer: Option[RapidsBuffer] = None - var lastTier: Option[StorageTier] = None - var nextSpillStore = store.spillStore - while (maybeNewBuffer.isEmpty && nextSpillStore != null) { - lastTier = Some(nextSpillStore.tier) - // copy buffer if it fits - maybeNewBuffer = nextSpillStore.copyBuffer(buffer, catalog, stream) - - // if it didn't fit, we can try a lower tier that has more space - if (maybeNewBuffer.isEmpty) { - nextSpillStore = nextSpillStore.spillStore - } - } - if (maybeNewBuffer.isEmpty) { - throw new IllegalStateException( - s"Unable to spill buffer ${buffer.id} of size ${buffer.memoryUsedBytes} " + - s"to tier ${lastTier}") - } - // return the buffer to free and the new buffer to register - BufferSpill(buffer, maybeNewBuffer) - } - - /** - * Tries to make room for `buffer` in the host store by spilling. - * - * @param buffer buffer that will be copied to the host store if it fits - * @param stream CUDA stream to synchronize for memory operations - * @return true if the buffer fits after a potential spill - */ - protected def trySpillToMaximumSize( - buffer: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Boolean = { - true // default to success, HostMemoryStore overrides this - } - - /** Base class for all buffers in this store. */ - abstract class RapidsBufferBase( - override val id: RapidsBufferId, - _meta: TableMeta, - initialSpillPriority: Long, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton) - extends RapidsBuffer { - private val MAX_UNSPILL_ATTEMPTS = 100 - - // isValid and refcount must be used with the `RapidsBufferBase` lock held - protected[this] var isValid = true - protected[this] var refcount = 0 - - private[this] var spillPriority: Long = initialSpillPriority - - private[this] val rwl: ReentrantReadWriteLock = new ReentrantReadWriteLock() - - - def meta: TableMeta = _meta - - /** Release the underlying resources for this buffer. */ - protected def releaseResources(): Unit - - /** - * Materialize the memory buffer from the underlying storage. - * - * If the buffer resides in device or host memory, only reference count is incremented. - * If the buffer resides in secondary storage, a new host or device memory buffer is created, - * with the data copied to the new buffer. - * The caller must have successfully acquired the buffer beforehand. - * @see [[addReference]] - * @note It is the responsibility of the caller to close the buffer. - * @note This is an internal API only used by Rapids buffer stores. - */ - protected def materializeMemoryBuffer: MemoryBuffer = getMemoryBuffer - - override def addReference(): Boolean = synchronized { - if (isValid) { - refcount += 1 - } - isValid - } - - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - // NOTE: Cannot hold a lock on this buffer here because memory is being - // allocated. Allocations can trigger synchronous spills which can - // deadlock if another thread holds the device store lock and is trying - // to spill to this store. - withResource(getDeviceMemoryBuffer) { deviceBuffer => - columnarBatchFromDeviceBuffer(deviceBuffer, sparkTypes) - } - } - - protected def columnarBatchFromDeviceBuffer(devBuffer: DeviceMemoryBuffer, - sparkTypes: Array[DataType]): ColumnarBatch = { - val bufferMeta = meta.bufferMeta() - if (bufferMeta == null || bufferMeta.codecBufferDescrsLength == 0) { - MetaUtils.getBatchFromMeta(devBuffer, meta, sparkTypes) - } else { - GpuCompressedColumnVector.from(devBuffer, meta) - } - } - - override def copyToMemoryBuffer(srcOffset: Long, dst: MemoryBuffer, dstOffset: Long, - length: Long, stream: Cuda.Stream): Unit = { - withResource(getMemoryBuffer) { memBuff => - dst match { - case _: HostMemoryBuffer => - // TODO: consider moving to the async version. - dst.copyFromMemoryBuffer(dstOffset, memBuff, srcOffset, length, stream) - case _: BaseDeviceMemoryBuffer => - dst.copyFromMemoryBufferAsync(dstOffset, memBuff, srcOffset, length, stream) - case _ => - throw new IllegalStateException(s"Infeasible destination buffer type ${dst.getClass}") - } - } - } - - /** - * TODO: we want to remove this method from the buffer, instead we want the catalog - * to be responsible for producing the DeviceMemoryBuffer by asking the buffer. This - * hides the RapidsBuffer from clients and simplifies locking. - */ - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = { - if (RapidsBufferCatalog.shouldUnspill) { - (0 until MAX_UNSPILL_ATTEMPTS).foreach { _ => - catalog.acquireBuffer(id, DEVICE) match { - case Some(buffer) => - withResource(buffer) { _ => - return buffer.getDeviceMemoryBuffer - } - case _ => - try { - logDebug(s"Unspilling $this $id to $DEVICE") - val newBuffer = catalog.unspillBufferToDeviceStore( - this, - Cuda.DEFAULT_STREAM) - withResource(newBuffer) { _ => - return newBuffer.getDeviceMemoryBuffer - } - } catch { - case _: DuplicateBufferException => - logDebug(s"Lost device buffer registration race for buffer $id, retrying...") - } - } - } - throw new IllegalStateException(s"Unable to get device memory buffer for ID: $id") - } else { - materializeMemoryBuffer match { - case h: HostMemoryBuffer => - withResource(h) { _ => - closeOnExcept(DeviceMemoryBuffer.allocate(h.getLength)) { deviceBuffer => - logDebug(s"copying ${h.getLength} from host $h to device $deviceBuffer " + - s"of size ${deviceBuffer.getLength}") - deviceBuffer.copyFromHostBuffer(h) - deviceBuffer - } - } - case d: DeviceMemoryBuffer => d - case b => throw new IllegalStateException(s"Unrecognized buffer: $b") - } - } - } - - override def getHostMemoryBuffer: HostMemoryBuffer = { - (0 until MAX_UNSPILL_ATTEMPTS).foreach { _ => - catalog.acquireBuffer(id, HOST) match { - case Some(buffer) => - withResource(buffer) { _ => - return buffer.getHostMemoryBuffer - } - case _ => - try { - logDebug(s"Unspilling $this $id to $HOST") - val newBuffer = catalog.unspillBufferToHostStore( - this, - Cuda.DEFAULT_STREAM) - withResource(newBuffer) { _ => - return newBuffer.getHostMemoryBuffer - } - } catch { - case _: DuplicateBufferException => - logDebug(s"Lost host buffer registration race for buffer $id, retrying...") - } - } - } - throw new IllegalStateException(s"Unable to get host memory buffer for ID: $id") - } - - /** - * close() is called by client code to decrease the ref count of this RapidsBufferBase. - * In the off chance that by the time close is invoked, the buffer was freed (not valid) - * then this close call winds up freeing the resources of the rapids buffer. - */ - override def close(): Unit = synchronized { - if (refcount == 0) { - throw new IllegalStateException("Buffer already closed") - } - refcount -= 1 - if (refcount == 0 && !isValid) { - freeBuffer() - } - } - - /** - * Mark the buffer as freed and no longer valid. This is called by the store when removing a - * buffer (it is no longer tracked). - * - * @note The resources may not be immediately released if the buffer has outstanding references. - * In that case the resources will be released when the reference count reaches zero. - */ - override def free(): Unit = synchronized { - if (isValid) { - isValid = false - buffers.remove(id) - if (refcount == 0) { - freeBuffer() - } - } else { - logWarning(s"Trying to free an invalid buffer => $id, size = ${memoryUsedBytes}, $this") - } - } - - override def getSpillPriority: Long = spillPriority - - override def setSpillPriority(priority: Long): Unit = - buffers.updateSpillPriority(this, priority) - - private[RapidsBufferStore] def updateSpillPriorityValue(priority: Long): Unit = { - spillPriority = priority - } - - override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { - withResource(getMemoryBuffer) { buff => - val lock = rwl.readLock() - try { - lock.lock() - body(buff) - } finally { - lock.unlock() - } - } - } - - override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { - withResource(getMemoryBuffer) { buff => - val lock = rwl.writeLock() - try { - lock.lock() - body(buff) - } finally { - lock.unlock() - } - } - } - - /** Must be called with a lock on the buffer */ - private def freeBuffer(): Unit = { - releaseResources() - } - - override def toString: String = s"$name buffer size=$memoryUsedBytes" - } -} - -/** - * Buffers that inherit from this type do not support changing the spillable status - * of a `RapidsBuffer`. This is only used right now for disk. - * @param tier storage tier of this store - */ -abstract class RapidsBufferStoreWithoutSpill(override val tier: StorageTier) - extends RapidsBufferStore(tier) { - - override def setSpillable(rapidsBuffer: RapidsBufferBase, isSpillable: Boolean): Unit = { - throw new NotImplementedError(s"This store ${this} does not implement setSpillable") - } -} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala deleted file mode 100644 index c56806bc965..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala +++ /dev/null @@ -1,518 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.nio.channels.WritableByteChannel -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.mutable - -import ai.rapids.cudf.{ColumnVector, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Table} -import com.nvidia.spark.rapids.Arm._ -import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableSeq -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta - -import org.apache.spark.sql.rapids.GpuTaskMetrics -import org.apache.spark.sql.rapids.storage.RapidsStorageUtils -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * Buffer storage using device memory. - * @param chunkedPackBounceBufferSize this is the size of the bounce buffer to be used - * during spill in chunked_pack. The parameter defaults to 128MB, - * with a rule-of-thumb of 1MB per SM. - */ -class RapidsDeviceMemoryStore( - chunkedPackBounceBufferSize: Long = 128L*1024*1024, - hostBounceBufferSize: Long = 128L*1024*1024) - extends RapidsBufferStore(StorageTier.DEVICE) { - - // The RapidsDeviceMemoryStore handles spillability via ref counting - override protected def spillableOnAdd: Boolean = false - - // bounce buffer to be used during chunked pack in GPU to host memory spill - private var chunkedPackBounceBuffer: DeviceMemoryBuffer = - DeviceMemoryBuffer.allocate(chunkedPackBounceBufferSize) - - private var hostSpillBounceBuffer: HostMemoryBuffer = - HostMemoryBuffer.allocate(hostBounceBufferSize) - - override protected def createBuffer( - other: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] = { - val memoryBuffer = withResource(other.getCopyIterator) { copyIterator => - copyIterator.next() - } - withResource(memoryBuffer) { _ => - val deviceBuffer = { - memoryBuffer match { - case d: DeviceMemoryBuffer => d - case h: HostMemoryBuffer => - GpuTaskMetrics.get.readSpillFromHostTime { - closeOnExcept(DeviceMemoryBuffer.allocate(memoryBuffer.getLength)) { deviceBuffer => - logDebug(s"copying from host $h to device $deviceBuffer") - deviceBuffer.copyFromHostBuffer(h, stream) - deviceBuffer - } - } - case b => throw new IllegalStateException(s"Unrecognized buffer: $b") - } - } - Some(new RapidsDeviceMemoryBuffer( - other.id, - deviceBuffer.getLength, - other.meta, - deviceBuffer, - other.getSpillPriority)) - } - } - - /** - * Adds a buffer to the device storage. This does NOT take ownership of the - * buffer, so it is the responsibility of the caller to close it. - * - * This function is called only from the RapidsBufferCatalog, under the - * catalog lock. - * - * @param id the RapidsBufferId to use for this buffer - * @param buffer buffer that will be owned by the store - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return the RapidsBuffer instance that was added. - */ - def addBuffer( - id: RapidsBufferId, - buffer: DeviceMemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBuffer = { - buffer.incRefCount() - val rapidsBuffer = new RapidsDeviceMemoryBuffer( - id, - buffer.getLength, - tableMeta, - buffer, - initialSpillPriority) - freeOnExcept(rapidsBuffer) { _ => - logDebug(s"Adding receive side table for: [id=$id, size=${buffer.getLength}, " + - s"uncompressed=${rapidsBuffer.meta.bufferMeta.uncompressedSize}, " + - s"meta_id=${tableMeta.bufferMeta.id}, " + - s"meta_size=${tableMeta.bufferMeta.size}]") - addBuffer(rapidsBuffer, needsSync) - rapidsBuffer - } - } - - /** - * Adds a table to the device storage. - * - * This takes ownership of the table. - * - * This function is called only from the RapidsBufferCatalog, under the - * catalog lock. - * - * @param id the RapidsBufferId to use for this table - * @param table table that will be owned by the store - * @param initialSpillPriority starting spill priority value - * @param needsSync whether the spill framework should stream synchronize while adding - * this table (defaults to true) - * @return the RapidsBuffer instance that was added. - */ - def addTable( - id: RapidsBufferId, - table: Table, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBuffer = { - val rapidsTable = new RapidsTable( - id, - table, - initialSpillPriority) - freeOnExcept(rapidsTable) { _ => - addBuffer(rapidsTable, needsSync) - rapidsTable - } - } - - /** - * A per cuDF column event handler that handles calls to .close() - * inside of the `ColumnVector` lock. - */ - class RapidsDeviceColumnEventHandler - extends ColumnVector.EventHandler { - - // Every RapidsTable that references this column has an entry in this map. - // The value represents the number of times (normally 1) that a ColumnVector - // appears in the RapidsTable. This is also the ColumnVector refCount at which - // the column is considered spillable. - // The map is protected via the ColumnVector lock. - private val registration = new mutable.HashMap[RapidsTable, Int]() - - /** - * Every RapidsTable iterates through its columns and either creates - * a `ColumnTracking` object and associates it with the column's - * `eventHandler` or calls into the existing one, and registers itself. - * - * The registration has two goals: it accounts for repetition of a column - * in a `RapidsTable`. If a table has the same column repeated it must adjust - * the refCount at which this column is considered spillable. - * - * The second goal is to account for aliasing. If two tables alias this column - * we are going to mark it as non spillable. - * - * @param rapidsTable - the table that is registering itself with this tracker - */ - def register(rapidsTable: RapidsTable, repetition: Int): Unit = { - registration.put(rapidsTable, repetition) - } - - /** - * This is invoked during `RapidsTable.free` in order to remove the entry - * in `registration`. - * @param rapidsTable - the table that is de-registering itself - */ - def deregister(rapidsTable: RapidsTable): Unit = { - registration.remove(rapidsTable) - } - - // called with the cudfCv lock held from cuDF's side - override def onClosed(cudfCv: ColumnVector, refCount: Int): Unit = { - // we only handle spillability if there is a single table registered - // (no aliasing) - if (registration.size == 1) { - val (rapidsTable, spillableRefCount) = registration.head - if (spillableRefCount == refCount) { - rapidsTable.onColumnSpillable(cudfCv) - } - } - } - } - - /** - * A `RapidsTable` is the spill store holder of a cuDF `Table`. - * - * The table is not contiguous in GPU memory. Instead, this `RapidsBuffer` instance - * allows us to use the cuDF chunked_pack API to make the table contiguous as the spill - * is happening. - * - * This class owns the cuDF table and will close it when `close` is called. - * - * @param id the `RapidsBufferId` this table is associated with - * @param table the cuDF table that we are managing - * @param spillPriority a starting spill priority - */ - class RapidsTable( - id: RapidsBufferId, - table: Table, - spillPriority: Long) - extends RapidsBufferBase( - id, - null, - spillPriority) - with RapidsBufferChannelWritable { - - /** The storage tier for this buffer */ - override val storageTier: StorageTier = StorageTier.DEVICE - - override val supportsChunkedPacker: Boolean = true - - // This is the current size in batch form. It is to be used while this - // table hasn't migrated to another store. - private val unpackedSizeInBytes: Long = GpuColumnVector.getTotalDeviceMemoryUsed(table) - - // By default all columns are NOT spillable since we are not the only owners of - // the columns (the caller is holding onto a ColumnarBatch that will be closed - // after instantiation, triggering onClosed callbacks) - // This hash set contains the columns that are currently spillable. - private val columnSpillability = new ConcurrentHashMap[ColumnVector, Boolean]() - - private val numDistinctColumns = - (0 until table.getNumberOfColumns).map(table.getColumn).distinct.size - - // we register our event callbacks as the very first action to deal with - // spillability - registerOnCloseEventHandler() - - /** Release the underlying resources for this buffer. */ - override protected def releaseResources(): Unit = { - table.close() - } - - private lazy val (cachedMeta, cachedPackedSize) = { - withResource(makeChunkedPacker) { cp => - (cp.getMeta, cp.getTotalContiguousSize) - } - } - - override def meta: TableMeta = cachedMeta - - override val memoryUsedBytes: Long = unpackedSizeInBytes - - override def getPackedSizeBytes: Long = cachedPackedSize - - override def makeChunkedPacker: ChunkedPacker = - new ChunkedPacker(id, table, chunkedPackBounceBuffer) - - /** - * Mark a column as spillable - * - * @param column the ColumnVector to mark as spillable - */ - def onColumnSpillable(column: ColumnVector): Unit = { - columnSpillability.put(column, true) - updateSpillability() - } - - /** - * Update the spillability state of this RapidsTable. This is invoked from - * two places: - * - * - from the onColumnSpillable callback, which is invoked from a - * ColumnVector.EventHandler.onClosed callback. - * - * - after adding a table to the store to mark the table as spillable if - * all columns are spillable. - */ - override def updateSpillability(): Unit = { - setSpillable(this, columnSpillability.size == numDistinctColumns) - } - - /** - * Produce a `ColumnarBatch` from our table, and in the process make ourselves - * not spillable. - * - * @param sparkTypes the spark data types the batch should have - */ - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - columnSpillability.clear() - setSpillable(this, false) - GpuColumnVector.from(table, sparkTypes) - } - - /** - * Get the underlying memory buffer. This may be either a HostMemoryBuffer or a - * DeviceMemoryBuffer depending on where the buffer currently resides. - * The caller must have successfully acquired the buffer beforehand. - * - * @see [[addReference]] - * @note It is the responsibility of the caller to close the buffer. - */ - override def getMemoryBuffer: MemoryBuffer = { - throw new UnsupportedOperationException( - "RapidsDeviceMemoryBatch doesn't support getMemoryBuffer") - } - - override def free(): Unit = { - // lets remove our handler from the chain of handlers for each column - removeOnCloseEventHandler() - super.free() - } - - private def registerOnCloseEventHandler(): Unit = { - val columns = (0 until table.getNumberOfColumns).map(table.getColumn) - // cudfColumns could contain duplicates. We need to take this into account when we are - // deciding the floor refCount for a duplicated column - val repetitionPerColumn = new mutable.HashMap[ColumnVector, Int]() - columns.foreach { col => - val repetitionCount = repetitionPerColumn.getOrElse(col, 0) - repetitionPerColumn(col) = repetitionCount + 1 - } - repetitionPerColumn.foreach { case (distinctCv, repetition) => - // lock the column because we are setting its event handler, and we are inspecting - // its refCount. - distinctCv.synchronized { - val eventHandler = distinctCv.getEventHandler match { - case null => - val eventHandler = new RapidsDeviceColumnEventHandler - distinctCv.setEventHandler(eventHandler) - eventHandler - case existing: RapidsDeviceColumnEventHandler => - existing - case other => - throw new IllegalStateException( - s"Invalid column event handler $other") - } - eventHandler.register(this, repetition) - if (repetition == distinctCv.getRefCount) { - onColumnSpillable(distinctCv) - } - } - } - } - - // this method is called from free() - private def removeOnCloseEventHandler(): Unit = { - val distinctColumns = - (0 until table.getNumberOfColumns).map(table.getColumn).distinct - distinctColumns.foreach { distinctCv => - distinctCv.synchronized { - distinctCv.getEventHandler match { - case eventHandler: RapidsDeviceColumnEventHandler => - eventHandler.deregister(this) - case t => - throw new IllegalStateException( - s"Invalid column event handler $t") - } - } - } - } - - override def writeToChannel(outputChannel: WritableByteChannel, stream: Cuda.Stream): Long = { - var written: Long = 0L - withResource(getCopyIterator) { copyIter => - while(copyIter.hasNext) { - withResource(copyIter.next()) { slice => - val iter = - new MemoryBufferToHostByteBufferIterator( - slice, - hostSpillBounceBuffer, - stream) - iter.foreach { bb => - try { - while (bb.hasRemaining) { - written += outputChannel.write(bb) - } - } finally { - RapidsStorageUtils.dispose(bb) - } - } - } - } - written - } - } - - } - - class RapidsDeviceMemoryBuffer( - id: RapidsBufferId, - size: Long, - meta: TableMeta, - contigBuffer: DeviceMemoryBuffer, - spillPriority: Long) - extends RapidsBufferBase(id, meta, spillPriority) - with MemoryBuffer.EventHandler - with RapidsBufferChannelWritable { - - override val memoryUsedBytes: Long = size - - override val storageTier: StorageTier = StorageTier.DEVICE - - // If this require triggers, we are re-adding a `DeviceMemoryBuffer` outside of - // the catalog lock, which should not possible. The event handler is set to null - // when we free the `RapidsDeviceMemoryBuffer` and if the buffer is not free, we - // take out another handle (in the catalog). - // TODO: This is not robust (to rely on outside locking and addReference/free) - // and should be revisited. - require(contigBuffer.setEventHandler(this) == null, - "DeviceMemoryBuffer with non-null event handler failed to add!!") - - /** - * Override from the MemoryBuffer.EventHandler interface. - * - * If we are being invoked we have the `contigBuffer` lock, as this callback - * is being invoked from `MemoryBuffer.close` - * - * @param refCount - contigBuffer's current refCount - */ - override def onClosed(refCount: Int): Unit = { - // refCount == 1 means only 1 reference exists to `contigBuffer` in the - // RapidsDeviceMemoryBuffer (we own it) - if (refCount == 1) { - // setSpillable is being called here as an extension of `MemoryBuffer.close()` - // we hold the MemoryBuffer lock and we could be called from a Spark task thread - // Since we hold the MemoryBuffer lock, `incRefCount` waits for us. The only other - // call to `setSpillable` is also under this same MemoryBuffer lock (see: - // `getDeviceMemoryBuffer`) - setSpillable(this, true) - } - } - - override protected def releaseResources(): Unit = synchronized { - // we need to disassociate this RapidsBuffer from the underlying buffer - contigBuffer.close() - } - - /** - * Get and increase the reference count of the device memory buffer - * in this RapidsBuffer, while making the RapidsBuffer non-spillable. - * - * @note It is the responsibility of the caller to close the DeviceMemoryBuffer - */ - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = synchronized { - contigBuffer.synchronized { - setSpillable(this, false) - contigBuffer.incRefCount() - contigBuffer - } - } - - override def getMemoryBuffer: MemoryBuffer = getDeviceMemoryBuffer - - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - // calling `getDeviceMemoryBuffer` guarantees that we have marked this RapidsBuffer - // as not spillable and increased its refCount atomically - withResource(getDeviceMemoryBuffer) { buff => - columnarBatchFromDeviceBuffer(buff, sparkTypes) - } - } - - /** - * We overwrite free to make sure we don't have a handler for the underlying - * contigBuffer, since this `RapidsBuffer` is no longer tracked. - */ - override def free(): Unit = synchronized { - if (isValid) { - // it is going to be invalid when calling super.free() - contigBuffer.setEventHandler(null) - } - super.free() - } - - override def writeToChannel(outputChannel: WritableByteChannel, stream: Cuda.Stream): Long = { - var written: Long = 0L - val iter = new MemoryBufferToHostByteBufferIterator( - contigBuffer, - hostSpillBounceBuffer, - stream) - iter.foreach { bb => - try { - while (bb.hasRemaining) { - written += outputChannel.write(bb) - } - } finally { - RapidsStorageUtils.dispose(bb) - } - } - written - } - - } - override def close(): Unit = { - try { - super.close() - } finally { - Seq(chunkedPackBounceBuffer, hostSpillBounceBuffer).safeClose() - chunkedPackBounceBuffer = null - hostSpillBounceBuffer = null - } - } -} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala deleted file mode 100644 index eb3692d434a..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.{File, FileInputStream} -import java.nio.channels.{Channels, FileChannel} -import java.nio.channels.FileChannel.MapMode -import java.nio.file.StandardOpenOption -import java.util.concurrent.ConcurrentHashMap - -import ai.rapids.cudf.{Cuda, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta -import org.apache.commons.io.IOUtils - -import org.apache.spark.TaskContext -import org.apache.spark.sql.rapids.{GpuTaskMetrics, RapidsDiskBlockManager} -import org.apache.spark.sql.rapids.execution.{SerializedHostTableUtils, TrampolineUtil} -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** A buffer store using files on the local disks. */ -class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) - extends RapidsBufferStoreWithoutSpill(StorageTier.DISK) { - private[this] val sharedBufferFiles = new ConcurrentHashMap[RapidsBufferId, File] - - private def reportDiskAllocMetrics(metrics: GpuTaskMetrics): String = { - val taskId = TaskContext.get().taskAttemptId() - val totalSize = metrics.getDiskBytesAllocated - val maxSize = metrics.getMaxDiskBytesAllocated - s"total size for task $taskId is $totalSize, max size is $maxSize" - } - - override protected def createBuffer( - incoming: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] = { - // assuming that the disk store gets contiguous buffers - val id = incoming.id - val path = if (id.canShareDiskPaths) { - sharedBufferFiles.computeIfAbsent(id, _ => id.getDiskPath(diskBlockManager)) - } else { - id.getDiskPath(diskBlockManager) - } - - val (fileOffset, uncompressedSize, diskLength) = if (id.canShareDiskPaths) { - // only one writer at a time for now when using shared files - path.synchronized { - writeToFile(incoming, path, append = true, stream) - } - } else { - writeToFile(incoming, path, append = false, stream) - } - logDebug(s"Spilled to $path $fileOffset:$diskLength") - val buff = incoming match { - case _: RapidsHostBatchBuffer => - new RapidsDiskColumnarBatch( - id, - fileOffset, - uncompressedSize, - diskLength, - incoming.meta, - incoming.getSpillPriority) - - case _ => - new RapidsDiskBuffer( - id, - fileOffset, - uncompressedSize, - diskLength, - incoming.meta, - incoming.getSpillPriority) - } - TrampolineUtil.incTaskMetricsDiskBytesSpilled(uncompressedSize) - - val metrics = GpuTaskMetrics.get - metrics.incDiskBytesAllocated(uncompressedSize) - logDebug(s"acquiring resources for disk buffer $id of size $uncompressedSize bytes") - logDebug(reportDiskAllocMetrics(metrics)) - Some(buff) - } - - /** - * Copy a host buffer to a file. It leverages [[RapidsSerializerManager]] from - * [[RapidsDiskBlockManager]] to do compression or encryption if needed. - * - * @param incoming the rapid buffer to be written into a file - * @param path file path - * @param append whether to append or written into the beginning of the file - * @param stream cuda stream - * @return a tuple of file offset, memory byte size and written size on disk. File offset is where - * buffer starts in the targeted file path. Memory byte size is the size of byte buffer - * occupied in memory before writing to disk. Written size on disk is actual byte size - * written to disk. - */ - private def writeToFile( - incoming: RapidsBuffer, - path: File, - append: Boolean, - stream: Cuda.Stream): (Long, Long, Long) = { - incoming match { - case fileWritable: RapidsBufferChannelWritable => - val option = if (append) { - Array(StandardOpenOption.CREATE, StandardOpenOption.APPEND) - } else { - Array(StandardOpenOption.CREATE, StandardOpenOption.WRITE) - } - var currentPos, writtenBytes = 0L - - GpuTaskMetrics.get.spillToDiskTime { - withResource(FileChannel.open(path.toPath, option: _*)) { fc => - currentPos = fc.position() - withResource(Channels.newOutputStream(fc)) { os => - withResource(diskBlockManager.getSerializerManager() - .wrapStream(incoming.id, os)) { cos => - val outputChannel = Channels.newChannel(cos) - writtenBytes = fileWritable.writeToChannel(outputChannel, stream) - } - } - (currentPos, writtenBytes, path.length() - currentPos) - } - } - case other => - throw new IllegalStateException( - s"Unable to write $other to file") - } - } - - /** - * A RapidsDiskBuffer that is mean to represent device-bound memory. This - * buffer can produce a device-backed ColumnarBatch. - */ - class RapidsDiskBuffer( - id: RapidsBufferId, - fileOffset: Long, - uncompressedSize: Long, - onDiskSizeInBytes: Long, - meta: TableMeta, - spillPriority: Long) - extends RapidsBufferBase(id, meta, spillPriority) { - - // FIXME: Need to be clean up. Tracked in https://github.com/NVIDIA/spark-rapids/issues/9496 - override val memoryUsedBytes: Long = uncompressedSize - - override val storageTier: StorageTier = StorageTier.DISK - - override def getMemoryBuffer: MemoryBuffer = synchronized { - require(onDiskSizeInBytes > 0, - s"$this attempted an invalid 0-byte mmap of a file") - val path = id.getDiskPath(diskBlockManager) - val serializerManager = diskBlockManager.getSerializerManager() - val memBuffer = if (serializerManager.isRapidsSpill(id)) { - // Only go through serializerManager's stream wrapper for spill case - closeOnExcept(HostAlloc.alloc(uncompressedSize)) { - decompressed => GpuTaskMetrics.get.readSpillFromDiskTime { - withResource(FileChannel.open(path.toPath, StandardOpenOption.READ)) { c => - c.position(fileOffset) - withResource(Channels.newInputStream(c)) { compressed => - withResource(serializerManager.wrapStream(id, compressed)) { in => - withResource(new HostMemoryOutputStream(decompressed)) { out => - IOUtils.copy(in, out) - } - decompressed - } - } - } - } - } - } else { - // Reserved mmap read fashion for UCX shuffle path. Also it's skipping encryption and - // compression. - HostMemoryBuffer.mapFile(path, MapMode.READ_WRITE, fileOffset, onDiskSizeInBytes) - } - memBuffer - } - - override def close(): Unit = synchronized { - super.close() - } - - override protected def releaseResources(): Unit = { - logDebug(s"releasing resources for disk buffer $id of size $memoryUsedBytes bytes") - val metrics = GpuTaskMetrics.get - metrics.decDiskBytesAllocated(memoryUsedBytes) - logDebug(reportDiskAllocMetrics(metrics)) - - // Buffers that share paths must be cleaned up elsewhere - if (id.canShareDiskPaths) { - sharedBufferFiles.remove(id) - } else { - val path = id.getDiskPath(diskBlockManager) - if (!path.delete() && path.exists()) { - logWarning(s"Unable to delete spill path $path") - } - } - } - } - - /** - * A RapidsDiskBuffer that should remain in the host, producing host-backed - * ColumnarBatch if the caller invokes getHostColumnarBatch, but not producing - * anything on the device. - */ - class RapidsDiskColumnarBatch( - id: RapidsBufferId, - fileOffset: Long, - size: Long, - uncompressedSize: Long, - // TODO: remove meta - meta: TableMeta, - spillPriority: Long) - extends RapidsDiskBuffer( - id, fileOffset, size, uncompressedSize, meta, spillPriority) - with RapidsHostBatchBuffer { - - override def getMemoryBuffer: MemoryBuffer = - throw new IllegalStateException( - "Called getMemoryBuffer on a disk buffer that needs deserialization") - - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = - throw new IllegalStateException( - "Called getColumnarBatch on a disk buffer that needs deserialization") - - override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - require(fileOffset == 0, - "Attempted to obtain a HostColumnarBatch from a spilled RapidsBuffer that is sharing " + - "paths on disk") - val path = id.getDiskPath(diskBlockManager) - withResource(new FileInputStream(path)) { fis => - withResource(diskBlockManager.getSerializerManager() - .wrapStream(id, fis)) { fs => - val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fs) - val hostCols = withResource(hostBuffer) { _ => - SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes) - } - new ColumnarBatch(hostCols.toArray, header.getNumRows) - } - } - } - } -} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala deleted file mode 100644 index 235ed9ddb45..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala +++ /dev/null @@ -1,484 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.DataOutputStream -import java.nio.channels.{Channels, WritableByteChannel} -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.mutable - -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, JCudfSerialization, MemoryBuffer} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, freeOnExcept, withResource} -import com.nvidia.spark.rapids.SpillPriorities.{applyPriorityOffset, HOST_MEMORY_BUFFER_SPILL_OFFSET} -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta - -import org.apache.spark.TaskContext -import org.apache.spark.sql.rapids.GpuTaskMetrics -import org.apache.spark.sql.rapids.storage.RapidsStorageUtils -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * A buffer store using host memory. - * @param maxSize maximum size in bytes for all buffers in this store - */ -class RapidsHostMemoryStore( - maxSize: Option[Long]) - extends RapidsBufferStore(StorageTier.HOST) { - - override protected def spillableOnAdd: Boolean = false - - override def getMaxSize: Option[Long] = maxSize - - def addBuffer( - id: RapidsBufferId, - buffer: HostMemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBuffer = { - buffer.incRefCount() - val rapidsBuffer = new RapidsHostMemoryBuffer( - id, - buffer.getLength, - tableMeta, - initialSpillPriority, - buffer) - freeOnExcept(rapidsBuffer) { _ => - logDebug(s"Adding host buffer for: [id=$id, size=${buffer.getLength}, " + - s"uncompressed=${rapidsBuffer.meta.bufferMeta.uncompressedSize}, " + - s"meta_id=${tableMeta.bufferMeta.id}, " + - s"meta_size=${tableMeta.bufferMeta.size}]") - addBuffer(rapidsBuffer, needsSync) - rapidsBuffer - } - } - - def addBatch(id: RapidsBufferId, - hostCb: ColumnarBatch, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBuffer = { - RapidsHostColumnVector.incRefCounts(hostCb) - val rapidsBuffer = new RapidsHostColumnarBatch( - id, - hostCb, - initialSpillPriority) - freeOnExcept(rapidsBuffer) { _ => - addBuffer(rapidsBuffer, needsSync) - rapidsBuffer - } - } - - override protected def trySpillToMaximumSize( - buffer: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Boolean = { - maxSize.forall { ms => - // this spillStore has a maximum size requirement (host only). We need to spill from it - // in order to make room for `buffer`. - val targetTotalSize = ms - buffer.memoryUsedBytes - if (targetTotalSize < 0) { - // lets not spill to host when the buffer we are about - // to spill is larger than our limit - false - } else { - val amountSpilled = synchronousSpill(targetTotalSize, catalog, stream) - if (amountSpilled != 0) { - logDebug(s"Task ${TaskContext.get.taskAttemptId()} spilled $amountSpilled bytes from" + - s"${name} to make room for ${buffer.id}") - } - // if after spill we can fit the new buffer, return true - buffer.memoryUsedBytes <= (ms - currentSize) - } - } - } - - override protected def createBuffer( - other: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] = { - val wouldFit = trySpillToMaximumSize(other, catalog, stream) - if (!wouldFit) { - // skip host - logWarning(s"Buffer $other with size ${other.memoryUsedBytes} does not fit " + - s"in the host store, skipping tier.") - None - } else { - // If the other is from the local disk store, we are unspilling to host memory. - if (other.storageTier == StorageTier.DISK) { - logDebug(s"copying RapidsDiskStore buffer ${other.id} to a HostMemoryBuffer") - val hostBuffer = other.getMemoryBuffer.asInstanceOf[HostMemoryBuffer] - Some(new RapidsHostMemoryBuffer( - other.id, - hostBuffer.getLength(), - other.meta, - applyPriorityOffset(other.getSpillPriority, HOST_MEMORY_BUFFER_SPILL_OFFSET), - hostBuffer)) - } else { - withResource(other.getCopyIterator) { otherBufferIterator => - val isChunked = otherBufferIterator.isChunked - val totalCopySize = otherBufferIterator.getTotalCopySize - closeOnExcept(HostAlloc.tryAlloc(totalCopySize)) { hb => - hb.map { hostBuffer => - val spillNs = GpuTaskMetrics.get.spillToHostTime { - var hostOffset = 0L - val start = System.nanoTime() - while (otherBufferIterator.hasNext) { - val otherBuffer = otherBufferIterator.next() - withResource(otherBuffer) { _ => - otherBuffer match { - case devBuffer: DeviceMemoryBuffer => - hostBuffer.copyFromMemoryBufferAsync( - hostOffset, devBuffer, 0, otherBuffer.getLength, stream) - hostOffset += otherBuffer.getLength - case _ => - throw new IllegalStateException("copying from buffer without device memory") - } - } - } - stream.sync() - System.nanoTime() - start - } - val szMB = (totalCopySize.toDouble / 1024.0 / 1024.0).toLong - val bw = (szMB.toDouble / (spillNs.toDouble / 1000000000.0)).toLong - logDebug(s"Spill to host (chunked=$isChunked) " + - s"size=$szMB MiB bandwidth=$bw MiB/sec") - new RapidsHostMemoryBuffer( - other.id, - totalCopySize, - other.meta, - applyPriorityOffset(other.getSpillPriority, HOST_MEMORY_BUFFER_SPILL_OFFSET), - hostBuffer) - }.orElse { - // skip host - logWarning(s"Buffer $other with size ${other.memoryUsedBytes} does not fit " + - s"in the host store, skipping tier.") - None - } - } - } - } - } - } - - def numBytesFree: Option[Long] = maxSize.map(_ - currentSize) - - class RapidsHostMemoryBuffer( - id: RapidsBufferId, - size: Long, - meta: TableMeta, - spillPriority: Long, - buffer: HostMemoryBuffer) - extends RapidsBufferBase(id, meta, spillPriority) - with RapidsBufferChannelWritable - with MemoryBuffer.EventHandler { - override val storageTier: StorageTier = StorageTier.HOST - - override def getMemoryBuffer: MemoryBuffer = getHostMemoryBuffer - - override def getHostMemoryBuffer: HostMemoryBuffer = synchronized { - buffer.synchronized { - setSpillable(this, false) - buffer.incRefCount() - buffer - } - } - - override def writeToChannel(outputChannel: WritableByteChannel, ignored: Cuda.Stream): Long = { - var written: Long = 0L - val iter = new HostByteBufferIterator(buffer) - iter.foreach { bb => - try { - while (bb.hasRemaining) { - written += outputChannel.write(bb) - } - } finally { - RapidsStorageUtils.dispose(bb) - } - } - written - } - - override def updateSpillability(): Unit = { - if (buffer.getRefCount == 1) { - setSpillable(this, true) - } - } - - override protected def releaseResources(): Unit = { - buffer.close() - } - - /** The size of this buffer in bytes. */ - override val memoryUsedBytes: Long = size - - // If this require triggers, we are re-adding a `HostMemoryBuffer` outside of - // the catalog lock, which should not possible. The event handler is set to null - // when we free the `RapidsHostMemoryBuffer` and if the buffer is not free, we - // take out another handle (in the catalog). - HostAlloc.addEventHandler(buffer, this) - - /** - * Override from the MemoryBuffer.EventHandler interface. - * - * If we are being invoked we have the `buffer` lock, as this callback - * is being invoked from `MemoryBuffer.close` - * - * @param refCount - buffer's current refCount - */ - override def onClosed(refCount: Int): Unit = { - // refCount == 1 means only 1 reference exists to `buffer` in the - // RapidsHostMemoryBuffer (we own it) - if (refCount == 1) { - // setSpillable is being called here as an extension of `MemoryBuffer.close()` - // we hold the MemoryBuffer lock and we could be called from a Spark task thread - // Since we hold the MemoryBuffer lock, `incRefCount` waits for us. The only other - // call to `setSpillable` is also under this same MemoryBuffer lock (see: - // `getMemoryBuffer`) - setSpillable(this, true) - } - } - - /** - * We overwrite free to make sure we don't have a handler for the underlying - * buffer, since this `RapidsBuffer` is no longer tracked. - */ - override def free(): Unit = synchronized { - if (isValid) { - // it is going to be invalid when calling super.free() - HostAlloc.removeEventHandler(buffer, this) - } - super.free() - } - } - - /** - * A per cuDF host column event handler that handles calls to .close() - * inside of the `HostColumnVector` lock. - */ - class RapidsHostColumnEventHandler - extends HostColumnVector.EventHandler { - - // Every RapidsHostColumnarBatch that references this column has an entry in this map. - // The value represents the number of times (normally 1) that a ColumnVector - // appears in the RapidsHostColumnarBatch. This is also the HosColumnVector refCount at which - // the column is considered spillable. - // The map is protected via the ColumnVector lock. - private val registration = new mutable.HashMap[RapidsHostColumnarBatch, Int]() - - /** - * Every RapidsHostColumnarBatch iterates through its columns and either creates - * a `RapidsHostColumnEventHandler` object and associates it with the column's - * `eventHandler` or calls into the existing one, and registers itself. - * - * The registration has two goals: it accounts for repetition of a column - * in a `RapidsHostColumnarBatch`. If a batch has the same column repeated it must adjust - * the refCount at which this column is considered spillable. - * - * The second goal is to account for aliasing. If two host batches alias this column - * we are going to mark it as non spillable. - * - * @param rapidsHostCb - the host batch that is registering itself with this tracker - */ - def register(rapidsHostCb: RapidsHostColumnarBatch, repetition: Int): Unit = { - registration.put(rapidsHostCb, repetition) - } - - /** - * This is invoked during `RapidsHostColumnarBatch.free` in order to remove the entry - * in `registration`. - * - * @param rapidsHostCb - the batch that is de-registering itself - */ - def deregister(rapidsHostCb: RapidsHostColumnarBatch): Unit = { - registration.remove(rapidsHostCb) - } - - // called with the cudf HostColumnVector lock held from cuDF's side - override def onClosed(cudfCv: HostColumnVector, refCount: Int): Unit = { - // we only handle spillability if there is a single batch registered - // (no aliasing) - if (registration.size == 1) { - val (rapidsHostCb, spillableRefCount) = registration.head - if (spillableRefCount == refCount) { - rapidsHostCb.onColumnSpillable(cudfCv) - } - } - } - } - - /** - * A `RapidsHostColumnarBatch` is the spill store holder of ColumnarBatch backed by - * HostColumnVector. - * - * This class owns the host batch and will close it when `close` is called. - * - * @param id the `RapidsBufferId` this batch is associated with - * @param batch the host ColumnarBatch we are managing - * @param spillPriority a starting spill priority - */ - class RapidsHostColumnarBatch( - id: RapidsBufferId, - hostCb: ColumnarBatch, - spillPriority: Long) - extends RapidsBufferBase( - id, - null, - spillPriority) - with RapidsBufferChannelWritable - with RapidsHostBatchBuffer { - - override val storageTier: StorageTier = StorageTier.HOST - - // By default all columns are NOT spillable since we are not the only owners of - // the columns (the caller is holding onto a ColumnarBatch that will be closed - // after instantiation, triggering onClosed callbacks) - // This hash set contains the columns that are currently spillable. - private val columnSpillability = new ConcurrentHashMap[HostColumnVector, Boolean]() - - private val numDistinctColumns = RapidsHostColumnVector.extractBases(hostCb).distinct.size - - // we register our event callbacks as the very first action to deal with - // spillability - registerOnCloseEventHandler() - - /** Release the underlying resources for this buffer. */ - override protected def releaseResources(): Unit = { - hostCb.close() - } - - override def meta: TableMeta = { - null - } - - // This is the current size in batch form. It is to be used while this - // batch hasn't migrated to another store. - override val memoryUsedBytes: Long = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) - - /** - * Mark a column as spillable - * - * @param column the ColumnVector to mark as spillable - */ - def onColumnSpillable(column: HostColumnVector): Unit = { - columnSpillability.put(column, true) - updateSpillability() - } - - /** - * Update the spillability state of this RapidsHostColumnarBatch. This is invoked from - * two places: - * - * - from the onColumnSpillable callback, which is invoked from a - * HostColumnVector.EventHandler.onClosed callback. - * - * - after adding a batch to the store to mark the batch as spillable if - * all columns are spillable. - */ - override def updateSpillability(): Unit = { - setSpillable(this, columnSpillability.size == numDistinctColumns) - } - - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - throw new UnsupportedOperationException( - "RapidsHostColumnarBatch does not support getColumnarBatch") - } - - override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - columnSpillability.clear() - setSpillable(this, false) - RapidsHostColumnVector.incRefCounts(hostCb) - } - - override def getMemoryBuffer: MemoryBuffer = { - throw new UnsupportedOperationException( - "RapidsHostColumnarBatch does not support getMemoryBuffer") - } - - override def getCopyIterator: RapidsBufferCopyIterator = { - throw new UnsupportedOperationException( - "RapidsHostColumnarBatch does not support getCopyIterator") - } - - override def writeToChannel(outputChannel: WritableByteChannel, ignored: Cuda.Stream): Long = { - withResource(Channels.newOutputStream(outputChannel)) { outputStream => - withResource(new DataOutputStream(outputStream)) { dos => - val columns = RapidsHostColumnVector.extractBases(hostCb) - JCudfSerialization.writeToStream(columns, dos, 0, hostCb.numRows()) - dos.size() - } - } - } - - override def free(): Unit = { - // lets remove our handler from the chain of handlers for each column - removeOnCloseEventHandler() - super.free() - } - - private def registerOnCloseEventHandler(): Unit = { - val columns = RapidsHostColumnVector.extractBases(hostCb) - // cudfColumns could contain duplicates. We need to take this into account when we are - // deciding the floor refCount for a duplicated column - val repetitionPerColumn = new mutable.HashMap[HostColumnVector, Int]() - columns.foreach { col => - val repetitionCount = repetitionPerColumn.getOrElse(col, 0) - repetitionPerColumn(col) = repetitionCount + 1 - } - repetitionPerColumn.foreach { case (distinctCv, repetition) => - // lock the column because we are setting its event handler, and we are inspecting - // its refCount. - distinctCv.synchronized { - val eventHandler = distinctCv.getEventHandler match { - case null => - val eventHandler = new RapidsHostColumnEventHandler - distinctCv.setEventHandler(eventHandler) - eventHandler - case existing: RapidsHostColumnEventHandler => - existing - case other => - throw new IllegalStateException( - s"Invalid column event handler $other") - } - eventHandler.register(this, repetition) - if (repetition == distinctCv.getRefCount) { - onColumnSpillable(distinctCv) - } - } - } - } - - // this method is called from free() - private def removeOnCloseEventHandler(): Unit = { - val distinctColumns = RapidsHostColumnVector.extractBases(hostCb).distinct - distinctColumns.foreach { distinctCv => - distinctCv.synchronized { - distinctCv.getEventHandler match { - case eventHandler: RapidsHostColumnEventHandler => - eventHandler.deregister(this) - case t => - throw new IllegalStateException( - s"Invalid column event handler $t") - } - } - } - } - } -} - - diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsSerializerManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsSerializerManager.scala index ab4a6398d32..74aed062142 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsSerializerManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsSerializerManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,8 +20,8 @@ import java.io.{InputStream, OutputStream} import org.apache.spark.SparkConf import org.apache.spark.io.CompressionCodec -import org.apache.spark.sql.rapids.TempSpillBufferId import org.apache.spark.sql.rapids.execution.TrampolineUtil +import org.apache.spark.storage.BlockId /** @@ -44,22 +44,20 @@ class RapidsSerializerManager (conf: SparkConf) { private lazy val compressionCodec: CompressionCodec = TrampolineUtil.createCodec(conf) - // Whether it really goes through crypto streams replies on Spark configuration - // (e.g., `` `spark.io.encryption.enabled` ``) and the existence of crypto keys. - def wrapStream(bufferId: RapidsBufferId, s: OutputStream): OutputStream = { - if(isRapidsSpill(bufferId)) wrapForCompression(bufferId, wrapForEncryption(s)) else s + def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = { + if(isRapidsSpill(blockId)) wrapForCompression(blockId, wrapForEncryption(s)) else s } - def wrapStream(bufferId: RapidsBufferId, s: InputStream): InputStream = { - if(isRapidsSpill(bufferId)) wrapForCompression(bufferId, wrapForEncryption(s)) else s + def wrapStream(blockId: BlockId, s: InputStream): InputStream = { + if(isRapidsSpill(blockId)) wrapForCompression(blockId, wrapForEncryption(s)) else s } - private[this] def wrapForCompression(bufferId: RapidsBufferId, s: InputStream): InputStream = { - if (shouldCompress(bufferId)) compressionCodec.compressedInputStream(s) else s + private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } - private[this] def wrapForCompression(bufferId: RapidsBufferId, s: OutputStream): OutputStream = { - if (shouldCompress(bufferId)) compressionCodec.compressedOutputStream(s) else s + private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } private[this] def wrapForEncryption(s: InputStream): InputStream = { @@ -70,18 +68,15 @@ class RapidsSerializerManager (conf: SparkConf) { if (serializerManager != null) serializerManager.wrapForEncryption(s) else s } - def isRapidsSpill(bufferId: RapidsBufferId): Boolean = { - bufferId match { - case _: TempSpillBufferId => true - case _ => false - } + def isRapidsSpill(blockId: BlockId): Boolean = { + !blockId.isShuffle } - private[this] def shouldCompress(bufferId: RapidsBufferId): Boolean = { - bufferId match { - case _: TempSpillBufferId => compressSpill - case _: ShuffleBufferId | _: ShuffleReceivedBufferId => false - case _ => false + private[this] def shouldCompress(blockId: BlockId): Boolean = { + if (!blockId.isShuffle) { + compressSpill + } else { + false } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index a587c5cd7ae..6c17a29fdc9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,52 +16,76 @@ package com.nvidia.spark.rapids -import java.io.File import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import java.util.function.{Consumer, IntUnaryOperator} import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer} +import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, Table} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.TableMeta -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.rapids.RapidsDiskBlockManager import org.apache.spark.sql.rapids.execution.TrampolineUtil +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.ShuffleBlockId /** Identifier for a shuffle buffer that holds the data for a table */ case class ShuffleBufferId( blockId: ShuffleBlockId, - override val tableId: Int) extends RapidsBufferId { + tableId: Int) { val shuffleId: Int = blockId.shuffleId val mapId: Long = blockId.mapId - - override val canShareDiskPaths: Boolean = true - - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = { - diskBlockManager.getFile(blockId) - } } /** Catalog for lookup of shuffle buffers by block ID */ -class ShuffleBufferCatalog( - catalog: RapidsBufferCatalog, - diskBlockManager: RapidsDiskBlockManager) extends Logging { +class ShuffleBufferCatalog extends Logging { + /** + * Information stored for each active shuffle. + * NOTE: ArrayBuffer in blockMap must be explicitly locked when using it! + * + * @param blockMap mapping of block ID to array of buffers for the block + */ + private case class ShuffleInfo( + blockMap: ConcurrentHashMap[ShuffleBlockId, ArrayBuffer[ShuffleBufferId]]) + + private val bufferIdToHandle = + new ConcurrentHashMap[ + ShuffleBufferId, + (Option[SpillableDeviceBufferHandle], TableMeta)]() + + /** shuffle information for each active shuffle */ + private[this] val activeShuffles = new ConcurrentHashMap[Int, ShuffleInfo] - private val bufferIdToHandle = new ConcurrentHashMap[RapidsBufferId, RapidsBufferHandle]() + /** Mapping of table ID to shuffle buffer ID */ + private[this] val tableMap = new ConcurrentHashMap[Int, ShuffleBufferId] + + /** Tracks the next table identifier */ + private[this] val tableIdCounter = new AtomicInteger(0) private def trackCachedHandle( bufferId: ShuffleBufferId, - bufferHandle: RapidsBufferHandle): Unit = { - bufferIdToHandle.put(bufferId, bufferHandle) + handle: SpillableDeviceBufferHandle, + meta: TableMeta): Unit = { + bufferIdToHandle.put(bufferId, (Some(handle), meta)) + } + + private def trackDegenerate(bufferId: ShuffleBufferId, + meta: TableMeta): Unit = { + bufferIdToHandle.put(bufferId, (None, meta)) } def removeCachedHandles(): Unit = { - bufferIdToHandle.forEach { (_, handle) => removeBuffer(handle) } + val bufferIt = bufferIdToHandle.keySet().iterator() + while (bufferIt.hasNext) { + val buffer = bufferIt.next() + val (maybeHandle, _) = bufferIdToHandle.remove(buffer) + tableMap.remove(buffer.tableId) + maybeHandle.foreach(_.close()) + } } /** @@ -70,56 +94,48 @@ class ShuffleBufferCatalog( * The refcount of the underlying device buffer will be incremented so the contiguous table * can be closed before this buffer is destroyed. * - * @param blockId Spark's `ShuffleBlockId` that identifies this buffer - * @param contigTable contiguous table to track in storage + * @param blockId Spark's `ShuffleBlockId` that identifies this buffer + * @param contigTable contiguous table to track in storage * @param initialSpillPriority starting spill priority value for the buffer - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) * @return RapidsBufferHandle identifying this table */ - def addContiguousTable( - blockId: ShuffleBlockId, - contigTable: ContiguousTable, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = { - val bufferId = nextShuffleBufferId(blockId) + def addContiguousTable(blockId: ShuffleBlockId, + contigTable: ContiguousTable, + initialSpillPriority: Long): Unit = { withResource(contigTable) { _ => - val handle = catalog.addContiguousTable( - bufferId, - contigTable, - initialSpillPriority, - needsSync) - trackCachedHandle(bufferId, handle) - handle + val bufferId = nextShuffleBufferId(blockId) + val tableMeta = MetaUtils.buildTableMeta(bufferId.tableId, contigTable) + val buff = contigTable.getBuffer + buff.incRefCount() + val handle = SpillableDeviceBufferHandle(buff) + trackCachedHandle(bufferId, handle, tableMeta) } } /** * Adds a buffer to the device storage, taking ownership of the buffer. * - * @param blockId Spark's `ShuffleBlockId` that identifies this buffer - * @param buffer buffer that will be owned by the store - * @param tableMeta metadata describing the buffer layout + * @param blockId Spark's `ShuffleBlockId` that identifies this buffer + * @param buffer buffer that will be owned by the store + * @param tableMeta metadata describing the buffer layout * @param initialSpillPriority starting spill priority value for the buffer * @return RapidsBufferHandle associated with this buffer */ - def addBuffer( - blockId: ShuffleBlockId, - buffer: DeviceMemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = { - val bufferId = nextShuffleBufferId(blockId) - // update the table metadata for the buffer ID generated above - tableMeta.bufferMeta.mutateId(bufferId.tableId) - val handle = catalog.addBuffer( - bufferId, - buffer, - tableMeta, - initialSpillPriority, - needsSync) - trackCachedHandle(bufferId, handle) - handle + def addCompressedBatch( + blockId: ShuffleBlockId, + compressedBatch: ColumnarBatch, + initialSpillPriority: Long): Unit = { + withResource(compressedBatch) { _ => + val bufferId = nextShuffleBufferId(blockId) + val compressed = compressedBatch.column(0).asInstanceOf[GpuCompressedColumnVector] + val tableMeta = compressed.getTableMeta + // update the table metadata for the buffer ID generated above + tableMeta.bufferMeta().mutateId(bufferId.tableId) + val buff = compressed.getTableBuffer + buff.incRefCount() + val handle = SpillableDeviceBufferHandle(buff) + trackCachedHandle(bufferId, handle, tableMeta) + } } /** @@ -128,30 +144,11 @@ class ShuffleBufferCatalog( */ def addDegenerateRapidsBuffer( blockId: ShuffleBlockId, - meta: TableMeta): RapidsBufferHandle = { + meta: TableMeta): Unit = { val bufferId = nextShuffleBufferId(blockId) - val handle = catalog.registerDegenerateBuffer(bufferId, meta) - trackCachedHandle(bufferId, handle) - handle + trackDegenerate(bufferId, meta) } - /** - * Information stored for each active shuffle. - * NOTE: ArrayBuffer in blockMap must be explicitly locked when using it! - * - * @param blockMap mapping of block ID to array of buffers for the block - */ - private case class ShuffleInfo( - blockMap: ConcurrentHashMap[ShuffleBlockId, ArrayBuffer[ShuffleBufferId]]) - - /** shuffle information for each active shuffle */ - private[this] val activeShuffles = new ConcurrentHashMap[Int, ShuffleInfo] - - /** Mapping of table ID to shuffle buffer ID */ - private[this] val tableMap = new ConcurrentHashMap[Int, ShuffleBufferId] - - /** Tracks the next table identifier */ - private[this] val tableIdCounter = new AtomicInteger(0) /** * Register a new shuffle. @@ -174,22 +171,11 @@ class ShuffleBufferCatalog( // NOTE: Not synchronizing array buffer because this shuffle should be inactive. bufferIds.foreach { id => tableMap.remove(id.tableId) - val handle = bufferIdToHandle.remove(id) - if (handle != null) { - handle.close() - } + val handleAndMeta = bufferIdToHandle.remove(id) + handleAndMeta._1.foreach(_.close()) } } info.blockMap.forEachValue(Long.MaxValue, bufferRemover) - - val fileRemover: Consumer[ShuffleBlockId] = { blockId => - val file = diskBlockManager.getFile(blockId) - logDebug(s"Deleting file $file") - if (!file.delete() && file.exists()) { - logWarning(s"Unable to delete $file") - } - } - info.blockMap.forEachKey(Long.MaxValue, fileRemover) } else { // currently shuffle unregister can get called on the driver which never saw a register if (!TrampolineUtil.isDriver(SparkEnv.get)) { @@ -201,10 +187,10 @@ class ShuffleBufferCatalog( def hasActiveShuffle(shuffleId: Int): Boolean = activeShuffles.containsKey(shuffleId) /** Get all the buffer IDs that correspond to a shuffle block identifier. */ - def blockIdToBuffersIds(blockId: ShuffleBlockId): Array[ShuffleBufferId] = { + private def blockIdToBuffersIds(blockId: ShuffleBlockId): Array[ShuffleBufferId] = { val info = activeShuffles.get(blockId.shuffleId) if (info == null) { - throw new NoSuchElementException(s"unknown shuffle $blockId.shuffleId") + throw new NoSuchElementException(s"unknown shuffle ${blockId.shuffleId}") } val entries = info.blockMap.get(blockId) if (entries == null) { @@ -215,27 +201,61 @@ class ShuffleBufferCatalog( } } - def blockIdToBufferHandles(blockId: ShuffleBlockId): Array[RapidsBufferHandle] = { + def getColumnarBatchIterator( + blockId: ShuffleBlockId, + sparkTypes: Array[DataType]): Iterator[ColumnarBatch] = { + val bufferIDs = blockIdToBuffersIds(blockId) + bufferIDs.iterator.map { bId => + GpuSemaphore.acquireIfNecessary(TaskContext.get) + val (maybeHandle, meta) = bufferIdToHandle.get(bId) + maybeHandle.map { handle => + withResource(handle.materialize) { buff => + val bufferMeta = meta.bufferMeta() + if (bufferMeta == null || bufferMeta.codecBufferDescrsLength == 0) { + MetaUtils.getBatchFromMeta(buff, meta, sparkTypes) + } else { + GpuCompressedColumnVector.from(buff, meta) + } + } + }.getOrElse { + // degenerate table (handle is None) + // make a batch out of denegerate meta + val rowCount = meta.rowCount + val packedMeta = meta.packedMetaAsByteBuffer() + if (packedMeta != null) { + withResource(DeviceMemoryBuffer.allocate(0)) { deviceBuffer => + withResource(Table.fromPackedTable( + meta.packedMetaAsByteBuffer(), deviceBuffer)) { table => + GpuColumnVectorFromBuffer.from(table, deviceBuffer, meta, sparkTypes) + } + } + } else { + // no packed metadata, must be a table with zero columns + new ColumnarBatch(Array.empty, rowCount.toInt) + } + } + } + } + + /** Get all the buffer metadata that correspond to a shuffle block identifier. */ + def blockIdToMetas(blockId: ShuffleBlockId): Seq[TableMeta] = { val info = activeShuffles.get(blockId.shuffleId) if (info == null) { - throw new NoSuchElementException(s"unknown shuffle $blockId.shuffleId") + throw new NoSuchElementException(s"unknown shuffle ${blockId.shuffleId}") } val entries = info.blockMap.get(blockId) if (entries == null) { throw new NoSuchElementException(s"unknown shuffle block $blockId") } - entries.synchronized { - entries.map(bufferIdToHandle.get).toArray + entries.synchronized { + entries.map(bufferIdToHandle.get).map { case (_, meta) => + meta + } } } - /** Get all the buffer metadata that correspond to a shuffle block identifier. */ - def blockIdToMetas(blockId: ShuffleBlockId): Seq[TableMeta] = { - blockIdToBuffersIds(blockId).map(catalog.getBufferMeta) - } - /** Allocate a new shuffle buffer identifier and update the shuffle block mapping. */ - def nextShuffleBufferId(blockId: ShuffleBlockId): ShuffleBufferId = { + private def nextShuffleBufferId(blockId: ShuffleBlockId): ShuffleBufferId = { val info = activeShuffles.get(blockId.shuffleId) if (info == null) { throw new IllegalStateException(s"unknown shuffle ${blockId.shuffleId}") @@ -258,35 +278,29 @@ class ShuffleBufferCatalog( } /** Lookup the shuffle buffer handle that corresponds to the specified table identifier. */ - def getShuffleBufferHandle(tableId: Int): RapidsBufferHandle = { + def getShuffleBufferHandle(tableId: Int): RapidsShuffleHandle = { val shuffleBufferId = tableMap.get(tableId) if (shuffleBufferId == null) { throw new NoSuchElementException(s"unknown table ID $tableId") } - bufferIdToHandle.get(shuffleBufferId) + val (maybeHandle, meta) = bufferIdToHandle.get(shuffleBufferId) + maybeHandle match { + case Some(spillable) => + RapidsShuffleHandle(spillable, meta) + case None => + throw new IllegalStateException( + "a buffer handle could not be obtained for a degenerate buffer") + } } /** * Update the spill priority of a shuffle buffer that soon will be read locally. * @param handle shuffle buffer handle of buffer to update */ - def updateSpillPriorityForLocalRead(handle: RapidsBufferHandle): Unit = { - handle.setSpillPriority(SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY) - } - - /** - * Lookup the shuffle buffer that corresponds to the specified buffer handle and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * @param handle shuffle buffer handle - * @return shuffle buffer that has been acquired - */ - def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = { - val buffer = catalog.acquireBuffer(handle) - // Shuffle buffers that have been read are less likely to be read again, - // so update the spill priority based on this access - handle.setSpillPriority(SpillPriorities.getShuffleOutputBufferReadPriority) - buffer - } + // TODO: AB: priorities + //def updateSpillPriorityForLocalRead(handle: RapidsBufferHandle): Unit = { + // handle.setSpillPriority(SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY) + //} /** * Remove a buffer and table given a buffer handle @@ -294,9 +308,7 @@ class ShuffleBufferCatalog( * the handle being removed is not being utilized by another thread. * @param handle buffer handle */ - def removeBuffer(handle: RapidsBufferHandle): Unit = { - val id = handle.id - tableMap.remove(id.tableId) + def removeBuffer(handle: SpillableHandle): Unit = { handle.close() } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala index 0ff4f9278be..87c8deb2c5d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,49 +16,23 @@ package com.nvidia.spark.rapids -import java.io.File -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicInteger -import java.util.function.IntUnaryOperator - -import ai.rapids.cudf.DeviceMemoryBuffer +import ai.rapids.cudf.{DeviceMemoryBuffer, Table} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.TableMeta import org.apache.spark.internal.Logging -import org.apache.spark.sql.rapids.RapidsDiskBlockManager - -/** Identifier for a shuffle buffer that holds the data for a table on the read side */ - -case class ShuffleReceivedBufferId( - override val tableId: Int) extends RapidsBufferId { - override val canShareDiskPaths: Boolean = false +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = { - diskBlockManager.getFile(s"temp_shuffle_${tableId}") +case class RapidsShuffleHandle( + spillable: SpillableDeviceBufferHandle, tableMeta: TableMeta) extends AutoCloseable { + override def close(): Unit = { + spillable.close() } } /** Catalog for lookup of shuffle buffers by block ID */ -class ShuffleReceivedBufferCatalog( - catalog: RapidsBufferCatalog) extends Logging { - - /** Mapping of table ID to shuffle buffer ID */ - private[this] val tableMap = new ConcurrentHashMap[Int, ShuffleReceivedBufferId] - - /** Tracks the next table identifier */ - private[this] val tableIdCounter = new AtomicInteger(0) - - /** Allocate a new shuffle buffer identifier and update the shuffle block mapping. */ - private def nextShuffleReceivedBufferId(): ShuffleReceivedBufferId = { - val tableId = tableIdCounter.getAndUpdate(ShuffleReceivedBufferCatalog.TABLE_ID_UPDATER) - val id = ShuffleReceivedBufferId(tableId) - val prev = tableMap.put(tableId, id) - if (prev != null) { - throw new IllegalStateException(s"table ID $tableId is already in use") - } - id - } +class ShuffleReceivedBufferCatalog() extends Logging { /** * Adds a buffer to the device storage, taking ownership of the buffer. @@ -70,64 +44,52 @@ class ShuffleReceivedBufferCatalog( * @param initialSpillPriority starting spill priority value for the buffer * @param needsSync tells the store a synchronize in the current stream is required * before storing this buffer - * @return RapidsBufferHandle associated with this buffer + * @return RapidsShuffleHandle associated with this buffer */ def addBuffer( buffer: DeviceMemoryBuffer, tableMeta: TableMeta, - initialSpillPriority: Long, - needsSync: Boolean): RapidsBufferHandle = { - val bufferId = nextShuffleReceivedBufferId() - tableMeta.bufferMeta.mutateId(bufferId.tableId) - // when we call `addBuffer` the store will incRefCount - withResource(buffer) { _ => - catalog.addBuffer( - bufferId, - buffer, - tableMeta, - initialSpillPriority, - needsSync) - } + initialSpillPriority: Long): RapidsShuffleHandle = { + RapidsShuffleHandle(SpillableDeviceBufferHandle(buffer), tableMeta) } /** - * Adds a degenerate buffer (zero rows or columns) + * Adds a degenerate batch (zero rows or columns), described only by metadata. * * @param meta metadata describing the buffer layout - * @return RapidsBufferHandle associated with this buffer + * @return RapidsShuffleHandle associated with this buffer */ - def addDegenerateRapidsBuffer( - meta: TableMeta): RapidsBufferHandle = { - val bufferId = nextShuffleReceivedBufferId() - catalog.registerDegenerateBuffer(bufferId, meta) + def addDegenerateBatch(meta: TableMeta): RapidsShuffleHandle = { + RapidsShuffleHandle(null, meta) } - /** - * Lookup the shuffle buffer that corresponds to the specified shuffle buffer - * handle and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * - * @param handle shuffle buffer handle - * @return shuffle buffer that has been acquired - */ - def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = catalog.acquireBuffer(handle) - - /** - * Remove a buffer and table given a buffer handle - * NOTE: This function is not thread safe! The caller should only invoke if - * the handle being removed is not being utilized by another thread. - * @param handle buffer handle - */ - def removeBuffer(handle: RapidsBufferHandle): Unit = { - val id = handle.id - tableMap.remove(id.tableId) - handle.close() - } -} - -object ShuffleReceivedBufferCatalog{ - private val MAX_TABLE_ID = Integer.MAX_VALUE - private val TABLE_ID_UPDATER = new IntUnaryOperator { - override def applyAsInt(i: Int): Int = if (i < MAX_TABLE_ID) i + 1 else 0 + def getColumnarBatchAndRemove(handle: RapidsShuffleHandle, + sparkTypes: Array[DataType]): (ColumnarBatch, Long) = { + withResource(handle) { _ => + val spillable = handle.spillable + var memoryUsedBytes = 0L + val cb = if (spillable != null) { + memoryUsedBytes = spillable.sizeInBytes + withResource(spillable.materialize) { buff => + MetaUtils.getBatchFromMeta(buff, handle.tableMeta, sparkTypes) + } + } else { + val rowCount = handle.tableMeta.rowCount + val packedMeta = handle.tableMeta.packedMetaAsByteBuffer() + if (packedMeta != null) { + withResource(DeviceMemoryBuffer.allocate(0)) { deviceBuffer => + withResource(Table.fromPackedTable( + handle.tableMeta.packedMetaAsByteBuffer(), deviceBuffer)) { table => + GpuColumnVectorFromBuffer.from( + table, deviceBuffer, handle.tableMeta, sparkTypes) + } + } + } else { + // no packed metadata, must be a table with zero columns + new ColumnarBatch(Array.empty, rowCount.toInt) + } + } + (cb, memoryUsedBytes) + } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala new file mode 100644 index 00000000000..aa529e83bbf --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala @@ -0,0 +1,1606 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.io.{DataOutputStream, File, FileInputStream, InputStream, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.{Channels, FileChannel, WritableByteChannel} +import java.nio.file.StandardOpenOption +import java.util +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable + +import ai.rapids.cudf.{ColumnVector, ContiguousTable, Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange, Table} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableSeq +import com.nvidia.spark.rapids.format.TableMeta +import org.apache.commons.io.IOUtils + +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.rapids.{GpuTaskMetrics, RapidsDiskBlockManager} +import org.apache.spark.sql.rapids.execution.SerializedHostTableUtils +import org.apache.spark.sql.rapids.storage.RapidsStorageUtils +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.storage.BlockId + +/** + * Spark-RAPIDS Spill Framework + * + * The spill framework tracks device/host/disk object lifecycle in the RAPIDS Accelerator + * for Apache Spark. A set of stores is used to track these objects, which are wrapped in + * "handles" that describe the state of each to the user and to the framework. + * + * This file comment covers some pieces of the framework that are worth knowing up front. + * + * Ownership: + * + * Any object handed to the framework via the factory methods for each of the handles + * should not be used directly by the user. The framework takes ownership of all objects. + * To get a reference back, call the `materialize` method, and always close what the framework + * returns. + * + * CUDA/Host synchronization: + * + * We assume all device backed handles are completely materialized on the device (before adding + * to the store, the CUDA stream has been synchronized with the CPU thread creating the handle). + * + * We assume all host memory backed handles are completely materialized and not mutated by + * other CPU threads once handed to the framework. + * + * Spillability: + * + * An object is spillable (it will be copied to host or disk during OOM) if: + * - it has a sizeInBytes > 0 + * - it is not actively being referenced by the user (call to `materialize`, or aliased) + * - it hasn't already spilled + * - it hasn't been closed + * + * Aliasing: + * + * We handle aliasing of objects, either in the spill framework or outside, by looking at the + * reference count. All objects added to the store should support a ref count. + * If the ref count is greater than the expected value, we assume it is being aliased, + * and therefore we don't waste time spilling the aliased object. + * + * Materialization: + * + * Every store handle supports a `materialize` method that isn't part of the interface. + * The reason is that to materialize certain objects, you may need some data (for example, + * spark schema descriptors). `materialize` incRefCounts the object if it's resident in the + * intended store (`DeviceSpillableHandle` incRefCounts an object if it is in the device store), + * and otherwise it will create a new copy from the spilled version and hand it to the user. + * Any time a user calls `materialize`, they are responsible for closing the returned object. + * + * Spilling: + * + * A `SpillableHandle` will track an object in a specific store (`DeviceSpillable` tracks + * device "intended" objects) for example. If the handle is asked to spill, it is the handle's + * responsibility to initiate that spill, and to track the spilled handle (a device spillable + * would have a `host` handle, which tracks the host spilled object). + * + * A cascade of spills can occur device -> host -> disk, given that host allocations can fail, or + * could not fit in the SpillableHostStore's limit (if defined). In this case, the call to spill + * will either create a host handle tracking an object on the host store (if we made room), or it + * will create a host handle that points to a disk handle, tracking a file on disk. + * + * If the disk is full, we do not handle this in any special way. We expect this to be a + * terminal state to the executor. Every handle spills to its own file on disk, identified + * as a "temporary block" `BlockId` from Spark. + * + * Notes on locking: + * + * All stores use a concurrent hash map to store instances of `StoreHandle`. The only store + * with extra locking is the `SpillableHostStore`, to maintain a `totalSize` number that is + * used to figure out cheaply when it is full. + * + * Handles hold a lock to protect the user against when it is either in the middle of + * spilling, or closed. Most handles, except for disk handles, hold a reference to an object + * in their respective store (`SpillableDeviceBufferHandle` has a `dev` reference that holds + * a `DeviceMemoryBuffer`, and a `host` reference to `SpillableHostBufferHandle` that is only + * set if spilled), and the handle guarantees that one of these will be set, or none if the handle + * is closed. + * + * We hold the handle lock when we are spilling (performing IO). That means that no other consumer + * can access this spillable device handle while it is being spilled, including a second thread + * trying to spill and asking each of the handles wether they are spillable or not, as that requires + * the handle lock. We will relax this likely in follow on work. + * + * We never hold a store-wide coarse grain lock in the stores when we do IO. + */ + +/** + * Common interface for all handles in the spill framework. + */ +trait StoreHandle extends AutoCloseable { + /** + * Approximate size of this handle, used in two scenarios: + * - Used from the host store to figure out how much host memory total it is tracking. + * - If sizeInBytes is 0, the object is tracked by the stores so it can be + * removed on shutdown, or by handle.close, but 0-byte handles are not spillable. + */ + val sizeInBytes: Long +} + +trait SpillableHandle extends StoreHandle { + /** + * Method called to spill this handle. It can be triggered from the spill store, + * or directly against the handle. + * @return sizeInBytes if spilled, 0 for any other reason (not spillable, closed) + */ + def spill: Long + + /** + * Method used to determine whether a handle tracks an object that could be spilled + * @note At the level of `SpillableHandle`, the only requirement of spillability + * is that the size of the handle is > 0. `sizeInBytes` is known at construction, and + * is a `val`. + * @return true if currently spillable, false otherwise + */ + def spillable: Boolean = sizeInBytes > 0 +} + +/** + * Spillable handles that can be materialized on the device. + * @tparam T an auto closeable subclass. `dev` tracks an instance of this object, + * on the device. + */ +trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { + var dev: Option[T] + + override def spillable: Boolean = synchronized { + super.spillable && dev.isDefined + } +} + +/** + * Spillable handles that can be materialized on the host. + * @tparam T an auto closeable subclass. `host` tracks an instance of this object, + * on the host. + */ +trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle { + var host: Option[T] + + override def spillable: Boolean = synchronized { + super.spillable && host.isDefined + } +} + +object SpillableHostBufferHandle extends Logging { + def apply(hmb: HostMemoryBuffer): SpillableHostBufferHandle = { + val handle = new SpillableHostBufferHandle(hmb.getLength, host = Some(hmb)) + SpillFramework.stores.hostStore.track(handle) + // do we care if: + // handle didn't fit in the store as it is too large. + // we made memory for this so we are going to hold our noses and keep going + // we could spill `handle` at this point. + handle + } + + def createHostHandleWithPacker( + chunkedPacker: ChunkedPacker): SpillableHostBufferHandle = { + val handle = new SpillableHostBufferHandle(chunkedPacker.getTotalContiguousSize) + withResource( + SpillFramework.stores.hostStore.makeBuilder(handle)) { builder => + SpillFramework.withChunkedPackBounceBuffer { bb => + while (chunkedPacker.hasNext) { + withResource(chunkedPacker.next(bb)) { n => + builder.copyNext(n, Cuda.DEFAULT_STREAM) + // we are calling chunked packer on `bb` again each time, we need + // to synchronize before we ask for the next chunk + Cuda.DEFAULT_STREAM.sync() + } + } + } + builder.build + } + } + + def createHostHandleFromDeviceBuff( + buff: DeviceMemoryBuffer): SpillableHostBufferHandle = { + val handle = new SpillableHostBufferHandle(buff.getLength) + withResource( + SpillFramework.stores.hostStore.makeBuilder(handle)) { builder => + builder.copyNext(buff, Cuda.DEFAULT_STREAM) + Cuda.DEFAULT_STREAM.sync() + builder.build + } + } +} + +class SpillableHostBufferHandle private ( + override val sizeInBytes: Long, + override var host: Option[HostMemoryBuffer] = None, + var disk: Option[DiskHandle] = None) + extends HostSpillableHandle[HostMemoryBuffer] { + + override def spillable: Boolean = synchronized { + if (super.spillable) { + host.getOrElse { + throw new IllegalStateException( + s"$this is spillable but it doesn't have a materialized host buffer!") + }.getRefCount == 1 + } else { + false + } + } + + def materialize: HostMemoryBuffer = { + var materialized: HostMemoryBuffer = null + var diskHandle: DiskHandle = null + synchronized { + if (host.isDefined) { + materialized = host.get + materialized.incRefCount() + } else if (disk.isDefined) { + diskHandle = disk.get + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (materialized == null) { + materialized = closeOnExcept(HostMemoryBuffer.allocate(sizeInBytes)) { hmb => + diskHandle.materializeToHostMemoryBuffer(hmb) + hmb + } + } + materialized + } + + override def spill: Long = { + if (!spillable) { + 0L + } else { + synchronized { + if (disk.isEmpty) { + withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => + val outputChannel = diskHandleBuilder.getChannel + GpuTaskMetrics.get.spillToDiskTime { + withResource(getHostBuffer.get) { hmb => + val iter = new HostByteBufferIterator(hmb) + iter.foreach { bb => + try { + while (bb.hasRemaining) { + outputChannel.write(bb) + } + } finally { + RapidsStorageUtils.dispose(bb) + } + } + } + } + disk = Some(diskHandleBuilder.build) + disk + } + } else { + None + } + } + releaseHostResource() + sizeInBytes + } + } + + private def releaseHostResource(): Unit = { + SpillFramework.removeFromHostStore(this) + synchronized { + host.foreach(_.close()) + host = None + } + } + + def getHostBuffer: Option[HostMemoryBuffer] = synchronized { + host.foreach(_.incRefCount()) + host + } + + override def close(): Unit = { + releaseHostResource() + synchronized { + disk.foreach(_.close()) + disk = None + } + } + + def materializeToDeviceMemoryBuffer(dmb: DeviceMemoryBuffer): Unit = { + var hostBuffer: HostMemoryBuffer = null + var diskHandle: DiskHandle = null + synchronized { + if (host.isDefined) { + hostBuffer = host.get + hostBuffer.incRefCount() + } else if (disk.isDefined) { + diskHandle = disk.get + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (hostBuffer != null) { + GpuTaskMetrics.get.readSpillFromHostTime { + withResource(hostBuffer) { _ => + dmb.copyFromHostBuffer( + /*dstOffset*/ 0, + /*src*/ hostBuffer, + /*srcOffset*/ 0, + /*length*/ hostBuffer.getLength) + } + } + } else { + // cannot find a full host buffer, get chunked api + // from disk + diskHandle.materializeToDeviceMemoryBuffer(dmb) + } + } + + def setHost(singleShotBuffer: HostMemoryBuffer): Unit = synchronized { + host = Some(singleShotBuffer) + } + + def setDisk(handle: DiskHandle): Unit = synchronized { + disk = Some(handle) + } +} + +object SpillableDeviceBufferHandle { + def apply(dmb: DeviceMemoryBuffer): SpillableDeviceBufferHandle = { + val handle = new SpillableDeviceBufferHandle(dmb.getLength, dev = Some(dmb)) + SpillFramework.stores.deviceStore.track(handle) + handle + } +} + +class SpillableDeviceBufferHandle private ( + override val sizeInBytes: Long, + override var dev: Option[DeviceMemoryBuffer], + var host: Option[SpillableHostBufferHandle] = None) + extends DeviceSpillableHandle[DeviceMemoryBuffer] { + + override def spillable: Boolean = synchronized { + if (super.spillable) { + dev.getOrElse { + throw new IllegalStateException( + s"$this is spillable but it doesn't have a dev buffer!") + }.getRefCount == 1 + } else { + false + } + } + + def materialize: DeviceMemoryBuffer = { + var materialized: DeviceMemoryBuffer = null + var hostHandle: SpillableHostBufferHandle = null + synchronized { + if (dev.isDefined) { + materialized = dev.get + materialized.incRefCount() + } else if (host.isDefined) { + // since we spilled, host must be set. + hostHandle = host.get + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + // if `materialized` is null, we spilled. This is a terminal + // state, as we are not allowing unspill, and we don't need + // to hold locks while we copy back from here. + if (materialized == null) { + materialized = closeOnExcept(DeviceMemoryBuffer.allocate(sizeInBytes)) { dmb => + hostHandle.materializeToDeviceMemoryBuffer(dmb) + dmb + } + } + materialized + } + + override def spill: Long = { + if (!spillable) { + 0L + } else { + synchronized { + if (host.isEmpty) { + host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(dev.get)) + } + } + releaseDeviceResources() + sizeInBytes + } + } + + private def releaseDeviceResources(): Unit = { + SpillFramework.removeFromDeviceStore(this) + synchronized { + dev.foreach(_.close()) + dev = None + } + } + + override def close(): Unit = { + releaseDeviceResources() + synchronized { + host.foreach(_.close()) + host = None + } + } +} + +class SpillableColumnarBatchHandle private ( + override val sizeInBytes: Long, + override var dev: Option[ColumnarBatch], + var host: Option[SpillableHostBufferHandle] = None) + extends DeviceSpillableHandle[ColumnarBatch] with Logging { + + override def spillable: Boolean = synchronized { + if (super.spillable) { + val dcvs = GpuColumnVector.extractBases(dev.get) + val colRepetition = mutable.HashMap[ColumnVector, Int]() + dcvs.foreach { hcv => + colRepetition.put(hcv, colRepetition.getOrElse(hcv, 0) + 1) + } + dcvs.forall(dcv => { + colRepetition(dcv) == dcv.getRefCount + }) + } else { + false + } + } + + private var meta: Option[ByteBuffer] = None + + def materialize(dt: Array[DataType]): ColumnarBatch = { + var materialized: ColumnarBatch = null + var hostHandle: SpillableHostBufferHandle = null + synchronized { + if (dev.isDefined) { + materialized = GpuColumnVector.incRefCounts(dev.get) + } else if (host.isDefined) { + hostHandle = host.get + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (materialized == null) { + val devBuffer = closeOnExcept(DeviceMemoryBuffer.allocate(hostHandle.sizeInBytes)) { dmb => + hostHandle.materializeToDeviceMemoryBuffer(dmb) + dmb + } + val cb = withResource(devBuffer) { _ => + withResource(Table.fromPackedTable(meta.get, devBuffer)) { tbl => + GpuColumnVector.from(tbl, dt) + } + } + materialized = cb + } + materialized + } + + override def spill: Long = { + if (!spillable) { + 0L + } else { + synchronized { + if (host.isEmpty) { + withChunkedPacker { chunkedPacker => + meta = Some(chunkedPacker.getPackedMeta) + host = Some(SpillableHostBufferHandle.createHostHandleWithPacker(chunkedPacker)) + } + } + } + releaseDeviceResource() + // We return the size we were created with. This is not the actual size + // of this batch when it is packed, and it is used by the calling code + // to figure out more or less how much did we free in the device. + sizeInBytes + } + } + + private def withChunkedPacker[T](body: ChunkedPacker => T): T = { + val tbl = synchronized { + if (dev.isEmpty) { + throw new IllegalStateException("cannot get copier without a batch") + } + GpuColumnVector.from(dev.get) + } + withResource(tbl) { _ => + withResource(new ChunkedPacker(tbl, SpillFramework.chunkedPackBounceBufferSize)) { packer => + body(packer) + } + } + } + + private def releaseDeviceResource(): Unit = { + SpillFramework.removeFromDeviceStore(this) + synchronized { + dev.foreach(_.close()) + dev = None + } + } + + override def close(): Unit = { + releaseDeviceResource() + synchronized { + host.foreach(_.close()) + host = None + } + } +} + +object SpillableColumnarBatchFromBufferHandle { + def apply( + ct: ContiguousTable, + dataTypes: Array[DataType]): SpillableColumnarBatchFromBufferHandle = { + withResource(ct) { _ => + val sizeInBytes = ct.getBuffer.getLength + val cb = GpuColumnVectorFromBuffer.from(ct, dataTypes) + val handle = new SpillableColumnarBatchFromBufferHandle( + sizeInBytes, dev = Some(cb)) + SpillFramework.stores.deviceStore.track(handle) + handle + } + } + + def apply(cb: ColumnarBatch): SpillableColumnarBatchFromBufferHandle = { + require(GpuColumnVectorFromBuffer.isFromBuffer(cb), + "Columnar batch isn't a batch from buffer") + val sizeInBytes = + cb.column(0).asInstanceOf[GpuColumnVectorFromBuffer].getBuffer.getLength + val handle = new SpillableColumnarBatchFromBufferHandle( + sizeInBytes, dev = Some(cb)) + SpillFramework.stores.deviceStore.track(handle) + handle + } +} + +class SpillableColumnarBatchFromBufferHandle private ( + override val sizeInBytes: Long, + override var dev: Option[ColumnarBatch], + var host: Option[SpillableHostBufferHandle] = None) + extends DeviceSpillableHandle[ColumnarBatch] { + + private var meta: Option[TableMeta] = None + + override def spillable: Boolean = synchronized { + if (super.spillable) { + val dcvs = GpuColumnVector.extractBases(dev.get) + val colRepetition = mutable.HashMap[ColumnVector, Int]() + dcvs.foreach { hcv => + colRepetition.put(hcv, colRepetition.getOrElse(hcv, 0) + 1) + } + dcvs.forall(dcv => { + colRepetition(dcv) == dcv.getRefCount + }) + } else { + false + } + } + + def materialize(dt: Array[DataType]): ColumnarBatch = { + var materialized: ColumnarBatch = null + var hostHandle: SpillableHostBufferHandle = null + synchronized { + if (dev.isDefined) { + materialized = GpuColumnVector.incRefCounts(dev.get) + } else if (host.isDefined) { + hostHandle = host.get + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (materialized == null) { + val devBuffer = closeOnExcept(DeviceMemoryBuffer.allocate(hostHandle.sizeInBytes)) { dmb => + hostHandle.materializeToDeviceMemoryBuffer(dmb) + dmb + } + val cb = withResource(devBuffer) { _ => + withResource(Table.fromPackedTable(meta.get.packedMetaAsByteBuffer(), devBuffer)) { tbl => + GpuColumnVector.from(tbl, dt) + } + } + materialized = cb + } + materialized + } + + override def spill: Long = { + if (!spillable) { + 0 + } else { + synchronized { + if (host.isEmpty) { + val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuColumnVectorFromBuffer] + meta = Some(cvFromBuffer.getTableMeta) + host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( + cvFromBuffer.getBuffer)) + } + } + releaseDeviceResource() + sizeInBytes + } + } + + private def releaseDeviceResource(): Unit = { + SpillFramework.removeFromDeviceStore(this) + synchronized { + dev.foreach(_.close()) + dev = None + } + } + + override def close(): Unit = { + releaseDeviceResource() + synchronized { + host.foreach(_.close()) + host = None + } + } +} + +object SpillableCompressedColumnarBatchHandle { + def apply(cb: ColumnarBatch): SpillableCompressedColumnarBatchHandle = { + require(GpuCompressedColumnVector.isBatchCompressed(cb), + "Tried to track a compressed batch, but the batch wasn't compressed") + val compressedSize = + cb.column(0).asInstanceOf[GpuCompressedColumnVector].getTableBuffer.getLength + val handle = new SpillableCompressedColumnarBatchHandle(compressedSize, dev = Some(cb)) + SpillFramework.stores.deviceStore.track(handle) + handle + } +} + +class SpillableCompressedColumnarBatchHandle private ( + val compressedSizeInBytes: Long, + override var dev: Option[ColumnarBatch], + var host: Option[SpillableHostBufferHandle] = None) + extends DeviceSpillableHandle[ColumnarBatch] { + + override val sizeInBytes: Long = compressedSizeInBytes + + protected var meta: Option[TableMeta] = None + + override def spillable: Boolean = synchronized { + if (super.spillable) { + val cb = dev.get + val buff = cb.column(0).asInstanceOf[GpuCompressedColumnVector].getTableBuffer + buff.getRefCount == 1 + } else { + false + } + } + + def materialize: ColumnarBatch = { + var materialized: ColumnarBatch = null + var hostHandle: SpillableHostBufferHandle = null + synchronized { + if (dev.isDefined) { + materialized = GpuCompressedColumnVector.incRefCounts(dev.get) + } else if (host.isDefined) { + hostHandle = host.get + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (materialized == null) { + val devBuffer = closeOnExcept(DeviceMemoryBuffer.allocate(hostHandle.sizeInBytes)) { dmb => + hostHandle.materializeToDeviceMemoryBuffer(dmb) + dmb + } + materialized = withResource(devBuffer) { _ => + GpuCompressedColumnVector.from(devBuffer, meta.get) + } + } + materialized + } + + override def spill: Long = { + if (!spillable) { + 0L + } else { + synchronized { + if (host.isEmpty) { + val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuCompressedColumnVector] + meta = Some(cvFromBuffer.getTableMeta) + host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( + cvFromBuffer.getTableBuffer)) + } + } + releaseDeviceResource() + compressedSizeInBytes + } + } + + private def releaseDeviceResource(): Unit = { + SpillFramework.removeFromDeviceStore(this) + synchronized { + dev.foreach(_.close()) + dev = None + } + } + + override def close(): Unit = { + releaseDeviceResource() + synchronized { + host.foreach(_.close()) + host = None + meta = None + } + } +} + +object SpillableHostColumnarBatchHandle { + def apply(cb: ColumnarBatch): SpillableHostColumnarBatchHandle = { + val sizeInBytes = RapidsHostColumnVector.getTotalHostMemoryUsed(cb) + val handle = new SpillableHostColumnarBatchHandle(sizeInBytes, cb.numRows(), host = Some(cb)) + SpillFramework.stores.hostStore.track(handle) + handle + } +} + +class SpillableHostColumnarBatchHandle private ( + val sizeInBytes: Long, + val numRows: Int, + override var host: Option[ColumnarBatch], + var disk: Option[DiskHandle] = None) + extends HostSpillableHandle[ColumnarBatch] { + + override def spillable: Boolean = synchronized { + if (super.spillable) { + val hcvs = RapidsHostColumnVector.extractBases(host.get) + val colRepetition = mutable.HashMap[HostColumnVector, Int]() + hcvs.foreach { hcv => + colRepetition.put(hcv, colRepetition.getOrElse(hcv, 0) + 1) + } + hcvs.forall(hcv => { + colRepetition(hcv) == hcv.getRefCount + }) + } else { + false + } + } + + def materialize(sparkTypes: Array[DataType]): ColumnarBatch = { + var materialized: ColumnarBatch = null + var diskHandle: DiskHandle = null + synchronized { + if (host.isDefined) { + materialized = RapidsHostColumnVector.incRefCounts(host.get) + } else if (disk.isDefined) { + diskHandle = disk.get + } else { + throw new IllegalStateException( + "attempting to materialize a closed handle") + } + } + if (materialized == null) { + materialized = diskHandle.withInputWrappedStream { inputStream => + val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(inputStream) + val hostCols = withResource(hostBuffer) { _ => + SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes) + } + new ColumnarBatch(hostCols.toArray, numRows) + } + } + materialized + } + + override def spill: Long = { + if (!spillable) { + 0L + } else { + synchronized { + if (disk.isEmpty) { + withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => + GpuTaskMetrics.get.spillToDiskTime { + val dos = diskHandleBuilder.getDataOutputStream + val columns = RapidsHostColumnVector.extractBases(host.get) + JCudfSerialization.writeToStream(columns, dos, 0, host.get.numRows()) + } + disk = Some(diskHandleBuilder.build) + disk + } + } else { + None + } + } + releaseHostResource() + sizeInBytes + } + } + + private def releaseHostResource(): Unit = { + SpillFramework.removeFromHostStore(this) + synchronized { + host.foreach(_.close()) + host = None + } + } + + override def close(): Unit = { + releaseHostResource() + synchronized { + disk.foreach(_.close()) + disk = None + } + } +} + +object DiskHandle { + def apply(blockId: BlockId, + offset: Long, + diskSizeInBytes: Long): DiskHandle = { + val handle = new DiskHandle( + blockId, offset, diskSizeInBytes) + SpillFramework.stores.diskStore.track(handle) + handle + } +} + +/** + * A disk buffer handle helps us track spill-framework originated data on disk. + * This type of handle isn't spillable, and therefore it just implements `StoreHandle` + * @param blockId - a spark `BlockId` obtained from the configured `BlockManager` + * @param offset - starting offset for the data within the file backing `blockId` + * @param sizeInBytes - amount of bytes on disk (usually compressed and could also be encrypted). + */ +class DiskHandle private( + val blockId: BlockId, + val offset: Long, + override val sizeInBytes: Long) + extends StoreHandle { + + private def withInputChannel[T](body: FileChannel => T): T = synchronized { + val file = SpillFramework.stores.diskStore.diskBlockManager.getFile(blockId) + GpuTaskMetrics.get.readSpillFromDiskTime { + withResource(new FileInputStream(file)) { fs => + withResource(fs.getChannel) { channel => + body(channel) + } + } + } + } + + def withInputWrappedStream[T](body: InputStream => T): T = synchronized { + val diskBlockManager = SpillFramework.stores.diskStore.diskBlockManager + val serializerManager = diskBlockManager.getSerializerManager() + GpuTaskMetrics.get.readSpillFromDiskTime { + withInputChannel { inputChannel => + inputChannel.position(offset) + withResource(Channels.newInputStream(inputChannel)) { compressed => + withResource(serializerManager.wrapStream(blockId, compressed)) { in => + body(in) + } + } + } + } + } + + override def close(): Unit = { + SpillFramework.removeFromDiskStore(this) + SpillFramework.stores.diskStore.deleteFile(blockId) + } + + def materializeToHostMemoryBuffer(mb: HostMemoryBuffer): Unit = { + withInputWrappedStream { in => + withResource(new HostMemoryOutputStream(mb)) { out => + IOUtils.copy(in, out) + } + } + } + + def materializeToDeviceMemoryBuffer(dmb: DeviceMemoryBuffer): Unit = { + var copyOffset = 0L + withInputWrappedStream { in => + SpillFramework.withHostSpillBounceBuffer { hmb => + val bbLength = hmb.getLength.toInt + withResource(new HostMemoryOutputStream(hmb)) { out => + var sizeRead = IOUtils.copyLarge(in, out, 0, bbLength) + while (sizeRead > 0) { + // this syncs at every copy, since for now we are + // reusing a single host spill bounce buffer + dmb.copyFromHostBuffer( + /*dstOffset*/ copyOffset, + /*src*/ hmb, + /*srcOffset*/ 0, + /*length*/ sizeRead) + out.seek(0) // start over + copyOffset += sizeRead + sizeRead = IOUtils.copyLarge(in, out, 0, bbLength) + } + } + } + } + } +} + +trait HandleStore[T <: StoreHandle] extends AutoCloseable with Logging { + protected val handles = new ConcurrentHashMap[T, java.lang.Boolean]() + + def numHandles: Int = { + handles.size() + } + + def track(handle: T): Unit = { + doTrack(handle) + } + + def remove(handle: T): Unit = { + doRemove(handle) + } + + protected def doTrack(handle: T): Boolean = { + handles.put(handle, true) == null + } + + protected def doRemove(handle: T): Boolean = { + handles.remove(handle) != null + } + + override def close(): Unit = { + handles.forEach((handle, _ )=> { + handle.close() + }) + handles.clear() + } +} + +trait SpillableStore extends HandleStore[SpillableHandle] with Logging { + protected def spillNvtxRange: NvtxRange + + def spill(spillNeeded: Long): Long = { + if (spillNeeded == 0) { + 0L + } else { + withResource(spillNvtxRange) { _ => + var amountSpilled = 0L + val spillables = new util.ArrayList[SpillableHandle]() + var amountToSpill = 0L + val allHandles = handles.keySet().iterator() + // two threads could be here trying to spill and creating a list of spillables + while (allHandles.hasNext && amountToSpill < spillNeeded) { + val handle = allHandles.next() + if (handle.spillable) { + amountToSpill += handle.sizeInBytes + spillables.add(handle) + } + } + val it = spillables.iterator() + var numSpilled = 0 + while (it.hasNext) { + val handle = it.next() + val spilled = handle.spill + if (spilled > 0) { + // this thread was successful at spilling handle. + amountSpilled += spilled + numSpilled += 1 + } // else, either: + // - this thread lost the race and the handle was closed + // - another thread spilled it + // - the handle isn't spillable anymore, due to ref count. + } + + amountSpilled + } + } + } +} + +class SpillableHostStore(val maxSize: Option[Long] = None) + extends SpillableStore + with Logging { + + private var totalSize: Long = 0L + + private def tryTrack(handle: SpillableHandle): Boolean = { + if (maxSize.isEmpty || handle.sizeInBytes == 0) { + super.doTrack(handle) + // for now, keep this totalSize part, we technically + // do not need to track `totalSize` if we don't have a limit + synchronized { + totalSize += handle.sizeInBytes + } + true + } else { + synchronized { + val storeMaxSize = maxSize.get + if (totalSize + handle.sizeInBytes > storeMaxSize) { + // we want to try to make room for this buffer + false + } else { + // it fits + if (super.doTrack(handle)) { + totalSize += handle.sizeInBytes + } + true + } + } + } + } + + override def track(handle: SpillableHandle): Unit = { + trackInternal(handle) + } + + private def trackInternal(handle: SpillableHandle): Boolean = { + // try to track the handle: in the case of no limits + // this should just be add to the store + var tracked = false + tracked = tryTrack(handle) + if (!tracked) { + // we only end up here if we have host store limits. + var numRetries = 0 + // we are going to try to track again, in a loop, + // since we want to release + var canFit = true + val handleSize = handle.sizeInBytes + var amountSpilled = 0L + val hadHandlesToSpill = !handles.isEmpty + while (canFit && !tracked && numRetries < 5) { + // if we are trying to add a handle larger than our limit + if (maxSize.get < handleSize) { + // no point in checking how much is free, just spill all + // we have + amountSpilled += spill(maxSize.get) + } else { + // handleSize is within the limits + val freeAmount = synchronized { + maxSize.get - totalSize + } + val spillNeeded = handleSize - freeAmount + if (spillNeeded > 0) { + amountSpilled += spill(spillNeeded) + } + } + tracked = tryTrack(handle) + if (!tracked) { + // we tried to spill, and we still couldn't fit this buffer + // if we have a totalSize > 0, we could try some more + // the disk api + synchronized { + canFit = totalSize > 0 + } + } + numRetries += 1 + } + val taskId = Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(0) + if (hadHandlesToSpill) { + logInfo(s"Task $taskId spilled $amountSpilled bytes while trying to " + + s"track $handleSize bytes.") + } + } + tracked + } + + override def remove(handle: SpillableHandle): Unit = { + synchronized { + if (doRemove(handle)) { + totalSize -= handle.sizeInBytes + } + } + } + + /** + * Makes a builder object for `SpillableHostBufferHandle`. The builder will + * either copy ot host or disk, if the host buffer fits in the host store (if tracking + * is enabled). + * + * Host store locks and disk store locks will be taken/released during this call, but + * after the builder is created, no locks are held in the store. + * + * @param handle a host handle that only has a size set, and no backing store. + * @return the builder to be closed by caller + */ + def makeBuilder(handle: SpillableHostBufferHandle): SpillableHostBufferHandleBuilder = { + var builder: Option[SpillableHostBufferHandleBuilder] = None + if (handle.sizeInBytes < maxSize.getOrElse(Long.MaxValue)) { + HostAlloc.tryAlloc(handle.sizeInBytes).foreach { hmb => + withResource(hmb) { _ => + if (trackInternal(handle)) { + hmb.incRefCount + // the host store made room or fit this buffer + builder = Some(new SpillableHostBufferHandleBuilderForHost(handle, hmb)) + } + } + } + } + builder.getOrElse { + // the disk store will track this when we call .build + new SpillableHostBufferHandleBuilderForDisk(handle) + } + } + + trait SpillableHostBufferHandleBuilder extends AutoCloseable { + /** + * Copy the entirety of `mb` (0 to getLength) to host or disk. + * + * We synchronize after each copy since we do not manage the lifetime + * of `mb`. + * + * @param mb buffer to copy from + * @param stream CUDA stream to use, and synchronize against + */ + def copyNext(mb: DeviceMemoryBuffer, stream: Cuda.Stream): Unit + + /** + * Returns a usable `SpillableHostBufferHandle` with either the + * `host` or `disk` set with the appropriate object. + * + * Note that if we are writing to disk, we are going to add a + * new `DiskHandle` in the disk store's concurrent collection. + * + * @return host handle with data in host or disk + */ + def build: SpillableHostBufferHandle + } + + private class SpillableHostBufferHandleBuilderForHost( + var handle: SpillableHostBufferHandle, + var singleShotBuffer: HostMemoryBuffer) + extends SpillableHostBufferHandleBuilder with Logging { + private var copied = 0L + + override def copyNext(mb: DeviceMemoryBuffer, stream: Cuda.Stream): Unit = { + GpuTaskMetrics.get.spillToHostTime { + singleShotBuffer.copyFromMemoryBufferAsync( + copied, + mb, + 0, + mb.getLength, + stream) + copied += mb.getLength + } + } + + override def build: SpillableHostBufferHandle = { + // add some sort of setter method to Host Handle + require(handle != null, "Called build too many times") + require(copied == handle.sizeInBytes, + s"Expected ${handle.sizeInBytes} B but copied ${copied} B instead") + handle.setHost(singleShotBuffer) + singleShotBuffer = null + val res = handle + handle = null + res + } + + override def close(): Unit = { + if (handle != null) { + handle.close() + handle = null + } + if (singleShotBuffer != null) { + singleShotBuffer.close() + singleShotBuffer = null + } + } + } + + private class SpillableHostBufferHandleBuilderForDisk( + var handle: SpillableHostBufferHandle) + extends SpillableHostBufferHandleBuilder { + private var copied = 0L + private var diskHandleBuilder = DiskHandleStore.makeBuilder + + override def copyNext(mb: DeviceMemoryBuffer, stream: Cuda.Stream): Unit = { + SpillFramework.withHostSpillBounceBuffer { hostSpillBounceBuffer => + GpuTaskMetrics.get.spillToDiskTime { + val outputChannel = diskHandleBuilder.getChannel + val iter = new MemoryBufferToHostByteBufferIterator( + mb, + hostSpillBounceBuffer, + Cuda.DEFAULT_STREAM) + iter.foreach { byteBuff => + try { + while (byteBuff.hasRemaining) { + outputChannel.write(byteBuff) + } + copied += byteBuff.capacity() + } finally { + RapidsStorageUtils.dispose(byteBuff) + } + } + } + } + } + + override def build: SpillableHostBufferHandle = { + // add some sort of setter method to Host Handle + require(handle != null, "Called build too many times") + require(copied == handle.sizeInBytes, + s"Expected ${handle.sizeInBytes} B but copied ${copied} B instead") + handle.setDisk(diskHandleBuilder.build) + val res = handle + handle = null + res + } + + override def close(): Unit = { + if (handle != null) { + handle.close() + handle = null + } + if (diskHandleBuilder!= null) { + diskHandleBuilder.close() + diskHandleBuilder = null + } + } + } + + override protected def spillNvtxRange: NvtxRange = + new NvtxRange("disk spill", NvtxColor.RED) +} + +class SpillableDeviceStore extends SpillableStore { + override protected def spillNvtxRange: NvtxRange = + new NvtxRange("device spill", NvtxColor.ORANGE) +} + +class DiskHandleStore(conf: SparkConf) + extends HandleStore[DiskHandle] with Logging { + val diskBlockManager: RapidsDiskBlockManager = new RapidsDiskBlockManager(conf) + + def getFile(blockId: BlockId): File = { + diskBlockManager.getFile(blockId) + } + + def deleteFile(blockId: BlockId): Unit = { + val file = getFile(blockId) + file.delete() + if (file.exists()) { + logWarning(s"Unable to delete $file") + } + } + + // TODO: AB: yeah GpuTaskMetrics are messed up + override def track(handle: DiskHandle): Unit = { + // protects the off chance that someone adds this handle twice.. + if (doTrack(handle)) { + GpuTaskMetrics.get.incDiskBytesAllocated(handle.sizeInBytes) + } + } + + override def remove(handle: DiskHandle): Unit = { + // protects the off chance that someone removes this handle twice.. + if (doRemove(handle)) { + GpuTaskMetrics.get.decDiskBytesAllocated(handle.sizeInBytes) + } + } +} + +object DiskHandleStore { + /** + * An object that knows how to write a block to disk in Spark. + * It supports + * @param blockId the BlockManager `BlockId` to use. + * @param startPos the position to start writing from, useful if we can + * share files + */ + class DiskHandleBuilder(val blockId: BlockId, + val startPos: Long = 0L) extends AutoCloseable { + private val file = SpillFramework.stores.diskStore.getFile(blockId) + + private val serializerManager = + SpillFramework.stores.diskStore.diskBlockManager.getSerializerManager() + + // this is just to make sure we use DiskWriter once and we are not leaking + // as it is, we could use `DiskWriter` to start writing at other offsets + private var closed = false + + private var fc: FileChannel = null + + private def getFileChannel: FileChannel = { + val options = Seq(StandardOpenOption.CREATE, StandardOpenOption.WRITE) + fc = FileChannel.open(file.toPath, options:_*) + // seek to the starting pos + fc.position(startPos) + fc + } + + private def wrapChannel(channel: FileChannel): OutputStream = { + val os = Channels.newOutputStream(channel) + serializerManager.wrapStream(blockId, os) + } + + private var outputChannel: WritableByteChannel = _ + private var outputStream: DataOutputStream = _ + + def getChannel: WritableByteChannel = { + require(!closed, "Cannot write to closed DiskWriter") + require(outputStream == null, + "either channel or data output stream supported, but not both") + if (outputChannel != null) { + outputChannel + } else { + val fc = getFileChannel + val wrappedStream = closeOnExcept(fc)(wrapChannel) + outputChannel = closeOnExcept(wrappedStream)(Channels.newChannel) + outputChannel + } + } + + def getDataOutputStream: DataOutputStream = { + require(!closed, "Cannot write to closed DiskWriter") + require(outputStream == null, + "either channel or data output stream supported, but not both") + if (outputStream != null) { + outputStream + } else { + val fc = getFileChannel + val wrappedStream = closeOnExcept(fc)(wrapChannel) + outputStream = new DataOutputStream(wrappedStream) + outputStream + } + } + + override def close(): Unit = { + if (closed) { + throw new IllegalStateException("already closed DiskWriter") + } + if (outputStream != null) { + outputStream.close() + outputStream = null + } + if (outputChannel != null) { + outputChannel.close() + outputChannel = null + } + closed = true + } + + def build: DiskHandle = + DiskHandle( + blockId, + startPos, + fc.position() - startPos) + } + + def makeBuilder: DiskHandleBuilder = { + val blockId = BlockId(s"temp_local_${UUID.randomUUID().toString}") + new DiskHandleBuilder(blockId) + } +} + +trait SpillableStores extends AutoCloseable { + var deviceStore: SpillableDeviceStore + var hostStore: SpillableHostStore + var diskStore: DiskHandleStore + override def close(): Unit = { + Seq(deviceStore, hostStore, diskStore).safeClose() + } +} + +/** + * A spillable that is meant to be interacted with from the device. + */ +object SpillableColumnarBatchHandle { + def apply(tbl: Table, dataTypes: Array[DataType]): SpillableColumnarBatchHandle = { + withResource(tbl) { _ => + SpillableColumnarBatchHandle(GpuColumnVector.from(tbl, dataTypes)) + } + } + + def apply(cb: ColumnarBatch): SpillableColumnarBatchHandle = { + require(!GpuColumnVectorFromBuffer.isFromBuffer(cb), + "A SpillableColumnarBatchHandle doesn't support cuDF packed batches") + require(!GpuCompressedColumnVector.isBatchCompressed(cb), + "A SpillableColumnarBatchHandle doesn't support comprssed batches") + val sizeInBytes = GpuColumnVector.getTotalDeviceMemoryUsed(cb) + val handle = new SpillableColumnarBatchHandle(sizeInBytes, dev = Some(cb)) + SpillFramework.stores.deviceStore.track(handle) + handle + } +} + +object SpillFramework extends Logging { + // public for tests + var storesInternal: SpillableStores = _ + + def stores = { + if (storesInternal == null) { + throw new IllegalStateException( + "Cannot use SpillFramework without calling SpillFramework.initialize first") + } + storesInternal + } + + // used by the chunked packer to construct the cuDF-side packer object + var chunkedPackBounceBufferSize: Long = -1L + + // TODO: these should be pools, instead of individual buffers + private var hostSpillBounceBuffer: HostMemoryBuffer = _ + private var chunkedPackBounceBuffer: DeviceMemoryBuffer = _ + + private lazy val conf: SparkConf = { + val env = SparkEnv.get + if (env != null) { + env.conf + } else { + // For some unit tests + new SparkConf() + } + } + + def initialize(rapidsConf: RapidsConf): Unit = synchronized { + val hostSpillStorageSize = if (rapidsConf.offHeapLimitEnabled) { + // Disable the limit because it is handled by the RapidsHostMemoryStore + None + } else if (rapidsConf.hostSpillStorageSize == -1) { + // + 1 GiB by default to match backwards compatibility + Some(rapidsConf.pinnedPoolSize + (1024L * 1024 * 1024)) + } else { + Some(rapidsConf.hostSpillStorageSize) + } + chunkedPackBounceBufferSize = rapidsConf.chunkedPackBounceBufferSize + // this should hopefully be pinned, but it would work without + hostSpillBounceBuffer = HostMemoryBuffer.allocate(rapidsConf.spillToDiskBounceBufferSize) + chunkedPackBounceBuffer = DeviceMemoryBuffer.allocate(rapidsConf.chunkedPackBounceBufferSize) + storesInternal = new SpillableStores { + override var deviceStore: SpillableDeviceStore = new SpillableDeviceStore + override var hostStore: SpillableHostStore = new SpillableHostStore(hostSpillStorageSize) + override var diskStore: DiskHandleStore = new DiskHandleStore(conf) + } + val hostSpillStorageSizeStr = hostSpillStorageSize.map(sz => s"$sz B").getOrElse("unlimited") + logInfo(s"Initialized SpillFramework. Host spill store max size is: $hostSpillStorageSizeStr.") + } + + def shutdown(): Unit = { + if (hostSpillBounceBuffer != null) { + hostSpillBounceBuffer.close() + hostSpillBounceBuffer = null + } + if (chunkedPackBounceBuffer != null) { + chunkedPackBounceBuffer.close() + chunkedPackBounceBuffer = null + } + if (storesInternal != null) { + storesInternal.close() + storesInternal = null + } + } + + def withHostSpillBounceBuffer[T](body: HostMemoryBuffer => T): T = + hostSpillBounceBuffer.synchronized { + body(hostSpillBounceBuffer) + } + + def withChunkedPackBounceBuffer[T](body: DeviceMemoryBuffer => T): T = + chunkedPackBounceBuffer.synchronized { + body(chunkedPackBounceBuffer) + } + + def removeFromDeviceStore(value: SpillableHandle): Unit = { + // if the stores have already shut down, we don't want to create them here + val deviceStore = synchronized { + Option(storesInternal).map(_.deviceStore) + } + deviceStore.foreach(_.remove(value)) + } + + def removeFromHostStore(value: SpillableHandle): Unit = { + // if the stores have already shut down, we don't want to create them here + val hostStore = synchronized { + Option(storesInternal).map(_.hostStore) + } + hostStore.foreach(_.remove(value)) + } + + def removeFromDiskStore(value: DiskHandle): Unit = { + // if the stores have already shut down, we don't want to create them here + val maybeStores = synchronized { + Option(storesInternal).map(_.diskStore) + } + maybeStores.foreach(_.remove(value)) + } +} + +/** + * ChunkedPacker is an Iterator-like class that uses a cudf::chunked_pack to copy a cuDF `Table` + * to a target buffer in chunks. It implements a next method that takes a DeviceMemoryBuffer + * as an argument to be used for the copy. + * + * Each chunk is sized at most `bounceBuffer.getLength`, and the caller should cudaMemcpy + * bytes from `bounceBuffer` to a target buffer after each call to `next()`. + * + * @note `ChunkedPacker` must be closed by the caller as it has GPU and host resources + * associated with it. + * + * @param table cuDF Table to chunk_pack + * @param bounceBufferSize size of GPU memory to be used for packing. The buffer will be + * obtained during the iterator-like 'next(DeviceMemoryBuffer)' + */ +class ChunkedPacker(table: Table, + bounceBufferSize: Long) + extends Logging with AutoCloseable { + + private var closed: Boolean = false + + // When creating cudf::chunked_pack use a pool if available, otherwise default to the + // per-device memory resource + private val chunkedPack = { + val pool = GpuDeviceManager.chunkedPackMemoryResource + val cudfChunkedPack = try { + pool.flatMap { chunkedPool => + Some(table.makeChunkedPack(bounceBufferSize, chunkedPool)) + } + } catch { + case _: OutOfMemoryError => + if (!ChunkedPacker.warnedAboutPoolFallback) { + ChunkedPacker.warnedAboutPoolFallback = true + logWarning( + s"OOM while creating chunked_pack using pool sized ${pool.map(_.getMaxSize)}B. " + + "Falling back to the per-device memory resource.") + } + None + } + + // if the pool is not configured, or we got an OOM, try again with the per-device pool + cudfChunkedPack.getOrElse { + table.makeChunkedPack(bounceBufferSize) + } + } + + private val packedMeta = withResource(chunkedPack.buildMetadata()) { packedMeta => + val tmpBB = packedMeta.getMetadataDirectBuffer + val metaCopy = ByteBuffer.allocateDirect(tmpBB.capacity()) + metaCopy.put(tmpBB) + metaCopy.flip() + metaCopy + } + + def getTotalContiguousSize: Long = chunkedPack.getTotalContiguousSize + + def getPackedMeta: ByteBuffer = { + packedMeta + } + + def hasNext: Boolean = { + if (closed) { + throw new IllegalStateException(s"ChunkedPacker is closed") + } + chunkedPack.hasNext + } + + def next(bounceBuffer: DeviceMemoryBuffer): DeviceMemoryBuffer = { + require(bounceBuffer.getLength == bounceBufferSize, + s"Bounce buffer ${bounceBuffer} doesn't match size ${bounceBufferSize} B.") + + if (closed) { + throw new IllegalStateException(s"ChunkedPacker is closed") + } + val bytesWritten = chunkedPack.next(bounceBuffer) + // we increment the refcount because the caller has no idea where + // this memory came from, so it should close it. + bounceBuffer.slice(0, bytesWritten) + } + + override def close(): Unit = { + if (!closed) { + closed = true + chunkedPack.close() + } + } +} + +object ChunkedPacker { + private var warnedAboutPoolFallback: Boolean = false +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index d5216cbda9f..fc10477520b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -16,8 +16,8 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.{ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer} +import com.nvidia.spark.rapids.Arm.closeOnExcept import org.apache.spark.TaskContext import org.apache.spark.sql.types.DataType @@ -88,7 +88,7 @@ class JustRowsColumnarBatch(numRows: Int) * use `SpillableColumnarBatch.apply` instead. */ class SpillableColumnarBatchImpl ( - handle: RapidsBufferHandle, + handle: SpillableColumnarBatchHandle, rowCount: Int, sparkTypes: Array[DataType]) extends SpillableColumnarBatch { @@ -100,27 +100,128 @@ class SpillableColumnarBatchImpl ( */ override def numRows(): Int = rowCount - private def withRapidsBuffer[T](fn: RapidsBuffer => T): T = { - withResource(RapidsBufferCatalog.acquireBuffer(handle)) { rapidsBuffer => - fn(rapidsBuffer) + override lazy val sizeInBytes: Long = handle.sizeInBytes + + /** + * Set a new spill priority. + */ + override def setSpillPriority(priority: Long): Unit = { + // TODO: handle.setSpillPriority(priority) + } + + override def getColumnarBatch(): ColumnarBatch = { + GpuSemaphore.acquireIfNecessary(TaskContext.get()) + handle.materialize(sparkTypes) + } + + override def incRefCount(): SpillableColumnarBatch = { + if (refCount <= 0) { + throw new IllegalStateException("Use after free on SpillableColumnarBatchImpl") } + refCount += 1 + this } - override lazy val sizeInBytes: Long = - withRapidsBuffer(_.memoryUsedBytes) + /** + * Remove the `ColumnarBatch` from the cache. + */ + override def close(): Unit = { + refCount -= 1 + if (refCount == 0) { + // closing my reference + handle.close() + } + // TODO this is causing problems so we need to look into this + // https://github.com/NVIDIA/spark-rapids/issues/10161 + //else if (refCount < 0) { + // throw new IllegalStateException("Double free on SpillableColumnarBatchImpl") + //} + } + + override def toString: String = + s"SCB $handle $rowCount ${sparkTypes.toList} $refCount" +} + +class SpillableCompressedColumnarBatchImpl( + handle: SpillableCompressedColumnarBatchHandle, rowCount: Int) + extends SpillableColumnarBatch { + + private var refCount = 1 + + /** + * The number of rows stored in this batch. + */ + override def numRows(): Int = rowCount + + override lazy val sizeInBytes: Long = handle.compressedSizeInBytes /** * Set a new spill priority. */ override def setSpillPriority(priority: Long): Unit = { - handle.setSpillPriority(priority) + // TODO: handle.setSpillPriority(priority) } override def getColumnarBatch(): ColumnarBatch = { - withRapidsBuffer { rapidsBuffer => - GpuSemaphore.acquireIfNecessary(TaskContext.get()) - rapidsBuffer.getColumnarBatch(sparkTypes) + GpuSemaphore.acquireIfNecessary(TaskContext.get()) + handle.materialize + } + + override def incRefCount(): SpillableColumnarBatch = { + if (refCount <= 0) { + throw new IllegalStateException("Use after free on SpillableColumnarBatchImpl") + } + refCount += 1 + this + } + + /** + * Remove the `ColumnarBatch` from the cache. + */ + override def close(): Unit = { + refCount -= 1 + if (refCount == 0) { + // closing my reference + handle.close() } + // TODO this is causing problems so we need to look into this + // https://github.com/NVIDIA/spark-rapids/issues/10161 + //else if (refCount < 0) { + // throw new IllegalStateException("Double free on SpillableColumnarBatchImpl") + //} + } + + override def toString: String = + s"SCCB $handle $rowCount $refCount" + + override def dataTypes: Array[DataType] = null +} + +class SpillableColumnarBatchFromBufferImpl( + handle: SpillableColumnarBatchFromBufferHandle, + rowCount: Int, + sparkTypes: Array[DataType]) + extends SpillableColumnarBatch { + private var refCount = 1 + + override def dataTypes: Array[DataType] = sparkTypes + /** + * The number of rows stored in this batch. + */ + override def numRows(): Int = rowCount + + override lazy val sizeInBytes: Long = handle.sizeInBytes + + /** + * Set a new spill priority. + */ + override def setSpillPriority(priority: Long): Unit = { + // TODO: handle.setSpillPriority(priority) + } + + override def getColumnarBatch(): ColumnarBatch = { + GpuSemaphore.acquireIfNecessary(TaskContext.get()) + handle.materialize(dataTypes) } override def incRefCount(): SpillableColumnarBatch = { @@ -142,9 +243,9 @@ class SpillableColumnarBatchImpl ( } // TODO this is causing problems so we need to look into this // https://github.com/NVIDIA/spark-rapids/issues/10161 -// else if (refCount < 0) { -// throw new IllegalStateException("Double free on SpillableColumnarBatchImpl") -// } + //else if (refCount < 0) { + // throw new IllegalStateException("Double free on SpillableColumnarBatchImpl") + //} } override def toString: String = @@ -176,10 +277,9 @@ class JustRowsHostColumnarBatch(numRows: Int) * use `SpillableHostColumnarBatch.apply` instead. */ class SpillableHostColumnarBatchImpl ( - handle: RapidsBufferHandle, + handle: SpillableHostColumnarBatchHandle, rowCount: Int, - sparkTypes: Array[DataType], - catalog: RapidsBufferCatalog) + sparkTypes: Array[DataType]) extends SpillableColumnarBatch { private var refCount = 1 @@ -190,27 +290,19 @@ class SpillableHostColumnarBatchImpl ( */ override def numRows(): Int = rowCount - private def withRapidsHostBatchBuffer[T](fn: RapidsHostBatchBuffer => T): T = { - withResource(catalog.acquireHostBatchBuffer(handle)) { rapidsBuffer => - fn(rapidsBuffer) - } - } - override lazy val sizeInBytes: Long = { - withRapidsHostBatchBuffer(_.memoryUsedBytes) + handle.sizeInBytes } /** * Set a new spill priority. */ override def setSpillPriority(priority: Long): Unit = { - handle.setSpillPriority(priority) + // TODO: handle.setSpillPriority(priority) } override def getColumnarBatch(): ColumnarBatch = { - withRapidsHostBatchBuffer { hostBatchBuffer => - hostBatchBuffer.getHostColumnarBatch(sparkTypes) - } + handle.materialize(sparkTypes) } override def incRefCount(): SpillableColumnarBatch = { @@ -245,18 +337,29 @@ object SpillableColumnarBatch { */ def apply(batch: ColumnarBatch, priority: Long): SpillableColumnarBatch = { + Cuda.DEFAULT_STREAM.sync() val numRows = batch.numRows() if (batch.numCols() <= 0) { // We consumed it batch.close() new JustRowsColumnarBatch(numRows) } else { - val types = GpuColumnVector.extractTypes(batch) - val handle = addBatch(batch, priority) - new SpillableColumnarBatchImpl( - handle, - numRows, - types) + if (GpuCompressedColumnVector.isBatchCompressed(batch)) { + new SpillableCompressedColumnarBatchImpl( + SpillableCompressedColumnarBatchHandle(batch), + numRows) + } else if (GpuColumnVectorFromBuffer.isFromBuffer(batch)) { + new SpillableColumnarBatchFromBufferImpl( + SpillableColumnarBatchFromBufferHandle(batch), + numRows, + GpuColumnVector.extractTypes(batch) + ) + } else { + new SpillableColumnarBatchImpl( + SpillableColumnarBatchHandle(batch), + numRows, + GpuColumnVector.extractTypes(batch)) + } } } @@ -271,54 +374,11 @@ object SpillableColumnarBatch { ct: ContiguousTable, sparkTypes: Array[DataType], priority: Long): SpillableColumnarBatch = { - withResource(ct) { _ => - val handle = RapidsBufferCatalog.addContiguousTable(ct, priority) - new SpillableColumnarBatchImpl(handle, ct.getRowCount.toInt, sparkTypes) - } - } - - private[this] def allFromSameBuffer(batch: ColumnarBatch): Boolean = { - var bufferAddr = 0L - var isSet = false - val numColumns = batch.numCols() - (0 until numColumns).forall { i => - batch.column(i) match { - case fb: GpuColumnVectorFromBuffer => - if (!isSet) { - bufferAddr = fb.getBuffer.getAddress - isSet = true - true - } else { - bufferAddr == fb.getBuffer.getAddress - } - case _ => false - } - } - } - - private[this] def addBatch( - batch: ColumnarBatch, - initialSpillPriority: Long): RapidsBufferHandle = { - withResource(batch) { batch => - val numColumns = batch.numCols() - if (GpuCompressedColumnVector.isBatchCompressed(batch)) { - val cv = batch.column(0).asInstanceOf[GpuCompressedColumnVector] - val buff = cv.getTableBuffer - RapidsBufferCatalog.addBuffer(buff, cv.getTableMeta, initialSpillPriority) - } else if (GpuPackedTableColumn.isBatchPacked(batch)) { - val cv = batch.column(0).asInstanceOf[GpuPackedTableColumn] - RapidsBufferCatalog.addContiguousTable( - cv.getContiguousTable, - initialSpillPriority) - } else if (numColumns > 0 && - allFromSameBuffer(batch)) { - val cv = batch.column(0).asInstanceOf[GpuColumnVectorFromBuffer] - val buff = cv.getBuffer - RapidsBufferCatalog.addBuffer(buff, cv.getTableMeta, initialSpillPriority) - } else { - RapidsBufferCatalog.addBatch(batch, initialSpillPriority) - } - } + Cuda.DEFAULT_STREAM.sync() + new SpillableColumnarBatchFromBufferImpl( + SpillableColumnarBatchFromBufferHandle(ct, sparkTypes), + ct.getRowCount.toInt, + sparkTypes) } } @@ -330,10 +390,7 @@ object SpillableHostColumnarBatch { * @param batch the batch to make spillable * @param priority the initial spill priority of this batch */ - def apply( - batch: ColumnarBatch, - priority: Long, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton): SpillableColumnarBatch = { + def apply(batch: ColumnarBatch, priority: Long): SpillableColumnarBatch = { val numRows = batch.numRows() if (batch.numCols() <= 0) { // We consumed it @@ -341,45 +398,30 @@ object SpillableHostColumnarBatch { new JustRowsHostColumnarBatch(numRows) } else { val types = RapidsHostColumnVector.extractColumns(batch).map(_.dataType()) - val handle = addHostBatch(batch, priority, catalog) - new SpillableHostColumnarBatchImpl( - handle, - numRows, - types, - catalog) - } - } - - private[this] def addHostBatch( - batch: ColumnarBatch, - initialSpillPriority: Long, - catalog: RapidsBufferCatalog): RapidsBufferHandle = { - withResource(batch) { batch => - catalog.addBatch(batch, initialSpillPriority) + val handle = SpillableHostColumnarBatchHandle(batch) + new SpillableHostColumnarBatchImpl(handle, numRows, types) } } - } + /** * Just like a SpillableColumnarBatch but for buffers. */ class SpillableBuffer( - handle: RapidsBufferHandle) extends AutoCloseable { + handle: SpillableDeviceBufferHandle) extends AutoCloseable { /** * Set a new spill priority. */ def setSpillPriority(priority: Long): Unit = { - handle.setSpillPriority(priority) + // TODO: handle.setSpillPriority(priority) } /** * Use the device buffer. */ def getDeviceBuffer(): DeviceMemoryBuffer = { - withResource(RapidsBufferCatalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.getDeviceMemoryBuffer - } + handle.materialize } /** @@ -397,17 +439,15 @@ class SpillableBuffer( * @param length a metadata-only length that is kept in the `SpillableHostBuffer` * instance. Used in cases where the backing host buffer is larger * than the number of usable bytes. - * @param catalog this was added for tests, it defaults to - * `RapidsBufferCatalog.singleton` in the companion object. */ -class SpillableHostBuffer(handle: RapidsBufferHandle, - val length: Long, - catalog: RapidsBufferCatalog) extends AutoCloseable { +class SpillableHostBuffer(handle: SpillableHostBufferHandle, + val length: Long) + extends AutoCloseable { /** * Set a new spill priority. */ def setSpillPriority(priority: Long): Unit = { - handle.setSpillPriority(priority) + // TODO: handle.setSpillPriority(priority) } /** @@ -417,10 +457,8 @@ class SpillableHostBuffer(handle: RapidsBufferHandle, handle.close() } - def getHostBuffer(): HostMemoryBuffer = { - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.getHostMemoryBuffer - } + def getHostBuffer: HostMemoryBuffer = { + handle.materialize } } @@ -435,10 +473,8 @@ object SpillableBuffer { def apply( buffer: DeviceMemoryBuffer, priority: Long): SpillableBuffer = { - val meta = MetaUtils.getTableMetaNoTable(buffer.getLength) - val handle = withResource(buffer) { _ => - RapidsBufferCatalog.addBuffer(buffer, meta, priority) - } + Cuda.DEFAULT_STREAM.sync() + val handle = SpillableDeviceBufferHandle(buffer) // TODO: AB: priority new SpillableBuffer(handle) } } @@ -456,17 +492,12 @@ object SpillableHostBuffer { */ def apply(buffer: HostMemoryBuffer, length: Long, - priority: Long, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton): SpillableHostBuffer = { + priority: Long): SpillableHostBuffer = { closeOnExcept(buffer) { _ => require(length <= buffer.getLength, s"Attempted to add a host spillable with a length ${length} B which is " + s"greater than the backing host buffer length ${buffer.getLength} B") } - val meta = MetaUtils.getTableMetaNoTable(buffer.getLength) - val handle = withResource(buffer) { _ => - catalog.addBuffer(buffer, meta, priority) - } - new SpillableHostBuffer(handle, length, catalog) + new SpillableHostBuffer(SpillableHostBufferHandle(buffer), length) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala index 08a1ae22f5e..3512871726a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,9 @@ package com.nvidia.spark.rapids.shuffle -import java.io.IOException - -import ai.rapids.cudf.{Cuda, MemoryBuffer} -import com.nvidia.spark.rapids.{RapidsBuffer, ShuffleMetadata, StorageTier} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, MemoryBuffer} +import com.nvidia.spark.rapids.{RapidsShuffleHandle, ShuffleMetadata} +import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.format.{BufferMeta, BufferTransferRequest} @@ -60,8 +58,17 @@ class BufferSendState( serverStream: Cuda.Stream = Cuda.DEFAULT_STREAM) extends AutoCloseable with Logging { - class SendBlock(val bufferId: Int, tableSize: Long) extends BlockWithSize { - override def size: Long = tableSize + class SendBlock(val bufferHandle: RapidsShuffleHandle) extends BlockWithSize { + // we assume that the size of the buffer won't change as it goes to host/disk + // we also are likely to assume this is just a device buffer, and so we should + // copy to device and then send. + override def size: Long = { + if (bufferHandle.spillable != null) { + bufferHandle.spillable.sizeInBytes + } else { + 0L // degenerate + } + } } val peerExecutorId: Long = transaction.peerExecutorId() @@ -80,13 +87,10 @@ class BufferSendState( val btr = new BufferTransferRequest() // for reuse val blocksToSend = (0 until transferRequest.requestsLength()).map { ix => val bufferTransferRequest = transferRequest.requests(btr, ix) - withResource(requestHandler.acquireShuffleBuffer( - bufferTransferRequest.bufferId())) { table => - bufferMetas(ix) = table.meta.bufferMeta() - new SendBlock(bufferTransferRequest.bufferId(), table.getPackedSizeBytes) - } + val handle = requestHandler.getShuffleHandle(bufferTransferRequest.bufferId()) + bufferMetas(ix) = handle.tableMeta.bufferMeta() + new SendBlock(handle) } - (peerBufferReceiveHeader, bufferMetas, blocksToSend) } } @@ -145,7 +149,7 @@ class BufferSendState( } case class RangeBuffer( - range: BlockRange[SendBlock], rapidsBuffer: RapidsBuffer) + range: BlockRange[SendBlock], rapidsBuffer: MemoryBuffer) extends AutoCloseable { override def close(): Unit = { rapidsBuffer.close() @@ -170,50 +174,50 @@ class BufferSendState( if (hasMoreBlocks) { var deviceBuffs = 0L var hostBuffs = 0L - acquiredBuffs = blockRanges.safeMap { blockRange => - val bufferId = blockRange.block.bufferId - // we acquire these buffers now, and keep them until the caller releases them - // using `releaseAcquiredToCatalog` - closeOnExcept( - requestHandler.acquireShuffleBuffer(bufferId)) { rapidsBuffer => + var needsCleanup = false + try { + acquiredBuffs = blockRanges.safeMap { blockRange => + // we acquire these buffers now, and keep them until the caller releases them + // using `releaseAcquiredToCatalog` //these are closed later, after we synchronize streams - rapidsBuffer.storageTier match { - case StorageTier.DEVICE => + val spillable = blockRange.block.bufferHandle.spillable + val buff = spillable.materialize + buff match { + case _: DeviceMemoryBuffer => deviceBuffs += blockRange.rangeSize() - case _ => // host/disk + case _ => hostBuffs += blockRange.rangeSize() } - RangeBuffer(blockRange, rapidsBuffer) + RangeBuffer(blockRange, buff) } - } - logDebug(s"Occupancy for bounce buffer is [device=${deviceBuffs}, host=${hostBuffs}] Bytes") + logDebug(s"Occupancy for bounce buffer is " + + s"[device=${deviceBuffs}, host=${hostBuffs}] Bytes") - bounceBuffToUse = if (deviceBuffs >= hostBuffs || hostBounceBuffer == null) { - deviceBounceBuffer.buffer - } else { - hostBounceBuffer.buffer - } + bounceBuffToUse = if (deviceBuffs >= hostBuffs || hostBounceBuffer == null) { + deviceBounceBuffer.buffer + } else { + hostBounceBuffer.buffer + } - // `copyToMemoryBuffer` can throw if the `RapidsBuffer` is in the DISK tier and - // the file fails to mmap. We catch the `IOException` and attempt a retry - // in the server. - var needsCleanup = false - try { - acquiredBuffs.foreach { case RangeBuffer(blockRange, rapidsBuffer) => + acquiredBuffs.foreach { case RangeBuffer(blockRange, memoryBuffer) => needsCleanup = true require(blockRange.rangeSize() <= bounceBuffToUse.getLength - buffOffset) - rapidsBuffer.copyToMemoryBuffer(blockRange.rangeStart, bounceBuffToUse, buffOffset, - blockRange.rangeSize(), serverStream) + bounceBuffToUse.copyFromMemoryBufferAsync( + buffOffset, + memoryBuffer, + blockRange.rangeStart, + blockRange.rangeSize(), + serverStream) buffOffset += blockRange.rangeSize() } needsCleanup = false } catch { - case ioe: IOException => + case ex: Throwable => throw new RapidsShuffleSendPrepareException( s"Error while copying to bounce buffer for executor ${peerExecutorId} and " + - s"header ${TransportUtils.toHex(peerBufferReceiveHeader)}", ioe) + s"header ${TransportUtils.toHex(peerBufferReceiveHeader)}", ex) } finally { if (needsCleanup) { // we likely failed in `copyToMemoryBuffer` @@ -251,4 +255,4 @@ class BufferSendState( acquiredBuffs.foreach(_.close()) acquiredBuffs = Seq.empty } -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index b73f9820bad..2723e24b0f0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ trait RapidsShuffleFetchHandler { * @return a boolean that lets the caller know the batch was accepted (true), or * rejected (false), in which case the caller should dispose of the batch. */ - def batchReceived(handle: RapidsBufferHandle): Boolean + def batchReceived(handle: RapidsShuffleHandle): Boolean /** * Called when the transport layer is not able to handle a fetch error for metadata @@ -390,7 +390,7 @@ class RapidsShuffleClient( buffMetas.foreach { consumed: ConsumedBatchFromBounceBuffer => val handle = track(consumed.contigBuffer, consumed.meta) if (!consumed.handler.batchReceived(handle)) { - catalog.removeBuffer(handle) + handle.close() numBatchesRejected += 1 } transport.doneBytesInFlight(consumed.contigBuffer.getLength) @@ -431,25 +431,19 @@ class RapidsShuffleClient( * used to look up the buffer from the catalog going (e.g. from the iterator) * @param buffer contiguous [[DeviceMemoryBuffer]] with the tables' data * @param meta [[TableMeta]] describing [[buffer]] - * @return the [[RapidsBufferId]] to be used to look up the buffer from catalog + * @return a [[RapidsShuffleHandle]] with a spillable and metadata */ private[shuffle] def track( - buffer: DeviceMemoryBuffer, meta: TableMeta): RapidsBufferHandle = { + buffer: DeviceMemoryBuffer, meta: TableMeta): RapidsShuffleHandle = { if (buffer != null) { // add the buffer to the catalog so it is available for spill catalog.addBuffer( buffer, meta, - SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY, - // set needsSync to false because we already have stream synchronized after - // consuming the bounce buffer, so we know these buffers are synchronized - // w.r.t. the CPU - needsSync = false) + SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY) } else { // no device data, just tracking metadata - catalog.addDegenerateRapidsBuffer( - meta) - + catalog.addDegenerateBatch(meta) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala index 126b9200c90..95c411c64f2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import java.util.concurrent.{ConcurrentLinkedQueue, Executor} import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{Cuda, MemoryBuffer, NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.{RapidsBuffer, RapidsConf, ShuffleMetadata} +import com.nvidia.spark.rapids.{RapidsConf, RapidsShuffleHandle, ShuffleMetadata} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.format.TableMeta @@ -49,7 +49,7 @@ trait RapidsShuffleRequestHandler { * @param tableId the unique id for a table in the catalog * @return a [[RapidsBuffer]] which is reference counted, and should be closed by the acquirer */ - def acquireShuffleBuffer(tableId: Int): RapidsBuffer + def getShuffleHandle(tableId: Int): RapidsShuffleHandle } /** diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index 1b0ee21d494..7f8733b9e00 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -44,12 +44,12 @@ class GpuShuffleEnv(rapidsConf: RapidsConf) extends Logging { } } - def init(diskBlockManager: RapidsDiskBlockManager): Unit = { + def init(): Unit = { if (isRapidsShuffleConfigured) { shuffleCatalog = - new ShuffleBufferCatalog(RapidsBufferCatalog.singleton, diskBlockManager) + new ShuffleBufferCatalog() shuffleReceivedBufferCatalog = - new ShuffleReceivedBufferCatalog(RapidsBufferCatalog.singleton) + new ShuffleReceivedBufferCatalog() } } @@ -172,9 +172,9 @@ object GpuShuffleEnv extends Logging { // Functions below only get called from the executor // - def init(conf: RapidsConf, diskBlockManager: RapidsDiskBlockManager): Unit = { + def init(conf: RapidsConf): Unit = { val shuffleEnv = new GpuShuffleEnv(conf) - shuffleEnv.init(diskBlockManager) + shuffleEnv.init() env = shuffleEnv } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala index 5f1052f0e59..e447aed87d8 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala @@ -130,12 +130,12 @@ class GpuTaskMetrics extends Serializable { def getMaxDiskBytesAllocated: Long = maxDiskBytesAllocated - def incDiskBytesAllocated(bytes: Long): Unit = { + def incDiskBytesAllocated(bytes: Long): Unit = synchronized { diskBytesAllocated += bytes maxDiskBytesAllocated = maxDiskBytesAllocated.max(diskBytesAllocated) } - def decDiskBytesAllocated(bytes: Long): Unit = { + def decDiskBytesAllocated(bytes: Long): Unit = synchronized { diskBytesAllocated -= bytes // For some reason it's possible for the task to start out by releasing resources, // possibly from a previous task, in such case we probably should just ignore it. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index da54735aaf4..0b2c2bdef6f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -1116,7 +1116,7 @@ class RapidsCachingWriter[K, V]( val blockId = ShuffleBlockId(handle.shuffleId, mapId, partId) if (batch.numRows > 0 && batch.numCols > 0) { // Add the table to the shuffle store - val handle = batch.column(0) match { + batch.column(0) match { case c: GpuPackedTableColumn => val contigTable = c.getContiguousTable partSize = c.getTableBuffer.getLength @@ -1124,23 +1124,14 @@ class RapidsCachingWriter[K, V]( catalog.addContiguousTable( blockId, contigTable, - SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY, - // we don't need to sync here, because we sync on the cuda - // stream after sliceInternalOnGpu (contiguous_split) - needsSync = false) + SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY) case c: GpuCompressedColumnVector => - val buffer = c.getTableBuffer - partSize = buffer.getLength - val tableMeta = c.getTableMeta - uncompressedMetric += tableMeta.bufferMeta().uncompressedSize() - catalog.addBuffer( + partSize = c.getTableBuffer.getLength + uncompressedMetric += c.getTableMeta.bufferMeta().uncompressedSize() + catalog.addCompressedBatch( blockId, - buffer, - tableMeta, - SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY, - // we don't need to sync here, because we sync on the cuda - // stream after compression. - needsSync = false) + batch, + SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY) case c => throw new IllegalStateException(s"Unexpected column type: ${c.getClass}") } @@ -1154,21 +1145,18 @@ class RapidsCachingWriter[K, V]( } else { sizes(partId) += partSize } - handle } else { // no device data, tracking only metadata val tableMeta = MetaUtils.buildDegenerateTableMeta(batch) - val handle = - catalog.addDegenerateRapidsBuffer( - blockId, - tableMeta) + catalog.addDegenerateRapidsBuffer( + blockId, + tableMeta) // ensure that we set the partition size to the default in this case if // we have non-zero rows, so this degenerate batch is shuffled. if (batch.numRows > 0) { sizes(partId) += DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT } - handle } } metricsReporter.incBytesWritten(bytesWritten) @@ -1357,9 +1345,8 @@ class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: Boolean) if (rapidsConf.isGPUShuffle && !isDriver) { val catalog = getCatalogOrThrow val requestHandler = new RapidsShuffleRequestHandler() { - override def acquireShuffleBuffer(tableId: Int): RapidsBuffer = { - val handle = catalog.getShuffleBufferHandle(tableId) - catalog.acquireBuffer(handle) + override def getShuffleHandle(tableId: Int): RapidsShuffleHandle = { + catalog.getShuffleBufferHandle(tableId) } override def getShuffleBufferMetas(sbbId: ShuffleBlockBatchId): Seq[TableMeta] = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/TempSpillBufferId.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/TempSpillBufferId.scala deleted file mode 100644 index 0f9510a28ba..00000000000 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/TempSpillBufferId.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.rapids - -import java.io.File -import java.util.UUID -import java.util.concurrent.atomic.AtomicInteger -import java.util.function.IntUnaryOperator - -import com.nvidia.spark.rapids.RapidsBufferId - -import org.apache.spark.storage.TempLocalBlockId - -object TempSpillBufferId { - private val MAX_TABLE_ID = Integer.MAX_VALUE - private val TABLE_ID_UPDATER = new IntUnaryOperator { - override def applyAsInt(i: Int): Int = if (i < MAX_TABLE_ID) i + 1 else 0 - } - - /** Tracks the next table identifier */ - private[this] val tableIdCounter = new AtomicInteger(0) - - def apply(): TempSpillBufferId = { - val tableId = tableIdCounter.getAndUpdate(TABLE_ID_UPDATER) - val tempBlockId = TempLocalBlockId(UUID.randomUUID()) - new TempSpillBufferId(tableId, tempBlockId) - } -} - -case class TempSpillBufferId private( - override val tableId: Int, - bufferId: TempLocalBlockId) extends RapidsBufferId { - - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - diskBlockManager.getFile(bufferId) -} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index bd30459d63e..9290a8a9482 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -167,7 +167,7 @@ class SerializeConcatHostBuffersDeserializeBatch( * This will populate `data` before any task has had a chance to call `.batch` on this class. * * If `batchInternal` is defined we are in the executor, and there is no work to be done. - * This broadcast has been materialized on the GPU/RapidsBufferCatalog, and it is completely + * This broadcast has been materialized on the GPU/spill store, and it is completely * managed by the plugin. * * Public for unit tests. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala index b2bb5461a40..cfd7e4e60b8 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,8 @@ object GpuBroadcastHelper { case broadcastBatch: SerializeConcatHostBuffersDeserializeBatch => RmmRapidsRetryIterator.withRetryNoSplit { withResource(new NvtxRange("getBroadcastBatch", NvtxColor.YELLOW)) { _ => - broadcastBatch.batch.getColumnarBatch() + val spillable = broadcastBatch.batch + spillable.getColumnarBatch() } } case v if SparkShimImpl.isEmptyRelation(v) => diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala index 70942001cae..e06f708b6f5 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala @@ -39,7 +39,7 @@ import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import scala.collection.mutable import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.{GpuSemaphore, RapidsBuffer, RapidsBufferHandle, RapidsConf, ShuffleReceivedBufferCatalog} +import com.nvidia.spark.rapids.{GpuSemaphore, RapidsConf, RapidsShuffleHandle, ShuffleReceivedBufferCatalog} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.jni.RmmSpark @@ -90,7 +90,7 @@ class RapidsShuffleIterator( * A result for a successful buffer received * @param handle - the shuffle received buffer handle as tracked in the catalog */ - case class BufferReceived(handle: RapidsBufferHandle) extends ShuffleClientResult + case class BufferReceived(handle: RapidsShuffleHandle) extends ShuffleClientResult /** * A result for a failed attempt at receiving block metadata, or corresponding batches. @@ -245,7 +245,7 @@ class RapidsShuffleIterator( def clientDone: Boolean = clientExpectedBatches > 0 && clientExpectedBatches == clientResolvedBatches - override def batchReceived(handle: RapidsBufferHandle): Boolean = { + override def batchReceived(handle: RapidsShuffleHandle): Boolean = { resolvedBatches.synchronized { if (taskComplete) { false @@ -310,8 +310,7 @@ class RapidsShuffleIterator( logWarning(s"Iterator for task ${taskAttemptIdStr} closing, " + s"but it is not done. Closing ${resolvedBatches.size()} resolved batches!!") resolvedBatches.forEach { - case BufferReceived(handle) => - GpuShuffleEnv.getReceivedCatalog.removeBuffer(handle) + case BufferReceived(handle) => handle.close() case _ => } // tell the client to cancel pending requests @@ -337,8 +336,6 @@ class RapidsShuffleIterator( } override def next(): ColumnarBatch = { - var cb: ColumnarBatch = null - var sb: RapidsBuffer = null val range = new NvtxRange(s"RapidshuffleIterator.next", NvtxColor.RED) // If N tasks downstream are accumulating memory we run the risk OOM @@ -356,6 +353,7 @@ class RapidsShuffleIterator( // fetches and so it could produce device memory. Note this is not allowing for some external // thread to schedule the fetches for us, it may be something we consider in the future, given // memory pressure. + // No good way to get a metric in here for semaphore time. taskContext.foreach(GpuSemaphore.acquireIfNecessary) if (!started) { @@ -379,16 +377,12 @@ class RapidsShuffleIterator( val nvtxRangeAfterGettingBatch = new NvtxRange("RapidsShuffleIterator.gotBatch", NvtxColor.PURPLE) try { - sb = catalog.acquireBuffer(handle) - cb = sb.getColumnarBatch(sparkTypes) - metricsUpdater.update(blockedTime, 1, sb.memoryUsedBytes, cb.numRows()) + val (cb, memoryUsedBytes) = catalog.getColumnarBatchAndRemove(handle, sparkTypes) + metricsUpdater.update(blockedTime, 1, memoryUsedBytes, cb.numRows()) + cb } finally { nvtxRangeAfterGettingBatch.close() range.close() - if (sb != null) { - sb.close() - } - catalog.removeBuffer(handle) } case Some( TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage, throwable)) => @@ -414,6 +408,5 @@ class RapidsShuffleIterator( case _ => throw new IllegalStateException(s"Invalid result type $result") } - cb } } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala index 3334187bb16..1905190f30e 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala @@ -38,13 +38,13 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.shuffle.{RapidsShuffleIterator, RapidsShuffleTransport} import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.{ShuffleReader, ShuffleReadMetricsReporter} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator @@ -79,10 +79,12 @@ class RapidsCachingReader[K, C]( override def read(): Iterator[Product2[K, C]] = { val readRange = new NvtxRange(s"RapidsCachingReader.read", NvtxColor.DARK_GREEN) try { - val blocksForRapidsTransport = new ArrayBuffer[(BlockManagerId, Seq[(BlockId, Long, Int)])]() - val cachedBlocks = new ArrayBuffer[BlockId]() - val cachedBufferHandles = new ArrayBuffer[RapidsBufferHandle]() - val blocksByAddressMap: Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = blocksByAddress.toMap + val blocksForRapidsTransport = + new ArrayBuffer[(BlockManagerId, Seq[(BlockId, Long, Int)])]() + var cachedBatchIterator: Iterator[ColumnarBatch] = Iterator.empty + val blocksByAddressMap: Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = + blocksByAddress.toMap + var numCachedBlocks: Int = 0 blocksByAddressMap.keys.foreach(blockManagerId => { val blockInfos: Seq[(BlockId, Long, Int)] = blocksByAddressMap(blockManagerId) @@ -91,33 +93,29 @@ class RapidsCachingReader[K, C]( if (blockManagerId.executorId == localId.executorId) { val readLocalRange = new NvtxRange("Read Local", NvtxColor.GREEN) try { - blockInfos.foreach( - blockInfo => { - val blockId = blockInfo._1 - val shuffleBufferHandles: IndexedSeq[RapidsBufferHandle] = blockId match { - case sbbid: ShuffleBlockBatchId => - (sbbid.startReduceId to sbbid.endReduceId).flatMap { reduceId => - cachedBlocks.append(blockId) - val sBlockId = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId) - catalog.blockIdToBufferHandles(sBlockId) - } - case sbid: ShuffleBlockId => - cachedBlocks.append(blockId) - catalog.blockIdToBufferHandles(sbid) - case _ => throw new IllegalArgumentException( - s"${blockId.getClass} $blockId is not currently supported") - } - - cachedBufferHandles ++= shuffleBufferHandles - - // Update the spill priorities of these buffers to indicate they are about - // to be read and therefore should not be spilled if possible. - shuffleBufferHandles.foreach(catalog.updateSpillPriorityForLocalRead) - - if (shuffleBufferHandles.nonEmpty) { - metrics.incLocalBlocksFetched(1) - } - }) + cachedBatchIterator = blockInfos.iterator.flatMap { blockInfo => + val blockId = blockInfo._1 + val shuffleBufferHandles = blockId match { + case sbbid: ShuffleBlockBatchId => + (sbbid.startReduceId to sbbid.endReduceId).iterator.flatMap { reduceId => + val sbid = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId) + numCachedBlocks += 1 + catalog.getColumnarBatchIterator(sbid, sparkTypes) + } + case sbid: ShuffleBlockId => + numCachedBlocks += 1 + catalog.getColumnarBatchIterator(sbid, sparkTypes) + case _ => throw new IllegalArgumentException( + s"${blockId.getClass} $blockId is not currently supported") + } + + shuffleBufferHandles + } + + // Update the spill priorities of these buffers to indicate they are about + // to be read and therefore should not be spilled if possible. + // TODO: AB: shuffleBufferHandles.foreach(catalog.updateSpillPriorityForLocalRead) + metrics.incLocalBlocksFetched(numCachedBlocks) } finally { readLocalRange.close() } @@ -139,7 +137,7 @@ class RapidsCachingReader[K, C]( } }) - logInfo(s"Will read ${cachedBlocks.size} cached blocks, " + + logInfo(s"Will read ${numCachedBlocks} cached blocks, " + s"${blocksForRapidsTransport.size} remote blocks from the RapidsShuffleTransport. ") if (transport.isEmpty && blocksForRapidsTransport.nonEmpty) { @@ -159,17 +157,12 @@ class RapidsCachingReader[K, C]( val itRange = new NvtxRange("Shuffle Iterator prep", NvtxColor.BLUE) try { - val cachedIt = cachedBufferHandles.iterator.map(bufferHandle => { - // No good way to get a metric in here for semaphore wait time - GpuSemaphore.acquireIfNecessary(context) - val cb = withResource(catalog.acquireBuffer(bufferHandle)) { buffer => - buffer.getColumnarBatch(sparkTypes) - } + val cachedIt = cachedBatchIterator.map { cb => val cachedBytesRead = GpuColumnVector.getTotalDeviceMemoryUsed(cb) metrics.incLocalBytesRead(cachedBytesRead) metrics.incRecordsRead(cb.numRows()) (0, cb) - }).asInstanceOf[Iterator[(K, C)]] + }.asInstanceOf[Iterator[(K, C)]] val cbArrayFromUcx: Iterator[(K, C)] = if (blocksForRapidsTransport.nonEmpty) { val rapidsShuffleIterator = new RapidsShuffleIterator(localId, rapidsConf, transport.get, diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala index 55873e8020b..78cf79ad4c1 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala @@ -34,7 +34,7 @@ import scala.collection import scala.collection.mutable import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.{GpuSemaphore, RapidsBuffer, RapidsBufferHandle, RapidsConf, ShuffleReceivedBufferCatalog} +import com.nvidia.spark.rapids.{GpuSemaphore, RapidsConf, RapidsShuffleHandle, ShuffleReceivedBufferCatalog} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.jni.RmmSpark @@ -85,7 +85,7 @@ class RapidsShuffleIterator( * A result for a successful buffer received * @param handle - the shuffle received buffer handle as tracked in the catalog */ - case class BufferReceived(handle: RapidsBufferHandle) extends ShuffleClientResult + case class BufferReceived(handle: RapidsShuffleHandle) extends ShuffleClientResult /** * A result for a failed attempt at receiving block metadata, or corresponding batches. @@ -221,7 +221,7 @@ class RapidsShuffleIterator( override def getTaskIds: Array[Long] = taskIds - def start(expectedBatches: Int): Unit = resolvedBatches.synchronized { + override def start(expectedBatches: Int): Unit = resolvedBatches.synchronized { if (expectedBatches == 0) { throw new IllegalStateException( s"Received an invalid response from shuffle server: " + @@ -240,7 +240,7 @@ class RapidsShuffleIterator( def clientDone: Boolean = clientExpectedBatches > 0 && clientExpectedBatches == clientResolvedBatches - def batchReceived(handle: RapidsBufferHandle): Boolean = { + override def batchReceived(handle: RapidsShuffleHandle): Boolean = { resolvedBatches.synchronized { if (taskComplete) { false @@ -305,8 +305,7 @@ class RapidsShuffleIterator( logWarning(s"Iterator for task ${taskAttemptIdStr} closing, " + s"but it is not done. Closing ${resolvedBatches.size()} resolved batches!!") resolvedBatches.forEach { - case BufferReceived(handle) => - GpuShuffleEnv.getReceivedCatalog.removeBuffer(handle) + case BufferReceived(handle) => handle.close() case _ => } // tell the client to cancel pending requests @@ -332,8 +331,6 @@ class RapidsShuffleIterator( } override def next(): ColumnarBatch = { - var cb: ColumnarBatch = null - var sb: RapidsBuffer = null val range = new NvtxRange(s"RapidshuffleIterator.next", NvtxColor.RED) // If N tasks downstream are accumulating memory we run the risk OOM @@ -375,16 +372,12 @@ class RapidsShuffleIterator( val nvtxRangeAfterGettingBatch = new NvtxRange("RapidsShuffleIterator.gotBatch", NvtxColor.PURPLE) try { - sb = catalog.acquireBuffer(handle) - cb = sb.getColumnarBatch(sparkTypes) - metricsUpdater.update(blockedTime, 1, sb.memoryUsedBytes, cb.numRows()) + val (cb, memoryUsedBytes) = catalog.getColumnarBatchAndRemove(handle, sparkTypes) + metricsUpdater.update(blockedTime, 1, memoryUsedBytes, cb.numRows()) + cb } finally { nvtxRangeAfterGettingBatch.close() range.close() - if (sb != null) { - sb.close() - } - catalog.removeBuffer(handle) } case Some( TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage, throwable)) => @@ -410,6 +403,5 @@ class RapidsShuffleIterator( case _ => throw new IllegalStateException(s"Invalid result type $result") } - cb } } diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala index 4781e649c21..c9386cff4f3 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala @@ -33,13 +33,13 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.shuffle.{RapidsShuffleIterator, RapidsShuffleTransport} import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.{ShuffleReader, ShuffleReadMetricsReporter} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator @@ -76,10 +76,10 @@ class RapidsCachingReader[K, C]( try { val blocksForRapidsTransport = new ArrayBuffer[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]() - val cachedBlocks = new ArrayBuffer[BlockId]() - val cachedBufferHandles = new ArrayBuffer[RapidsBufferHandle]() - val blocksByAddressMap: Map[BlockManagerId, collection.Seq[(BlockId, Long, Int)]] = + var cachedBatchIterator: Iterator[ColumnarBatch] = Iterator.empty + val blocksByAddressMap: Map[BlockManagerId, collection.Seq[(BlockId, Long, Int)]] = blocksByAddress.toMap + var numCachedBlocks: Int = 0 blocksByAddressMap.keys.foreach(blockManagerId => { val blockInfos: collection.Seq[(BlockId, Long, Int)] = blocksByAddressMap(blockManagerId) @@ -88,33 +88,29 @@ class RapidsCachingReader[K, C]( if (blockManagerId.executorId == localId.executorId) { val readLocalRange = new NvtxRange("Read Local", NvtxColor.GREEN) try { - blockInfos.foreach( - blockInfo => { - val blockId = blockInfo._1 - val shuffleBufferHandles: IndexedSeq[RapidsBufferHandle] = blockId match { - case sbbid: ShuffleBlockBatchId => - (sbbid.startReduceId to sbbid.endReduceId).flatMap { reduceId => - cachedBlocks.append(blockId) - val sBlockId = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId) - catalog.blockIdToBufferHandles(sBlockId) - } - case sbid: ShuffleBlockId => - cachedBlocks.append(blockId) - catalog.blockIdToBufferHandles(sbid) - case _ => throw new IllegalArgumentException( - s"${blockId.getClass} $blockId is not currently supported") - } - - cachedBufferHandles ++= shuffleBufferHandles - - // Update the spill priorities of these buffers to indicate they are about - // to be read and therefore should not be spilled if possible. - shuffleBufferHandles.foreach(catalog.updateSpillPriorityForLocalRead) - - if (shuffleBufferHandles.nonEmpty) { - metrics.incLocalBlocksFetched(1) - } - }) + cachedBatchIterator = blockInfos.iterator.flatMap { blockInfo => + val blockId = blockInfo._1 + val shuffleBufferHandles = blockId match { + case sbbid: ShuffleBlockBatchId => + (sbbid.startReduceId to sbbid.endReduceId).iterator.flatMap { reduceId => + val sbid = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId) + numCachedBlocks += 1 + catalog.getColumnarBatchIterator(sbid, sparkTypes) + } + case sbid: ShuffleBlockId => + numCachedBlocks += 1 + catalog.getColumnarBatchIterator(sbid, sparkTypes) + case _ => throw new IllegalArgumentException( + s"${blockId.getClass} $blockId is not currently supported") + } + + shuffleBufferHandles + } + + // Update the spill priorities of these buffers to indicate they are about + // to be read and therefore should not be spilled if possible. + // TODO: AB: shuffleBufferHandles.foreach(catalog.updateSpillPriorityForLocalRead) + metrics.incLocalBlocksFetched(numCachedBlocks) } finally { readLocalRange.close() } @@ -136,7 +132,7 @@ class RapidsCachingReader[K, C]( } }) - logInfo(s"Will read ${cachedBlocks.size} cached blocks, " + + logInfo(s"Will read ${numCachedBlocks} cached blocks, " + s"${blocksForRapidsTransport.size} remote blocks from the RapidsShuffleTransport. ") if (transport.isEmpty && blocksForRapidsTransport.nonEmpty) { @@ -156,17 +152,12 @@ class RapidsCachingReader[K, C]( val itRange = new NvtxRange("Shuffle Iterator prep", NvtxColor.BLUE) try { - val cachedIt = cachedBufferHandles.iterator.map(bufferHandle => { - // No good way to get a metric in here for semaphore wait time - GpuSemaphore.acquireIfNecessary(context) - val cb = withResource(catalog.acquireBuffer(bufferHandle)) { buffer => - buffer.getColumnarBatch(sparkTypes) - } + val cachedIt = cachedBatchIterator.map { cb => val cachedBytesRead = GpuColumnVector.getTotalDeviceMemoryUsed(cb) metrics.incLocalBytesRead(cachedBytesRead) metrics.incRecordsRead(cb.numRows()) (0, cb) - }).asInstanceOf[Iterator[(K, C)]] + }.asInstanceOf[Iterator[(K, C)]] val cbArrayFromUcx: Iterator[(K, C)] = if (blocksForRapidsTransport.nonEmpty) { val rapidsShuffleIterator = new RapidsShuffleIterator(localId, rapidsConf, transport.get, diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala index 24755c2c0a1..fb990b99d3b 100644 --- a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.time._ import org.scalatestplus.mockito.MockitoSugar.mock -import org.apache.spark.TaskContext +import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.execution.TrampolineUtil @@ -36,7 +36,7 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil class HostAllocSuite extends AnyFunSuite with BeforeAndAfterEach with BeforeAndAfterAll with TimeLimits { private val sqlConf = new SQLConf() - private val rc = new RapidsConf(sqlConf) + Rmm.shutdown() private val timeoutMs = 10000 def setMockContext(taskAttemptId: Long): Unit = { @@ -316,23 +316,34 @@ class HostAllocSuite extends AnyFunSuite with BeforeAndAfterEach with private var rmmWasInitialized = false override def beforeEach(): Unit = { - RapidsBufferCatalog.close() + val sc = new SparkConf SparkSession.getActiveSession.foreach(_.stop()) SparkSession.clearActiveSession() + SpillFramework.shutdown() if (Rmm.isInitialized) { rmmWasInitialized = true Rmm.shutdown() } Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) + // this doesn't allocate memory for bounce buffers, as HostAllocSuite + // is playing games with the pools. + SpillFramework.storesInternal = new SpillableStores { + override var deviceStore: SpillableDeviceStore = new SpillableDeviceStore + override var hostStore: SpillableHostStore = new SpillableHostStore(None) + override var diskStore: DiskHandleStore = new DiskHandleStore(sc) + } + // some tests need an event handler + RmmSpark.setEventHandler( + new DeviceMemoryEventHandler(SpillFramework.stores.deviceStore, None, 0)) PinnedMemoryPool.shutdown() HostAlloc.initialize(-1) - RapidsBufferCatalog.init(rc) } override def afterAll(): Unit = { - RapidsBufferCatalog.close() + SpillFramework.shutdown() PinnedMemoryPool.shutdown() Rmm.shutdown() + RmmSpark.clearEventHandler() if (rmmWasInitialized) { // put RMM back for other tests to use Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala index 0b531adabb7..8045691b504 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,12 +23,9 @@ import org.scalatestplus.mockito.MockitoSugar class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoSugar { test("a failed allocation should be retried if we spilled enough") { - val mockCatalog = mock[RapidsBufferCatalog] - val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSpillableSize).thenReturn(1024) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024L)) + val mockStore = mock[SpillableDeviceStore] + when(mockStore.spill(any())).thenAnswer(_ => 1024L) val handler = new DeviceMemoryEventHandler( - mockCatalog, mockStore, None, 2) @@ -36,12 +33,9 @@ class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoS } test("when we deplete the store, retry up to max failed OOM retries") { - val mockCatalog = mock[RapidsBufferCatalog] - val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSpillableSize).thenReturn(0) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(0L)) + val mockStore = mock[SpillableDeviceStore] + when(mockStore.spill(any())).thenAnswer(_ => 0L) val handler = new DeviceMemoryEventHandler( - mockCatalog, mockStore, None, 2) @@ -51,12 +45,9 @@ class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoS } test("we reset our OOM state after a successful retry") { - val mockCatalog = mock[RapidsBufferCatalog] - val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSpillableSize).thenReturn(0) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(0L)) + val mockStore = mock[SpillableDeviceStore] + when(mockStore.spill(any())).thenAnswer(_ => 0L) val handler = new DeviceMemoryEventHandler( - mockCatalog, mockStore, None, 2) @@ -69,12 +60,9 @@ class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoS } test("a negative allocation cannot be retried and handler throws") { - val mockCatalog = mock[RapidsBufferCatalog] - val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSpillableSize).thenReturn(1024) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024L)) + val mockStore = mock[SpillableDeviceStore] + when(mockStore.spill(any())).thenAnswer(_ => 1024L) val handler = new DeviceMemoryEventHandler( - mockCatalog, mockStore, None, 2) @@ -82,12 +70,9 @@ class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoS } test("a negative retry count is invalid") { - val mockCatalog = mock[RapidsBufferCatalog] - val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSpillableSize).thenReturn(1024) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024L)) + val mockStore = mock[SpillableDeviceStore] + when(mockStore.spill(any())).thenAnswer(_ => 1024L) val handler = new DeviceMemoryEventHandler( - mockCatalog, mockStore, None, 2) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala index 725a2e37032..202c12eb774 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala @@ -20,7 +20,7 @@ import ai.rapids.cudf.Table import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.jni.{GpuSplitAndRetryOOM, RmmSpark} import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{doAnswer, spy, times, verify} +import org.mockito.Mockito.{doAnswer, spy} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatestplus.mockito.MockitoSugar @@ -41,6 +41,14 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } } + override def beforeEach(): Unit = { + // some tests in this suite will want to perform `verify` calls on the device store + // so we close it and create a spy around one. + super.beforeEach() + SpillFramework.storesInternal.deviceStore.close() + SpillFramework.storesInternal.deviceStore = spy(new SpillableDeviceStore) + } + private def getAndResetNumRetryThrowCurrentTask: Int = { // taskId 1 was associated with the current thread in RmmSparkRetrySuiteBase RmmSpark.getAndResetNumRetryThrow(/*taskId*/ 1) @@ -65,26 +73,24 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } assert(!GpuColumnVector.extractBases(batch).exists(_.getRefCount > 0)) assert(!myIter.hasNext) - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } test("a retry when converting to a table is handled") { val batch = buildBatch() val batchIter = Seq(batch).iterator - var rapidsBufferSpy: RapidsBuffer = null - doAnswer(new Answer[AnyRef]() { - override def answer(invocation: InvocationOnMock): AnyRef = { - val res = invocation.callRealMethod() + doAnswer(new Answer[Boolean]() { + override def answer(invocation: InvocationOnMock): Boolean = { + invocation.callRealMethod() // we mock things this way due to code generation issues with mockito. // when we add a table we have RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 3, RmmSpark.OomInjectionType.GPU.ordinal, 0) - rapidsBufferSpy = spy(res.asInstanceOf[RapidsBuffer]) - rapidsBufferSpy + true } - }).when(deviceStorage) - .addTable(any(), any(), any(), any()) + }).when(SpillFramework.stores.deviceStore) + .track(any()) withResource(new ColumnarToRowIterator(batchIter, NoopMetric, NoopMetric, NoopMetric, NoopMetric)) { ctriter => @@ -102,35 +108,27 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } assertResult(6)(getAndResetNumRetryThrowCurrentTask) assert(!myIter.hasNext) - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) - // This is my wrap around of checking that we did retry the last part - // where we are converting the device column of rows into an actual column. - // Because we asked for 3 retries, we would ask the spill framework 4 times to materialize - // a batch. - verify(rapidsBufferSpy, times(4)) - .getColumnarBatch(any()) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } test("spilling the device column of rows works") { val batch = buildBatch() val batchIter = Seq(batch).iterator - var rapidsBufferSpy: RapidsBuffer = null - doAnswer(new Answer[AnyRef]() { - override def answer(invocation: InvocationOnMock): AnyRef = { - val res = invocation.callRealMethod() + doAnswer(new Answer[Boolean]() { + override def answer(invocation: InvocationOnMock): Boolean = { + val handle = invocation.getArgument(0).asInstanceOf[SpillableColumnarBatchHandle] // we mock things this way due to code generation issues with mockito. // when we add a table we have RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 3, RmmSpark.OomInjectionType.GPU.ordinal, 0) - rapidsBufferSpy = spy(res.asInstanceOf[RapidsBuffer]) // at this point we have created a buffer in the Spill Framework // lets spill it - RapidsBufferCatalog.singleton.synchronousSpill(deviceStorage, 0) - rapidsBufferSpy + SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + true } - }).when(deviceStorage) - .addTable(any(), any(), any(), any()) + }).when(SpillFramework.stores.deviceStore) + .track(any()) withResource(new ColumnarToRowIterator(batchIter, NoopMetric, NoopMetric, NoopMetric, NoopMetric)) { ctriter => @@ -148,13 +146,7 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } assertResult(6)(getAndResetNumRetryThrowCurrentTask) assert(!myIter.hasNext) - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) - // This is my wrap around of checking that we did retry the last part - // where we are converting the device column of rows into an actual column. - // Because we asked for 3 retries, we would ask the spill framework 4 times to materialize - // a batch. - verify(rapidsBufferSpy, times(4)) - .getColumnarBatch(any()) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } @@ -173,7 +165,7 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite assertThrows[GpuSplitAndRetryOOM] { myIter.next() } - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } @@ -199,7 +191,7 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } assert(!GpuColumnVector.extractBases(batch).exists(_.getRefCount > 0)) assert(!myIter.hasNext) - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } @@ -225,7 +217,7 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite } assert(!GpuColumnVector.extractBases(batch).exists(_.getRefCount > 0)) assert(!myIter.hasNext) - assertResult(0)(RapidsBufferCatalog.getDeviceStorage.currentSize) + assertResult(0)(SpillFramework.stores.deviceStore.spill(1)) } } } \ No newline at end of file diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala index 34a9fa984d5..4f0027d4c2a 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesRetrySuite.scala @@ -199,7 +199,7 @@ class GpuCoalesceBatchesRetrySuite val batches = iter.asInstanceOf[CoalesceIteratorMocks].getBatches() assertResult(10)(batches.length) batches.foreach(b => - verify(b, times(1)).close() + GpuColumnVector.extractBases(b).forall(_.getRefCount == 0) ) } } @@ -209,7 +209,7 @@ class GpuCoalesceBatchesRetrySuite var refCount = 1 override def numRows(): Int = 0 override def setSpillPriority(priority: Long): Unit = {} - override def getColumnarBatch(): ColumnarBatch = { + override def getColumnarBatch: ColumnarBatch = { throw new GpuSplitAndRetryOOM() } override def sizeInBytes: Long = 0 diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuGenerateSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuGenerateSuite.scala index e69f7c75118..fbbf0acc20d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuGenerateSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuGenerateSuite.scala @@ -268,12 +268,12 @@ class GpuGenerateSuite var forceOOM: Boolean) extends SpillableColumnarBatch { override def numRows(): Int = spillable.numRows() override def setSpillPriority(priority: Long): Unit = spillable.setSpillPriority(priority) - override def getColumnarBatch(): ColumnarBatch = { + override def getColumnarBatch: ColumnarBatch = { if (forceOOM) { forceOOM = false throw new GpuSplitAndRetryOOM(s"mock split and retry") } - spillable.getColumnarBatch() + spillable.getColumnarBatch } override def sizeInBytes: Long = spillable.sizeInBytes override def dataTypes: Array[DataType] = spillable.dataTypes diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala index 1e3c0f699da..fdbe316a394 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala @@ -16,20 +16,20 @@ package com.nvidia.spark.rapids -import java.io.File import java.math.RoundingMode import ai.rapids.cudf.{ColumnVector, Cuda, DType, Table} import com.nvidia.spark.rapids.Arm.withResource +import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.{GpuShuffleEnv, RapidsDiskBlockManager} +import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types.{DecimalType, DoubleType, IntegerType, StringType} import org.apache.spark.sql.vectorized.ColumnarBatch -class GpuPartitioningSuite extends AnyFunSuite { +class GpuPartitioningSuite extends AnyFunSuite with BeforeAndAfterEach { var rapidsConf = new RapidsConf(Map[String, String]()) private def buildBatch(): ColumnarBatch = { @@ -113,7 +113,7 @@ class GpuPartitioningSuite extends AnyFunSuite { TrampolineUtil.cleanupAnyExistingSession() val conf = new SparkConf().set(RapidsConf.SHUFFLE_COMPRESSION_CODEC.key, "none") TestUtils.withGpuSparkSession(conf) { _ => - GpuShuffleEnv.init(new RapidsConf(conf), new RapidsDiskBlockManager(conf)) + GpuShuffleEnv.init(new RapidsConf(conf)) val partitionIndices = Array(0, 2, 2) val gp = new GpuPartitioning { override val numPartitions: Int = partitionIndices.length @@ -157,61 +157,53 @@ class GpuPartitioningSuite extends AnyFunSuite { val conf = new SparkConf() .set(RapidsConf.SHUFFLE_COMPRESSION_CODEC.key, codecName) TestUtils.withGpuSparkSession(conf) { _ => - GpuShuffleEnv.init(new RapidsConf(conf), new RapidsDiskBlockManager(conf)) - val spillPriority = 7L - - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = new RapidsBufferCatalog(store) - val partitionIndices = Array(0, 2, 2) - val gp = new GpuPartitioning { - override val numPartitions: Int = partitionIndices.length - } - withResource(buildBatch()) { batch => - // `sliceInternalOnGpuAndClose` will close the batch, but in this test we want to - // reuse it - GpuColumnVector.incRefCounts(batch) - val columns = GpuColumnVector.extractColumns(batch) - val sparkTypes = GpuColumnVector.extractTypes(batch) - val numRows = batch.numRows - withResource( - gp.sliceInternalOnGpuAndClose(numRows, partitionIndices, columns)) { partitions => - partitions.zipWithIndex.foreach { case (partBatch, partIndex) => - val startRow = partitionIndices(partIndex) - val endRow = if (partIndex < partitionIndices.length - 1) { - partitionIndices(partIndex + 1) - } else { - batch.numRows - } - val expectedRows = endRow - startRow - assertResult(expectedRows)(partBatch.numRows) - val columns = (0 until partBatch.numCols).map(i => partBatch.column(i)) - columns.foreach { column => - // batches with any rows should be compressed, and - // batches with no rows should not be compressed. - val actualRows = column match { - case c: GpuCompressedColumnVector => - val rows = c.getTableMeta.rowCount - assert(rows != 0) - rows - case c: GpuPackedTableColumn => - val rows = c.getContiguousTable.getRowCount - assert(rows == 0) - rows - case _ => - throw new IllegalStateException("column should either be compressed or packed") - } - assertResult(expectedRows)(actualRows) + GpuShuffleEnv.init(new RapidsConf(conf)) + val partitionIndices = Array(0, 2, 2) + val gp = new GpuPartitioning { + override val numPartitions: Int = partitionIndices.length + } + withResource(buildBatch()) { batch => + // `sliceInternalOnGpuAndClose` will close the batch, but in this test we want to + // reuse it + GpuColumnVector.incRefCounts(batch) + val columns = GpuColumnVector.extractColumns(batch) + val numRows = batch.numRows + withResource( + gp.sliceInternalOnGpuAndClose(numRows, partitionIndices, columns)) { partitions => + partitions.zipWithIndex.foreach { case (partBatch, partIndex) => + val startRow = partitionIndices(partIndex) + val endRow = if (partIndex < partitionIndices.length - 1) { + partitionIndices(partIndex + 1) + } else { + batch.numRows + } + val expectedRows = endRow - startRow + assertResult(expectedRows)(partBatch.numRows) + val columns = (0 until partBatch.numCols).map(i => partBatch.column(i)) + columns.foreach { column => + // batches with any rows should be compressed, and + // batches with no rows should not be compressed. + val actualRows = column match { + case c: GpuCompressedColumnVector => + val rows = c.getTableMeta.rowCount + assert(rows != 0) + rows + case c: GpuPackedTableColumn => + val rows = c.getContiguousTable.getRowCount + assert(rows == 0) + rows + case _ => + throw new IllegalStateException( + "column should either be compressed or packed") } - if (GpuCompressedColumnVector.isBatchCompressed(partBatch)) { - val gccv = columns.head.asInstanceOf[GpuCompressedColumnVector] - val devBuffer = gccv.getTableBuffer - val handle = catalog.addBuffer(devBuffer, gccv.getTableMeta, spillPriority) - withResource(buildSubBatch(batch, startRow, endRow)) { expectedBatch => - withResource(catalog.acquireBuffer(handle)) { buffer => - withResource(buffer.getColumnarBatch(sparkTypes)) { batch => - compareBatches(expectedBatch, batch) - } - } + assertResult(expectedRows)(actualRows) + } + if (GpuCompressedColumnVector.isBatchCompressed(partBatch)) { + GpuCompressedColumnVector.incRefCounts(partBatch) + val handle = SpillableColumnarBatch(partBatch, -1) + withResource(buildSubBatch(batch, startRow, endRow)) { expectedBatch => + withResource(handle.getColumnarBatch()) { cb => + compareBatches(expectedBatch, cb) } } } @@ -220,9 +212,4 @@ class GpuPartitioningSuite extends AnyFunSuite { } } } -} - -case class MockRapidsBufferId(tableId: Int) extends RapidsBufferId { - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - throw new UnsupportedOperationException -} +} \ No newline at end of file diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala index 9211c32e142..6c5ff16ffcb 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ import com.nvidia.spark.rapids.Arm.withResource import org.scalatest.funsuite.AnyFunSuite import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.{GpuShuffleEnv, RapidsDiskBlockManager} +import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.sql.types.{DecimalType, DoubleType, IntegerType, StringType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -48,7 +48,7 @@ class GpuSinglePartitioningSuite extends AnyFunSuite { .set("spark.rapids.shuffle.mode", RapidsConf.RapidsShuffleManagerMode.UCX.toString) .set(RapidsConf.SHUFFLE_COMPRESSION_CODEC.key, "none") TestUtils.withGpuSparkSession(conf) { _ => - GpuShuffleEnv.init(new RapidsConf(conf), new RapidsDiskBlockManager(conf)) + GpuShuffleEnv.init(new RapidsConf(conf)) val partitioner = GpuSinglePartitioning withResource(buildBatch()) { batch => withResource(GpuColumnVector.from(batch)) { table => diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala index 58608ed132c..82fb1dd4154 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala @@ -70,10 +70,8 @@ class HashAggregateRetrySuite when(aggHelper.aggOrdinals).thenReturn(aggOrdinals) // attempt a cuDF reduction - withResource(input) { _ => - GpuAggregateIterator.aggregate( - aggHelper, input, mockMetrics) - } + GpuAggregateIterator.aggregate( + aggHelper, input, mockMetrics) } def makeGroupByAggHelper(forceMerge: Boolean): AggHelper = { @@ -118,172 +116,187 @@ class HashAggregateRetrySuite test("computeAndAggregate reduction with retry") { val reductionBatch = buildReductionBatch() - RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doReduction(reductionBatch) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(1)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - withResource(gcv.getBase.copyToHost()) { hcv => - assertResult(9)(hcv.getLong(0)) + withResource(reductionBatch.incRefCount()) { _ => + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doReduction(reductionBatch) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(1)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + withResource(gcv.getBase.copyToHost()) { hcv => + assertResult(9)(hcv.getLong(0)) + } } } + // we need to request a ColumnarBatch twice here for the retry + // why is this invoking the underlying method + verify(reductionBatch, times(2)).getColumnarBatch } - // we need to request a ColumnarBatch twice here for the retry - verify(reductionBatch, times(2)).getColumnarBatch() } test("computeAndAggregate reduction with two retries") { val reductionBatch = buildReductionBatch() - RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 2, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doReduction(reductionBatch) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(1)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - withResource(gcv.getBase.copyToHost()) { hcv => - assertResult(9)(hcv.getLong(0)) + withResource(reductionBatch.incRefCount()) { _ => + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 2, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doReduction(reductionBatch) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(1)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + withResource(gcv.getBase.copyToHost()) { hcv => + assertResult(9)(hcv.getLong(0)) + } } } + // we need to request a ColumnarBatch three times, because of 1 regular attempt, + // and two retries + verify(reductionBatch, times(3)).getColumnarBatch } - // we need to request a ColumnarBatch three times, because of 1 regular attempt, - // and two retries - verify(reductionBatch, times(3)).getColumnarBatch() } test("computeAndAggregate reduction with cudf exception") { val reductionBatch = buildReductionBatch() - RmmSpark.forceCudfException(RmmSpark.getCurrentThreadId) - assertThrows[CudfException] { - doReduction(reductionBatch) + withResource(reductionBatch.incRefCount()) { _ => + RmmSpark.forceCudfException(RmmSpark.getCurrentThreadId) + assertThrows[CudfException] { + doReduction(reductionBatch) + } + // columnar batch was obtained once, but since this was not a retriable exception + // we don't retry it + verify(reductionBatch, times(1)).getColumnarBatch } - // columnar batch was obtained once, but since this was not a retriable exception - // we don't retry it - verify(reductionBatch, times(1)).getColumnarBatch() } test("computeAndAggregate group by with retry") { val groupByBatch = buildGroupByBatch() - RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doGroupBy(groupByBatch) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(3)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - val aggv = cb.column(1).asInstanceOf[GpuColumnVector] - var rowsLeftToMatch = 3 - withResource(aggv.getBase.copyToHost()) { aggvh => - withResource(gcv.getBase.copyToHost()) { grph => - (0 until 3).foreach { row => - if (grph.isNull(row)) { - assertResult(2L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 5) { - assertResult(1L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 1) { - assertResult(7L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 + withResource(groupByBatch.incRefCount()) { _ => + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doGroupBy(groupByBatch) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(3)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + val aggv = cb.column(1).asInstanceOf[GpuColumnVector] + var rowsLeftToMatch = 3 + withResource(aggv.getBase.copyToHost()) { aggvh => + withResource(gcv.getBase.copyToHost()) { grph => + (0 until 3).foreach { row => + if (grph.isNull(row)) { + assertResult(2L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 5) { + assertResult(1L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 1) { + assertResult(7L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } } } } + assertResult(0)(rowsLeftToMatch) } - assertResult(0)(rowsLeftToMatch) } + // we need to request a ColumnarBatch twice here for the retry + verify(groupByBatch, times(2)).getColumnarBatch } - // we need to request a ColumnarBatch twice here for the retry - verify(groupByBatch, times(2)).getColumnarBatch() } test("computeAndAggregate reduction with split and retry") { val reductionBatch = buildReductionBatch() - RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doReduction(reductionBatch) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(1)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + withResource(reductionBatch.incRefCount()) { _ => + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doReduction(reductionBatch) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(1)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - withResource(gcv.getBase.copyToHost()) { hcv => - assertResult(9L)(hcv.getLong(0)) + withResource(gcv.getBase.copyToHost()) { hcv => + assertResult(9L)(hcv.getLong(0)) + } } } + // the second time we access this batch is to split it + verify(reductionBatch, times(2)).getColumnarBatch } - // the second time we access this batch is to split it - verify(reductionBatch, times(2)).getColumnarBatch() } test("computeAndAggregate group by with split retry") { val groupByBatch = buildGroupByBatch() - RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doGroupBy(groupByBatch) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(3)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - val aggv = cb.column(1).asInstanceOf[GpuColumnVector] - var rowsLeftToMatch = 3 - withResource(aggv.getBase.copyToHost()) { aggvh => - withResource(gcv.getBase.copyToHost()) { grph => - (0 until 3).foreach { row => - if (grph.isNull(row)) { - assertResult(2L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 5) { - assertResult(1L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 1) { - assertResult(7L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 + withResource(groupByBatch.incRefCount()) { _ => + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doGroupBy(groupByBatch) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(3)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + val aggv = cb.column(1).asInstanceOf[GpuColumnVector] + var rowsLeftToMatch = 3 + withResource(aggv.getBase.copyToHost()) { aggvh => + withResource(gcv.getBase.copyToHost()) { grph => + (0 until 3).foreach { row => + if (grph.isNull(row)) { + assertResult(2L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 5) { + assertResult(1L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 1) { + assertResult(7L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } } } } + assertResult(0)(rowsLeftToMatch) } - assertResult(0)(rowsLeftToMatch) } + // the second time we access this batch is to split it + verify(groupByBatch, times(2)).getColumnarBatch } - // the second time we access this batch is to split it - verify(groupByBatch, times(2)).getColumnarBatch() } test("computeAndAggregate group by with retry and forceMerge") { // with forceMerge we expect 1 batch to be returned at all costs val groupByBatch = buildGroupByBatch() - // we force a split because that would cause us to compute two aggs - RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, - RmmSpark.OomInjectionType.GPU.ordinal, 0) - val result = doGroupBy(groupByBatch, forceMerge = true) - withResource(result) { spillable => - withResource(spillable.getColumnarBatch) { cb => - assertResult(3)(cb.numRows) - val gcv = cb.column(0).asInstanceOf[GpuColumnVector] - val aggv = cb.column(1).asInstanceOf[GpuColumnVector] - var rowsLeftToMatch = 3 - withResource(aggv.getBase.copyToHost()) { aggvh => - withResource(gcv.getBase.copyToHost()) { grph => - (0 until 3).foreach { row => - if (grph.isNull(row)) { - assertResult(2L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 5) { - assertResult(1L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 - } else if (grph.getInt(row) == 1) { - assertResult(7L)(aggvh.getLong(row)) - rowsLeftToMatch -= 1 + withResource(groupByBatch.incRefCount()) { _ => + // we force a split because that would cause us to compute two aggs + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + val result = doGroupBy(groupByBatch, forceMerge = true) + withResource(result) { spillable => + withResource(spillable.getColumnarBatch) { cb => + assertResult(3)(cb.numRows) + val gcv = cb.column(0).asInstanceOf[GpuColumnVector] + val aggv = cb.column(1).asInstanceOf[GpuColumnVector] + var rowsLeftToMatch = 3 + withResource(aggv.getBase.copyToHost()) { aggvh => + withResource(gcv.getBase.copyToHost()) { grph => + (0 until 3).foreach { row => + if (grph.isNull(row)) { + assertResult(2L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 5) { + assertResult(1L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } else if (grph.getInt(row) == 1) { + assertResult(7L)(aggvh.getLong(row)) + rowsLeftToMatch -= 1 + } } } } + assertResult(0)(rowsLeftToMatch) } - assertResult(0)(rowsLeftToMatch) } + // we need to request a ColumnarBatch twice here for the retry + verify(groupByBatch, times(2)).getColumnarBatch } - // we need to request a ColumnarBatch twice here for the retry - verify(groupByBatch, times(2)).getColumnarBatch() } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala index d018726ef35..8ca525a1f5c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala @@ -63,7 +63,7 @@ class NonDeterministicRetrySuite extends RmmSparkRetrySuiteBase { } } } - + test("GPU project retry with GPU rand") { def projectRand(): Seq[GpuExpression] = Seq( GpuAlias(GpuRand(GpuLiteral(RAND_SEED)), "rand")()) @@ -154,5 +154,4 @@ class NonDeterministicRetrySuite extends RmmSparkRetrySuiteBase { } } } - } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala deleted file mode 100644 index 9b5b37af480..00000000000 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala +++ /dev/null @@ -1,368 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.File - -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.StorageTier.{DEVICE, DISK, HOST, StorageTier} -import com.nvidia.spark.rapids.format.TableMeta -import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito._ -import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.mockito.MockitoSugar - -import org.apache.spark.sql.rapids.RapidsDiskBlockManager -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.vectorized.ColumnarBatch - -class RapidsBufferCatalogSuite extends AnyFunSuite with MockitoSugar { - test("lookup unknown buffer") { - val catalog = new RapidsBufferCatalog - val bufferId = new RapidsBufferId { - override val tableId: Int = 10 - override def getDiskPath(m: RapidsDiskBlockManager): File = null - } - val bufferHandle = new RapidsBufferHandle { - override val id: RapidsBufferId = bufferId - override def setSpillPriority(newPriority: Long): Unit = {} - override def close(): Unit = {} - } - - assertThrows[NoSuchElementException](catalog.acquireBuffer(bufferHandle)) - assertThrows[NoSuchElementException](catalog.getBufferMeta(bufferId)) - } - - test("buffer double register throws") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId) - catalog.registerNewBuffer(buffer) - val buffer2 = mockBuffer(bufferId) - assertThrows[DuplicateBufferException](catalog.registerNewBuffer(buffer2)) - } - - test("a second handle prevents buffer to be removed") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId) - catalog.registerNewBuffer(buffer) - val handle1 = - catalog.makeNewHandle(bufferId, -1) - val handle2 = - catalog.makeNewHandle(bufferId, -1) - - handle1.close() - - // this does not throw - catalog.acquireBuffer(handle2).close() - // actually this doesn't throw either - catalog.acquireBuffer(handle1).close() - - handle2.close() - - assertThrows[NoSuchElementException](catalog.acquireBuffer(handle1)) - assertThrows[NoSuchElementException](catalog.acquireBuffer(handle2)) - } - - test("spill priorities are updated as handles are registered and unregistered") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, initialPriority = -1) - catalog.registerNewBuffer(buffer) - val handle1 = - catalog.makeNewHandle(bufferId, -1) - withResource(catalog.acquireBuffer(handle1)) { buff => - assertResult(-1)(buff.getSpillPriority) - } - val handle2 = - catalog.makeNewHandle(bufferId, 0) - withResource(catalog.acquireBuffer(handle2)) { buff => - assertResult(0)(buff.getSpillPriority) - } - - // removing the lower priority handle, keeps the high priority spill - handle1.close() - withResource(catalog.acquireBuffer(handle2)) { buff => - assertResult(0)(buff.getSpillPriority) - } - - // adding a lower priority -1000 handle keeps the high priority (0) spill - val handle3 = - catalog.makeNewHandle(bufferId, -1000) - withResource(catalog.acquireBuffer(handle3)) { buff => - assertResult(0)(buff.getSpillPriority) - } - - // removing the high priority spill (0) brings us down to the - // low priority that is remaining - handle2.close() - withResource(catalog.acquireBuffer(handle2)) { buff => - assertResult(-1000)(buff.getSpillPriority) - } - - handle3.close() - } - - test("buffer registering slower tier does not hide faster tier") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - val handle = catalog.makeNewHandle(bufferId, 0) - val buffer2 = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer2) - val buffer3 = mockBuffer(bufferId, tier = DISK) - catalog.registerNewBuffer(buffer3) - val acquired = catalog.acquireBuffer(handle) - assertResult(5)(acquired.id.tableId) - assertResult(buffer)(acquired) - - // registering the handle acquires the buffer - verify(buffer, times(2)).addReference() - } - - test("acquire buffer") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId) - catalog.registerNewBuffer(buffer) - val handle = catalog.makeNewHandle(bufferId, 0) - val acquired = catalog.acquireBuffer(handle) - assertResult(5)(acquired.id.tableId) - assertResult(buffer)(acquired) - - // registering the handle acquires the buffer - verify(buffer, times(2)).addReference() - } - - test("acquire buffer retries automatically") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, acquireAttempts = 9) - catalog.registerNewBuffer(buffer) - val handle = catalog.makeNewHandle(bufferId, 0) - val acquired = catalog.acquireBuffer(handle) - assertResult(5)(acquired.id.tableId) - assertResult(buffer)(acquired) - - // registering the handle acquires the buffer - verify(buffer, times(10)).addReference() - } - - test("acquire buffer at specific tier") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - val buffer2 = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer2) - val acquired = catalog.acquireBuffer(MockBufferId(5), HOST).get - assertResult(5)(acquired.id.tableId) - assertResult(buffer2)(acquired) - verify(buffer2).addReference() - } - - test("acquire buffer at nonexistent tier") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer) - assert(catalog.acquireBuffer(MockBufferId(5), DEVICE).isEmpty) - assert(catalog.acquireBuffer(MockBufferId(5), DISK).isEmpty) - } - - test("get buffer meta") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val expectedMeta = new TableMeta - val buffer = mockBuffer(bufferId, tableMeta = expectedMeta) - catalog.registerNewBuffer(buffer) - val meta = catalog.getBufferMeta(bufferId) - assertResult(expectedMeta)(meta) - } - - test("buffer is spilled to slower tier only") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - val buffer2 = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer2) - val buffer3 = mockBuffer(bufferId, tier = DISK) - catalog.registerNewBuffer(buffer3) - assert(catalog.isBufferSpilled(bufferId, DEVICE)) - assert(catalog.isBufferSpilled(bufferId, HOST)) - assert(!catalog.isBufferSpilled(bufferId, DISK)) - } - - test("multiple calls to unspill return existing DEVICE buffer") { - withResource(spy(new RapidsDeviceMemoryStore)) { deviceStore => - val mockStore = mock[RapidsBufferStore] - withResource( - new RapidsHostMemoryStore(Some(10000))) { hostStore => - deviceStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - val catalog = new RapidsBufferCatalog(deviceStore) - val handle = withResource(DeviceMemoryBuffer.allocate(1024)) { buff => - val meta = MetaUtils.getTableMetaNoTable(buff.getLength) - catalog.addBuffer( - buff, meta, -1) - } - withResource(handle) { _ => - catalog.synchronousSpill(deviceStore, 0) - val acquiredHostBuffer = catalog.acquireBuffer(handle) - val unspilled = withResource(acquiredHostBuffer) { _ => - assertResult(HOST)(acquiredHostBuffer.storageTier) - val unspilled = - catalog.unspillBufferToDeviceStore( - acquiredHostBuffer, - Cuda.DEFAULT_STREAM) - withResource(unspilled) { _ => - assertResult(DEVICE)(unspilled.storageTier) - } - val unspilledSame = catalog.unspillBufferToDeviceStore( - acquiredHostBuffer, - Cuda.DEFAULT_STREAM) - withResource(unspilledSame) { _ => - assertResult(unspilled)(unspilledSame) - } - // verify that we invoked the copy function exactly once - verify(deviceStore, times(1)).copyBuffer(any(), any(), any()) - unspilled - } - val unspilledSame = catalog.unspillBufferToDeviceStore( - acquiredHostBuffer, - Cuda.DEFAULT_STREAM) - withResource(unspilledSame) { _ => - assertResult(unspilled)(unspilledSame) - } - // verify that we invoked the copy function exactly once - verify(deviceStore, times(1)).copyBuffer(any(), any(), any()) - } - } - } - } - - test("remove buffer tier") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - val buffer2 = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer2) - val buffer3 = mockBuffer(bufferId, tier = DISK) - catalog.registerNewBuffer(buffer3) - catalog.removeBufferTier(bufferId, DEVICE) - catalog.removeBufferTier(bufferId, DISK) - assert(catalog.acquireBuffer(MockBufferId(5), DEVICE).isEmpty) - assert(catalog.acquireBuffer(MockBufferId(5), HOST).isDefined) - assert(catalog.acquireBuffer(MockBufferId(5), DISK).isEmpty) - } - - test("remove nonexistent buffer tier") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - catalog.removeBufferTier(bufferId, HOST) - catalog.removeBufferTier(bufferId, DISK) - assert(catalog.acquireBuffer(MockBufferId(5), DEVICE).isDefined) - assert(catalog.acquireBuffer(MockBufferId(5), HOST).isEmpty) - assert(catalog.acquireBuffer(MockBufferId(5), DISK).isEmpty) - } - - test("remove buffer releases buffer resources") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId) - catalog.registerNewBuffer(buffer) - val handle = catalog.makeNewHandle( - bufferId, -1) - handle.close() - verify(buffer).free() - } - - test("remove buffer releases buffer resources at all tiers") { - val catalog = new RapidsBufferCatalog - val bufferId = MockBufferId(5) - val buffer = mockBuffer(bufferId, tier = DEVICE) - catalog.registerNewBuffer(buffer) - val handle = catalog.makeNewHandle( - bufferId, -1) - - // these next registrations don't get their own handle. This is an internal - // operation from the store where it has spilled to host and disk the RapidsBuffer - val buffer2 = mockBuffer(bufferId, tier = HOST) - catalog.registerNewBuffer(buffer2) - val buffer3 = mockBuffer(bufferId, tier = DISK) - catalog.registerNewBuffer(buffer3) - - // removing the original handle removes all buffers from all tiers. - handle.close() - verify(buffer).free() - verify(buffer2).free() - verify(buffer3).free() - } - - private def mockBuffer( - bufferId: RapidsBufferId, - tableMeta: TableMeta = null, - tier: StorageTier = StorageTier.DEVICE, - acquireAttempts: Int = 1, - initialPriority: Long = -1): RapidsBuffer = { - spy(new RapidsBuffer { - var _acquireAttempts: Int = acquireAttempts - var currentPriority: Long = initialPriority - override val id: RapidsBufferId = bufferId - override val memoryUsedBytes: Long = 0 - override def meta: TableMeta = tableMeta - override val storageTier: StorageTier = tier - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = null - override def getMemoryBuffer: MemoryBuffer = null - override def copyToMemoryBuffer( - srcOffset: Long, - dst: MemoryBuffer, - dstOffset: Long, - length: Long, - stream: Cuda.Stream): Unit = {} - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = null - override def getHostMemoryBuffer: HostMemoryBuffer = null - override def addReference(): Boolean = { - if (_acquireAttempts > 0) { - _acquireAttempts -= 1 - } - _acquireAttempts == 0 - } - override def free(): Unit = {} - override def getSpillPriority: Long = currentPriority - override def setSpillPriority(priority: Long): Unit = { - currentPriority = priority - } - - override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { body(null) } - override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { body(null) } - override def close(): Unit = {} - }) - } -} - -case class MockBufferId(override val tableId: Int) extends RapidsBufferId { - override def getDiskPath(dbm: RapidsDiskBlockManager): File = - throw new UnsupportedOperationException -} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala deleted file mode 100644 index 45d96be4cb6..00000000000 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala +++ /dev/null @@ -1,489 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.File -import java.math.RoundingMode - -import scala.collection.mutable.ArrayBuffer - -import ai.rapids.cudf.{ColumnVector, ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Table} -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta -import org.mockito.ArgumentCaptor -import org.mockito.Mockito.{spy, verify} -import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.mockito.MockitoSugar - -import org.apache.spark.sql.rapids.RapidsDiskBlockManager -import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, StringType} - -class RapidsDeviceMemoryStoreSuite extends AnyFunSuite with MockitoSugar { - private def buildTable(): Table = { - new Table.TestBuilder() - .column(5, null.asInstanceOf[java.lang.Integer], 3, 1) - .column("five", "two", null, null) - .column(5.0D, 2.0D, 3.0D, 1.0D) - .decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123) - .build() - } - - private def buildTableWithDuplicate(): Table = { - withResource(ColumnVector.fromInts(5, null.asInstanceOf[java.lang.Integer], 3, 1)) { intCol => - withResource(ColumnVector.fromStrings("five", "two", null, null)) { stringCol => - withResource(ColumnVector.fromDoubles(5.0, 2.0, 3.0, 1.0)) { doubleCol => - // add intCol twice - new Table(intCol, intCol, stringCol, doubleCol) - } - } - } - } - - private def buildContiguousTable(): ContiguousTable = { - withResource(buildTable()) { table => - table.contiguousSplit()(0) - } - } - - test("add table registers with catalog") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val bufferId = MockRapidsBufferId(7) - withResource(buildContiguousTable()) { ct => - catalog.addContiguousTable( - bufferId, ct, spillPriority, false) - } - val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) - verify(catalog).registerNewBuffer(captor.capture()) - val resultBuffer = captor.getValue - assertResult(bufferId)(resultBuffer.id) - assertResult(spillPriority)(resultBuffer.getSpillPriority) - } - } - - test("a non-contiguous table is spillable and it is handed over to the store") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTable() - catalog.addTable(table, spillPriority) - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("a non-contiguous table becomes non-spillable when batch is obtained") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTable() - val handle = catalog.addTable(table, spillPriority) - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - val batch = withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.getColumnarBatch(types) - } - withResource(batch) { _ => - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - } - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("a non-contiguous table is non-spillable until all columns are returned") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTable() - val handle = catalog.addTable(table, spillPriority) - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - // incRefCount all the columns via `batch` - val batch = withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.getColumnarBatch(types) - } - val columns = GpuColumnVector.extractBases(batch) - withResource(columns.head) { _ => - columns.head.incRefCount() - withResource(batch) { _ => - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - } - // still 0 after the batch is closed, because of the extra incRefCount - // for columns.head - assertResult(0)(store.currentSpillableSize) - } - // columns.head is closed, so now our RapidsTable is spillable again - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("an aliased non-contiguous table is not spillable (until closing the alias) ") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTable() - val handle = catalog.addTable(table, spillPriority) - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - val aliasHandle = withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - // extract the batch from the table we added, and add it back as a batch - withResource(rapidsBuffer.getColumnarBatch(types)) { batch => - catalog.addBatch(batch, spillPriority) - } - } // we now have two copies in the store - assertResult(buffSize*2)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - - aliasHandle.close() // remove the alias - - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("an aliased non-contiguous table is not spillable (until closing the original) ") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTable() - val handle = catalog.addTable(table, spillPriority) - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - // extract the batch from the table we added, and add it back as a batch - withResource(rapidsBuffer.getColumnarBatch(types)) { batch => - catalog.addBatch(batch, spillPriority) - } - } // we now have two copies in the store - assertResult(buffSize * 2)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - - handle.close() // remove the original - - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("an non-contiguous table supports duplicated columns") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val table = buildTableWithDuplicate() - val handle = catalog.addTable(table, spillPriority) - val types: Array[DataType] = - Seq(IntegerType, IntegerType, StringType, DoubleType).toArray - val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - // extract the batch from the table we added, and add it back as a batch - withResource(rapidsBuffer.getColumnarBatch(types)) { batch => - catalog.addBatch(batch, spillPriority) - } - } // we now have two copies in the store - assertResult(buffSize * 2)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - - handle.close() // remove the original - - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("a contiguous table is not spillable until the owner closes it") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val bufferId = MockRapidsBufferId(7) - val ct = buildContiguousTable() - val buffSize = ct.getBuffer.getLength - withResource(ct) { _ => - catalog.addContiguousTable( - bufferId, - ct, - spillPriority, - false) - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - } - // after closing the original table, the RapidsBuffer should be spillable - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("a buffer is not spillable until the owner closes columns referencing it") { - withResource(new RapidsDeviceMemoryStore) { store => - val spillPriority = 3 - val bufferId = MockRapidsBufferId(7) - val ct = buildContiguousTable() - val buffSize = ct.getBuffer.getLength - withResource(ct) { _ => - val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) - withResource(ct) { _ => - store.addBuffer( - bufferId, - ct.getBuffer, - meta, - spillPriority, - false) - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - } - } - // after closing the original table, the RapidsBuffer should be spillable - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("a buffer is not spillable when the underlying device buffer is obtained from it") { - withResource(new RapidsDeviceMemoryStore) { store => - val spillPriority = 3 - val bufferId = MockRapidsBufferId(7) - val ct = buildContiguousTable() - val buffSize = ct.getBuffer.getLength - val buffer = withResource(ct) { _ => - val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) - val buffer = store.addBuffer( - bufferId, - ct.getBuffer, - meta, - spillPriority, - false) - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - buffer - } - - // after closing the original table, the RapidsBuffer should be spillable - assertResult(buffSize)(store.currentSize) - assertResult(buffSize)(store.currentSpillableSize) - - // if a device memory buffer is obtained from the buffer, it is no longer spillable - withResource(buffer.getDeviceMemoryBuffer) { deviceBuffer => - assertResult(buffSize)(store.currentSize) - assertResult(0)(store.currentSpillableSize) - } - - // once the DeviceMemoryBuffer is closed, the RapidsBuffer should be spillable again - assertResult(buffSize)(store.currentSpillableSize) - } - } - - test("add buffer registers with catalog") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val spillPriority = 3 - val bufferId = MockRapidsBufferId(7) - val meta = withResource(buildContiguousTable()) { ct => - val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) - withResource(ct) { _ => - catalog.addBuffer( - bufferId, - ct.getBuffer, - meta, - spillPriority, - false) - } - meta - } - val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) - verify(catalog).registerNewBuffer(captor.capture()) - val resultBuffer = captor.getValue - assertResult(bufferId)(resultBuffer.id) - assertResult(spillPriority)(resultBuffer.getSpillPriority) - assertResult(meta)(resultBuffer.meta) - } - } - - test("get memory buffer") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = spy(new RapidsBufferCatalog(store)) - val bufferId = MockRapidsBufferId(7) - withResource(buildContiguousTable()) { ct => - withResource(HostMemoryBuffer.allocate(ct.getBuffer.getLength)) { expectedHostBuffer => - expectedHostBuffer.copyFromDeviceBuffer(ct.getBuffer) - val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) - val handle = withResource(ct) { _ => - catalog.addBuffer( - bufferId, - ct.getBuffer, - meta, - initialSpillPriority = 3, - needsSync = false) - } - withResource(catalog.acquireBuffer(handle)) { buffer => - withResource(buffer.getMemoryBuffer.asInstanceOf[DeviceMemoryBuffer]) { devbuf => - withResource(HostMemoryBuffer.allocate(devbuf.getLength)) { actualHostBuffer => - actualHostBuffer.copyFromDeviceBuffer(devbuf) - assertResult(expectedHostBuffer.asByteBuffer())(actualHostBuffer.asByteBuffer()) - } - } - } - } - } - } - } - - test("get column batch") { - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = new RapidsBufferCatalog(store) - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val bufferId = MockRapidsBufferId(7) - withResource(buildContiguousTable()) { ct => - withResource(GpuColumnVector.from(ct.getTable, sparkTypes)) { - expectedBatch => - val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) - val handle = withResource(ct) { _ => - catalog.addBuffer( - bufferId, - ct.getBuffer, - meta, - initialSpillPriority = 3, - false) - } - withResource(catalog.acquireBuffer(handle)) { buffer => - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - } - } - } - } - - test("size statistics") { - - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = new RapidsBufferCatalog(store) - assertResult(0)(store.currentSize) - val bufferSizes = new Array[Long](2) - val bufferHandles = new Array[RapidsBufferHandle](2) - bufferSizes.indices.foreach { i => - withResource(buildContiguousTable()) { ct => - bufferSizes(i) = ct.getBuffer.getLength - // store takes ownership of the table - bufferHandles(i) = - catalog.addContiguousTable( - MockRapidsBufferId(i), - ct, - initialSpillPriority = 0, - false) - } - assertResult(bufferSizes.take(i+1).sum)(store.currentSize) - } - bufferHandles(0).close() - assertResult(bufferSizes(1))(store.currentSize) - bufferHandles(1).close() - assertResult(0)(store.currentSize) - } - } - - test("spill") { - val spillStore = new MockSpillStore - val spillPriorities = Array(0, -1, 2) - val bufferSizes = new Array[Long](spillPriorities.length) - withResource(new RapidsDeviceMemoryStore) { store => - val catalog = new RapidsBufferCatalog(store) - store.setSpillStore(spillStore) - spillPriorities.indices.foreach { i => - withResource(buildContiguousTable()) { ct => - bufferSizes(i) = ct.getBuffer.getLength - // store takes ownership of the table - catalog.addContiguousTable( - MockRapidsBufferId(i), ct, spillPriorities(i), - false) - } - } - assert(spillStore.spilledBuffers.isEmpty) - - // asking to spill 0 bytes should not spill - val sizeBeforeSpill = store.currentSize - catalog.synchronousSpill(store, sizeBeforeSpill) - assert(spillStore.spilledBuffers.isEmpty) - assertResult(sizeBeforeSpill)(store.currentSize) - catalog.synchronousSpill(store, sizeBeforeSpill + 1) - assert(spillStore.spilledBuffers.isEmpty) - assertResult(sizeBeforeSpill)(store.currentSize) - - // spilling 1 byte should force one buffer to spill in priority order - catalog.synchronousSpill(store, sizeBeforeSpill - 1) - assertResult(1)(spillStore.spilledBuffers.length) - assertResult(bufferSizes.drop(1).sum)(store.currentSize) - assertResult(1)(spillStore.spilledBuffers(0).tableId) - - // spilling to zero should force all buffers to spill in priority order - catalog.synchronousSpill(store, 0) - assertResult(3)(spillStore.spilledBuffers.length) - assertResult(0)(store.currentSize) - assertResult(0)(spillStore.spilledBuffers(1).tableId) - assertResult(2)(spillStore.spilledBuffers(2).tableId) - } - } - - case class MockRapidsBufferId(tableId: Int) extends RapidsBufferId { - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - throw new UnsupportedOperationException - } - - class MockSpillStore extends RapidsBufferStore(StorageTier.HOST) { - val spilledBuffers = new ArrayBuffer[RapidsBufferId] - - override protected def createBuffer( - b: RapidsBuffer, - c: RapidsBufferCatalog, - s: Cuda.Stream): Option[RapidsBufferBase] = { - spilledBuffers += b.id - Some(new MockRapidsBuffer( - b.id, b.getPackedSizeBytes, b.meta, b.getSpillPriority)) - } - - class MockRapidsBuffer(id: RapidsBufferId, size: Long, meta: TableMeta, spillPriority: Long) - extends RapidsBufferBase(id, meta, spillPriority) { - override protected def releaseResources(): Unit = {} - - override val storageTier: StorageTier = StorageTier.HOST - - override def getMemoryBuffer: MemoryBuffer = - throw new UnsupportedOperationException - - /** The size of this buffer in bytes. */ - override val memoryUsedBytes: Long = size - } - } -} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala deleted file mode 100644 index fce88f116b3..00000000000 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala +++ /dev/null @@ -1,607 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.File -import java.math.RoundingMode - -import ai.rapids.cudf.{ColumnVector, ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, Table} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} -import org.mockito.ArgumentMatchers -import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{spy, times, verify, when} -import org.scalatestplus.mockito.MockitoSugar - -import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.{RapidsDiskBlockManager, TempSpillBufferId} -import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, StringType} -import org.apache.spark.storage.{BlockId, ShuffleBlockId} - - -class RapidsDiskStoreSuite extends FunSuiteWithTempDir with MockitoSugar { - - private def buildContiguousTable(): ContiguousTable = { - withResource(buildTable()) { table => - table.contiguousSplit()(0) - } - } - - private def buildTable(): Table = { - new Table.TestBuilder() - .column(5, null.asInstanceOf[java.lang.Integer], 3, 1) - .column("five", "two", null, null) - .column(5.0, 2.0, 3.0, 1.0) - .decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123) - .build() - } - - private def buildEmptyTable(): Table = { - withResource(buildTable()) { tbl => - withResource(ColumnVector.fromBooleans(false, false, false, false)) { mask => - tbl.filter(mask) // filter all out - } - } - } - - private val mockTableDataTypes: Array[DataType] = - Array(IntegerType, StringType, DoubleType, DecimalType(10, 5)) - - test("spill updates catalog") { - val bufferId = MockRapidsBufferId(7, canShareDiskPaths = false) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = spy(new RapidsBufferCatalog(devStore)) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - assertResult(0)(diskStore.currentSize) - hostStore.setSpillStore(diskStore) - val (bufferSize, handle) = - addContiguousTableToCatalog(catalog, bufferId, spillPriority) - val path = handle.id.getDiskPath(null) - assert(!path.exists()) - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - assertResult(0)(hostStore.currentSize) - assertResult(bufferSize)(diskStore.currentSize) - assert(path.exists) - assertResult(bufferSize)(path.length) - verify(catalog, times(3)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) - verify(catalog).removeBufferTier( - ArgumentMatchers.eq(handle.id), ArgumentMatchers.eq(StorageTier.DEVICE)) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - assertResult(bufferSize)(buffer.memoryUsedBytes) - assertResult(handle.id)(buffer.id) - assertResult(spillPriority)(buffer.getSpillPriority) - } - } - } - } - } - - test("Get columnar batch") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - assert(!bufferPath.exists) - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { - diskStore => - hostStore.setSpillStore(diskStore) - val (_, handle) = addContiguousTableToCatalog(catalog, bufferId, spillPriority) - assert(!handle.id.getDiskPath(null).exists()) - val expectedTable = withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DEVICE)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { beforeSpill => - withResource(GpuColumnVector.from(beforeSpill)) { table => - table.contiguousSplit()(0) - } - } // closing the batch from the store so that we can spill it - } - withResource(expectedTable) { _ => - withResource( - GpuColumnVector.from(expectedTable.getTable, sparkTypes)) { expectedBatch => - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - } - } - } - } - } - } - - test("get memory buffer") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - assert(!bufferPath.exists) - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val (_, handle) = addContiguousTableToCatalog(catalog, bufferId, spillPriority) - assert(!handle.id.getDiskPath(mockDiskBlockManager).exists()) - val expectedBuffer = withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DEVICE)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { devbuf => - closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => - hostbuf.copyFromDeviceBuffer(devbuf.asInstanceOf[DeviceMemoryBuffer]) - hostbuf - } - } - } - withResource(expectedBuffer) { expectedBuffer => - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { actualBuffer => - assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) - val actualHostBuffer = actualBuffer.asInstanceOf[HostMemoryBuffer] - assertResult(expectedBuffer. - asByteBuffer.limit())(actualHostBuffer.asByteBuffer.limit()) - } - } - } - } - } - } - } - - test("Compression on with or without encryption for spill block using single batch") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.io.compression.codec", "zstd") - conf.set("spark.shuffle.spill.compress", "true") - conf.set("spark.shuffle.compress", "true") - readWriteTestWithBatches(conf, TempSpillBufferId.apply()) - } - } - - test("Compression off with or without encryption for spill block using single batch") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.shuffle.spill.compress", "false") - conf.set("spark.shuffle.compress", "false") - readWriteTestWithBatches(conf, TempSpillBufferId.apply()) - } - } - - test("Compression on with or without encryption for spill block using multiple batches") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.io.compression.codec", "zstd") - conf.set("spark.shuffle.spill.compress", "true") - conf.set("spark.shuffle.compress", "true") - readWriteTestWithBatches(conf, TempSpillBufferId.apply(), TempSpillBufferId.apply()) - } - } - - test("Compression off with or without encryption for spill block using multiple batches") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.shuffle.spill.compress", "false") - conf.set("spark.shuffle.compress", "false") - readWriteTestWithBatches(conf, TempSpillBufferId.apply(), TempSpillBufferId.apply()) - } - } - - // ===== Tests for shuffle block ===== - - test("Compression on with or without encryption for shuffle block using single batch") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.io.compression.codec", "zstd") - conf.set("spark.shuffle.spill.compress", "true") - conf.set("spark.shuffle.compress", "true") - readWriteTestWithBatches(conf, ShuffleBufferId(ShuffleBlockId(1, 1, 1), 1)) - } - } - - test("Compression off with or without encryption for shuffle block using single batch") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.shuffle.spill.compress", "false") - conf.set("spark.shuffle.compress", "false") - readWriteTestWithBatches(conf, ShuffleBufferId(ShuffleBlockId(1, 1, 1), 1)) - } - } - - test("Compression on with or without encryption for shuffle block using multiple batches") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.io.compression.codec", "zstd") - conf.set("spark.shuffle.spill.compress", "true") - conf.set("spark.shuffle.compress", "true") - readWriteTestWithBatches(conf, - ShuffleBufferId(ShuffleBlockId(1, 1, 1), 1), ShuffleBufferId(ShuffleBlockId(2, 2, 2), 2)) - } - } - - test("Compression off with or without encryption for shuffle block using multiple batches") { - Seq("true", "false").foreach { encryptionEnabled => - val conf = new SparkConf() - conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) - conf.set("spark.shuffle.spill.compress", "false") - conf.set("spark.shuffle.compress", "false") - readWriteTestWithBatches(conf, - ShuffleBufferId(ShuffleBlockId(1, 1, 1), 1), ShuffleBufferId(ShuffleBlockId(2, 2, 2), 2)) - } - } - - test("No encryption and compression for shuffle block using multiple batches") { - readWriteTestWithBatches(new SparkConf(), - ShuffleBufferId(ShuffleBlockId(1, 1, 1), 1), ShuffleBufferId(ShuffleBlockId(2, 2, 2), 2)) - } - - private def readWriteTestWithBatches(conf: SparkConf, bufferIds: RapidsBufferId*) = { - assert(bufferIds.size != 0) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(conf)) - - if (bufferIds(0).canShareDiskPaths) { - // Return the same path - val bufferPath = new File(TEST_FILES_ROOT, s"diskbuffer-${bufferIds(0).tableId}") - when(mockDiskBlockManager.getFile(any[BlockId]())).thenReturn(bufferPath) - if (bufferPath.exists) bufferPath.delete() - } else { - when(mockDiskBlockManager.getFile(any[BlockId]())) - .thenAnswer { invocation => - new File(TEST_FILES_ROOT, s"diskbuffer-${invocation.getArgument[BlockId](0).name}") - } - } - - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - bufferIds.foreach { bufferId => - val (_, handle) = addContiguousTableToCatalog(catalog, bufferId, spillPriority) - val expectedBuffer = withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DEVICE)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { devbuf => - closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => - hostbuf.copyFromDeviceBuffer(devbuf.asInstanceOf[DeviceMemoryBuffer]) - hostbuf - } - } - } - withResource(expectedBuffer) { expectedBuffer => - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { actualBuffer => - assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) - val actualHostBuffer = actualBuffer.asInstanceOf[HostMemoryBuffer] - assertResult(expectedBuffer.asByteBuffer)(actualHostBuffer.asByteBuffer) - } - } - } - } - } - } - } - } - - test("skip host: spill device memory buffer to disk") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new AlwaysFailingRapidsHostMemoryStore) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val (_, handle) = addContiguousTableToCatalog(catalog, bufferId, spillPriority) - assert(!handle.id.getDiskPath(null).exists()) - val expectedBuffer = withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DEVICE)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { devbuf => - closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => - hostbuf.copyFromDeviceBuffer(devbuf.asInstanceOf[DeviceMemoryBuffer]) - hostbuf - } - } - } - withResource(expectedBuffer) { expectedBuffer => - catalog.synchronousSpill(devStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getMemoryBuffer) { actualBuffer => - assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) - val actualHostBuffer = actualBuffer.asInstanceOf[HostMemoryBuffer] - assertResult(expectedBuffer.asByteBuffer)(actualHostBuffer.asByteBuffer) - } - } - } - } - } - } - } - - test("skip host: spill table to disk") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new AlwaysFailingRapidsHostMemoryStore) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val handle = addTableToCatalog(catalog, bufferId, spillPriority) - withResource(buildTable()) { expectedTable => - withResource( - GpuColumnVector.from(expectedTable, mockTableDataTypes)) { expectedBatch => - catalog.synchronousSpill(devStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assert(handle.id.getDiskPath(null).exists()) - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(mockTableDataTypes)) { fromDiskBatch => - TestUtils.compareBatches(expectedBatch, fromDiskBatch) - } - } - } - } - } - } - } - } - - test("skip host: spill table to disk with small host bounce buffer") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - withResource(new RapidsDeviceMemoryStore(1L*1024*1024, 10)) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new AlwaysFailingRapidsHostMemoryStore) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val handle = addTableToCatalog(catalog, bufferId, spillPriority) - withResource(buildTable()) { expectedTable => - withResource( - GpuColumnVector.from(expectedTable, mockTableDataTypes)) { expectedBatch => - catalog.synchronousSpill(devStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assert(handle.id.getDiskPath(null).exists()) - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(mockTableDataTypes)) { fromDiskBatch => - TestUtils.compareBatches(expectedBatch, fromDiskBatch) - } - } - } - } - } - } - } - } - - - test("0-byte table is never spillable as we would fail to mmap") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - val bufferId2 = MockRapidsBufferId(2, canShareDiskPaths = false) - assert(!bufferPath.exists) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val handle = addZeroRowsTableToCatalog(catalog, bufferId, spillPriority - 1) - val handle2 = addTableToCatalog(catalog, bufferId2, spillPriority) - withResource(handle2) { _ => - assert(!handle.id.getDiskPath(null).exists()) - withResource(buildTable()) { expectedTable => - withResource(buildEmptyTable()) { expectedEmptyTable => - withResource( - GpuColumnVector.from( - expectedTable, mockTableDataTypes)) { expectedCb => - withResource( - GpuColumnVector.from( - expectedEmptyTable, mockTableDataTypes)) { expectedEmptyCb => - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - withResource(catalog.acquireBuffer(handle2)) { buffer => - withResource(catalog.acquireBuffer(handle)) { emptyBuffer => - // the 0-byte table never moved from device. It is not spillable - assertResult(StorageTier.DEVICE)(emptyBuffer.storageTier) - withResource(emptyBuffer.getColumnarBatch(mockTableDataTypes)) { cb => - TestUtils.compareBatches(expectedEmptyCb, cb) - } - // the second table (with rows) did spill - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(mockTableDataTypes)) { cb => - TestUtils.compareBatches(expectedCb, cb) - } - } - } - assertResult(0)(devStore.currentSize) - assertResult(0)(hostStore.currentSize) - } - } - } - } - } - } - } - } - } - - test("exclusive spill files are deleted when buffer deleted") { - testBufferFileDeletion(canShareDiskPaths = false) - } - - test("shared spill files are not deleted when a buffer is deleted") { - testBufferFileDeletion(canShareDiskPaths = true) - } - - class AlwaysFailingRapidsHostMemoryStore extends RapidsHostMemoryStore(Some(0L)){ - override def createBuffer( - other: RapidsBuffer, - catalog: RapidsBufferCatalog, - stream: Cuda.Stream): Option[RapidsBufferBase] = { - None - } - } - - private def testBufferFileDeletion(canShareDiskPaths: Boolean): Unit = { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) - val mockDiskBlockManager = mock[RapidsDiskBlockManager] - when(mockDiskBlockManager.getSerializerManager()) - .thenReturn(new RapidsSerializerManager(new SparkConf())) - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mockDiskBlockManager)) { diskStore => - hostStore.setSpillStore(diskStore) - val (_, handle) = addContiguousTableToCatalog(catalog, bufferId, spillPriority) - val bufferPath = handle.id.getDiskPath(null) - assert(!bufferPath.exists()) - catalog.synchronousSpill(devStore, 0) - catalog.synchronousSpill(hostStore, 0) - assert(bufferPath.exists) - handle.close() - if (canShareDiskPaths) { - assert(bufferPath.exists()) - } else { - assert(!bufferPath.exists) - } - } - } - } - } - - private def addContiguousTableToCatalog( - catalog: RapidsBufferCatalog, - bufferId: RapidsBufferId, - spillPriority: Long): (Long, RapidsBufferHandle) = { - withResource(buildContiguousTable()) { ct => - val bufferSize = ct.getBuffer.getLength - // store takes ownership of the table - val handle = catalog.addContiguousTable( - bufferId, - ct, - spillPriority, - false) - (bufferSize, handle) - } - } - - private def addTableToCatalog( - catalog: RapidsBufferCatalog, - bufferId: RapidsBufferId, - spillPriority: Long): RapidsBufferHandle = { - // store takes ownership of the table - catalog.addTable( - bufferId, - buildTable(), - spillPriority, - false) - } - - private def addZeroRowsTableToCatalog( - catalog: RapidsBufferCatalog, - bufferId: RapidsBufferId, - spillPriority: Long): RapidsBufferHandle = { - val table = buildEmptyTable() - // store takes ownership of the table - catalog.addTable( - bufferId, - table, - spillPriority, - false) - } - - case class MockRapidsBufferId( - tableId: Int, - override val canShareDiskPaths: Boolean) extends RapidsBufferId { - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - new File(TEST_FILES_ROOT, s"diskbuffer-$tableId") - } -} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala deleted file mode 100644 index 1ffad031451..00000000000 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala +++ /dev/null @@ -1,614 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.io.File -import java.math.RoundingMode - -import ai.rapids.cudf.{ContiguousTable, Cuda, HostColumnVector, HostMemoryBuffer, Table} -import com.nvidia.spark.rapids.Arm._ -import org.mockito.{ArgumentCaptor, ArgumentMatchers} -import org.mockito.Mockito.{spy, times, verify} -import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.mockito.MockitoSugar - -import org.apache.spark.SparkConf -import org.apache.spark.sql.rapids.RapidsDiskBlockManager -import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, LongType, StringType} -import org.apache.spark.sql.vectorized.ColumnarBatch - -class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { - private def buildContiguousTable(): ContiguousTable = { - withResource(new Table.TestBuilder() - .column(5, null.asInstanceOf[java.lang.Integer], 3, 1) - .column("five", "two", null, null) - .column(5.0, 2.0, 3.0, 1.0) - .decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123) - .build()) { table => - table.contiguousSplit()(0) - } - } - - private def buildContiguousTable(numRows: Int): ContiguousTable = { - val vals = (0 until numRows).map(_.toLong) - withResource(HostColumnVector.fromLongs(vals: _*)) { hcv => - withResource(hcv.copyToDevice()) { cv => - withResource(new Table(cv)) { table => - table.contiguousSplit()(0) - } - } - } - } - - private def buildHostBatch(): ColumnarBatch = { - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val hostCols = withResource(buildContiguousTable()) { ct => - withResource(ct.getTable) { tbl => - (0 until tbl.getNumberOfColumns) - .map(c => tbl.getColumn(c).copyToHost()) - } - }.toArray - new ColumnarBatch( - hostCols.zip(sparkTypes).map { case (hostCol, dataType) => - new RapidsHostColumnVector(dataType, hostCol) - }, hostCols.head.getRowCount.toInt) - } - - private def buildHostBatchWithDuplicate(): ColumnarBatch = { - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val hostCols = withResource(buildContiguousTable()) { ct => - withResource(ct.getTable) { tbl => - (0 until tbl.getNumberOfColumns) - .map(c => tbl.getColumn(c).copyToHost()) - } - }.toArray - hostCols.foreach(_.incRefCount()) - new ColumnarBatch( - (hostCols ++ hostCols).zip(sparkTypes ++ sparkTypes).map { case (hostCol, dataType) => - new RapidsHostColumnVector(dataType, hostCol) - }, hostCols.head.getRowCount.toInt) - } - - test("spill updates catalog") { - val spillPriority = -7 - val hostStoreMaxSize = 1L * 1024 * 1024 - val mockStore = mock[RapidsHostMemoryStore] - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = spy(new RapidsBufferCatalog(devStore)) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - assertResult(0)(hostStore.currentSize) - assertResult(hostStoreMaxSize)(hostStore.numBytesFree.get) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - - val (bufferSize, handle) = withResource(buildContiguousTable()) { ct => - val len = ct.getBuffer.getLength - // store takes ownership of the table - val handle = catalog.addContiguousTable( - ct, - spillPriority) - (len, handle) - } - - catalog.synchronousSpill(devStore, 0) - assertResult(bufferSize)(hostStore.currentSize) - assertResult(hostStoreMaxSize - bufferSize)(hostStore.numBytesFree.get) - verify(catalog, times(2)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) - verify(catalog).removeBufferTier( - ArgumentMatchers.eq(handle.id), ArgumentMatchers.eq(StorageTier.DEVICE)) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.HOST)(buffer.storageTier) - assertResult(bufferSize)(buffer.memoryUsedBytes) - assertResult(handle.id)(buffer.id) - assertResult(spillPriority)(buffer.getSpillPriority) - } - } - } - } - - test("get columnar batch") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val mockStore = mock[RapidsHostMemoryStore] - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - var expectedBuffer: HostMemoryBuffer = null - val handle = withResource(buildContiguousTable()) { ct => - expectedBuffer = HostMemoryBuffer.allocate(ct.getBuffer.getLength) - expectedBuffer.copyFromDeviceBuffer(ct.getBuffer) - catalog.addContiguousTable( - ct, - spillPriority) - } - withResource(expectedBuffer) { _ => - catalog.synchronousSpill(devStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - withResource(buffer.getMemoryBuffer) { actualBuffer => - assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) - assertResult(expectedBuffer.asByteBuffer) { - actualBuffer.asInstanceOf[HostMemoryBuffer].asByteBuffer - } - } - } - } - } - } - } - - test("get memory buffer") { - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val mockStore = mock[RapidsHostMemoryStore] - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { - hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - var expectedBatch: ColumnarBatch = null - val handle = withResource(buildContiguousTable()) { ct => - // make a copy of the table so we can compare it later to the - // one reconstituted after the spill - withResource(ct.getTable.contiguousSplit()) { copied => - expectedBatch = GpuColumnVector.from(copied(0).getTable, sparkTypes) - } - catalog.addContiguousTable( - ct, - spillPriority) - } - withResource(expectedBatch) { _ => - catalog.synchronousSpill(devStore, 0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.HOST)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - } - } - } - } - - test("get memory buffer after host spill") { - RapidsBufferCatalog.close() - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - try { - val bm = new RapidsDiskBlockManager(new SparkConf()) - val (catalog, devStore, hostStore, diskStore) = - closeOnExcept(new RapidsDiskStore(bm)) { diskStore => - closeOnExcept(new RapidsDeviceMemoryStore()) { devStore => - closeOnExcept(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - val catalog = closeOnExcept( - new RapidsBufferCatalog(devStore, hostStore)) { catalog => catalog } - (catalog, devStore, hostStore, diskStore) - } - } - } - - RapidsBufferCatalog.setDeviceStorage(devStore) - RapidsBufferCatalog.setHostStorage(hostStore) - RapidsBufferCatalog.setDiskStorage(diskStore) - RapidsBufferCatalog.setCatalog(catalog) - - var expectedBatch: ColumnarBatch = null - val handle = withResource(buildContiguousTable()) { ct => - // make a copy of the table so we can compare it later to the - // one reconstituted after the spill - withResource(ct.getTable.contiguousSplit()) { copied => - expectedBatch = GpuColumnVector.from(copied(0).getTable, sparkTypes) - } - RapidsBufferCatalog.addContiguousTable( - ct, - spillPriority) - } - withResource(expectedBatch) { _ => - val spilledToHost = - RapidsBufferCatalog.synchronousSpill( - RapidsBufferCatalog.getDeviceStorage, 0) - assert(spilledToHost.isDefined && spilledToHost.get > 0) - - val spilledToDisk = - RapidsBufferCatalog.synchronousSpill( - RapidsBufferCatalog.getHostStorage, 0) - assert(spilledToDisk.isDefined && spilledToDisk.get > 0) - - withResource(RapidsBufferCatalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - } - } finally { - RapidsBufferCatalog.close() - } - } - - test("host buffer originated: get host memory buffer") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val mockStore = mock[RapidsDiskStore] - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - val hmb = HostMemoryBuffer.allocate(1L * 1024) - val spillableBuffer = - SpillableHostBuffer(hmb, hmb.getLength, spillPriority, catalog) - withResource(spillableBuffer) { _ => - // the refcount of 1 is the store - assertResult(1)(hmb.getRefCount) - withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => - assertResult(hmb)(memoryBuffer) - assertResult(2)(memoryBuffer.getRefCount) - } - } - assertResult(0)(hmb.getRefCount) - } - } - } - - test("host buffer originated: get host memory buffer after spill") { - RapidsBufferCatalog.close() - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - try { - val bm = new RapidsDiskBlockManager(new SparkConf()) - val (catalog, devStore, hostStore, diskStore) = - closeOnExcept(new RapidsDiskStore(bm)) { diskStore => - closeOnExcept(new RapidsDeviceMemoryStore()) { devStore => - closeOnExcept(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - val catalog = closeOnExcept( - new RapidsBufferCatalog(devStore, hostStore)) { catalog => catalog } - (catalog, devStore, hostStore, diskStore) - } - } - } - - RapidsBufferCatalog.setDeviceStorage(devStore) - RapidsBufferCatalog.setHostStorage(hostStore) - RapidsBufferCatalog.setDiskStorage(diskStore) - RapidsBufferCatalog.setCatalog(catalog) - - val hmb = HostMemoryBuffer.allocate(1L * 1024) - val spillableBuffer = SpillableHostBuffer( - hmb, - hmb.getLength, - spillPriority) - assertResult(1)(hmb.getRefCount) - // we spill it - RapidsBufferCatalog.synchronousSpill(RapidsBufferCatalog.getHostStorage, 0) - withResource(spillableBuffer) { _ => - // the refcount of the original buffer is 0 because it spilled - assertResult(0)(hmb.getRefCount) - withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => - assertResult(memoryBuffer.getLength)(hmb.getLength) - } - } - } finally { - RapidsBufferCatalog.close() - } - } - - test("host buffer originated: get host memory buffer OOM when unable to spill") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - val hmb = HostMemoryBuffer.allocate(1L * 1024) - val spillableBuffer = SpillableHostBuffer( - hmb, - hmb.getLength, - spillPriority, - catalog) - // spillable is 1K - assertResult(hmb.getLength)(hostStore.currentSpillableSize) - withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => - // 0 because we have a reference to the memoryBuffer - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - } - assertResult(hmb.getLength)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(1L * 1024)(spilled.get) - spillableBuffer.close() - } - } - } - } - - test("host batch originated: get host memory batch") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - - val hostCb = buildHostBatch() - - val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) - - withResource( - SpillableHostColumnarBatch(hostCb, spillPriority, catalog)) { spillableBuffer => - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - - withResource(spillableBuffer.getColumnarBatch()) { hostCb => - // 0 because we have a reference to the memoryBuffer - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - } - - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(sizeOnHost)(spilled.get) - - val sizeOnDisk = diskStore.currentSpillableSize - - // reconstitute batch from disk - withResource(spillableBuffer.getColumnarBatch()) { hostCbFromDisk => - // disk has a different size, so this spillable batch has a different sizeInBytes - // right now, because this is the serialized represenation size - assertResult(sizeOnDisk)(spillableBuffer.sizeInBytes) - // lets recreate our original batch and compare to make sure contents match - withResource(buildHostBatch()) { expectedHostCb => - TestUtils.compareBatches(expectedHostCb, hostCbFromDisk) - } - } - } - } - } - } - } - - test("a host batch is not spillable when we leak it") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - - val hostCb = buildHostBatch() - - val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) - - val leakedBatch = withResource( - SpillableHostColumnarBatch(hostCb, spillPriority, catalog)) { spillableBuffer => - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - - val leakedBatch = spillableBuffer.getColumnarBatch() - // 0 because we have a reference to the host batch - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - leakedBatch - } - - withResource(leakedBatch) { _ => - // 0 because we have leaked that the host batch - assertResult(0)(hostStore.currentSize) - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - } - // after closing we still have 0 bytes in the store or available to spill - assertResult(0)(hostStore.currentSize) - assertResult(0)(hostStore.currentSpillableSize) - } - } - } - } - - test("a host batch is not spillable when columns are incRefCounted") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - - val hostCb = buildHostBatch() - - val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) - - withResource( - SpillableHostColumnarBatch(hostCb, spillPriority, catalog)) { spillableBuffer => - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - - val leakedFirstColumn = withResource(spillableBuffer.getColumnarBatch()) { hostCb => - // 0 because we have a reference to the host batch - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - // leak it by increasing the ref count of the underlying cuDF column - RapidsHostColumnVector.extractBases(hostCb).head.incRefCount() - } - withResource(leakedFirstColumn) { _ => - // 0 because we have a reference to the first column - assertResult(0)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(0)(spilled.get) - } - // batch is now spillable because we close our reference to the column - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - val spilled = catalog.synchronousSpill(hostStore, 0) - assertResult(sizeOnHost)(spilled.get) - } - } - } - } - } - - test("an aliased host batch is not spillable (until closing the original) ") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - val hostBatch = buildHostBatch() - val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostBatch) - val handle = withResource(hostBatch) { _ => - catalog.addBatch(hostBatch, spillPriority) - } - withResource(handle) { _ => - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - assertResult(sizeOnHost)(hostStore.currentSize) - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - // extract the batch from the table we added, and add it back as a batch - withResource(rapidsBuffer.getHostColumnarBatch(types)) { batch => - catalog.addBatch(batch, spillPriority) - } - } // we now have two copies in the store - assertResult(sizeOnHost * 2)(hostStore.currentSize) - assertResult(0)(hostStore.currentSpillableSize) - } // remove the original - assertResult(sizeOnHost)(hostStore.currentSize) - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - } - } - } - } - - test("an aliased host batch supports duplicated columns") { - val spillPriority = -10 - val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - val hostBatch = buildHostBatchWithDuplicate() - val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostBatch) - val handle = withResource(hostBatch) { _ => - catalog.addBatch(hostBatch, spillPriority) - } - withResource(handle) { _ => - val types: Array[DataType] = - Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray - assertResult(sizeOnHost)(hostStore.currentSize) - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - // extract the batch from the table we added, and add it back as a batch - withResource(rapidsBuffer.getHostColumnarBatch(types)) { batch => - catalog.addBatch(batch, spillPriority) - } - } // we now have two copies in the store - assertResult(sizeOnHost * 2)(hostStore.currentSize) - assertResult(0)(hostStore.currentSpillableSize) - } // remove the original - assertResult(sizeOnHost)(hostStore.currentSize) - assertResult(sizeOnHost)(hostStore.currentSpillableSize) - } - } - } - } - - test("buffer exceeds maximum size") { - val sparkTypes = Array[DataType](LongType) - val spillPriority = -10 - val hostStoreMaxSize = 256 - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore) - val spyStore = spy(new RapidsDiskStore(new RapidsDiskBlockManager(new SparkConf()))) - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(spyStore) - var bigHandle: RapidsBufferHandle = null - var bigTable = buildContiguousTable(1024 * 1024) - closeOnExcept(bigTable) { _ => - // make a copy of the table so we can compare it later to the - // one reconstituted after the spill - val expectedBatch = - withResource(bigTable.getTable.contiguousSplit()) { expectedTable => - GpuColumnVector.from(expectedTable(0).getTable, sparkTypes) - } - withResource(expectedBatch) { _ => - bigHandle = withResource(bigTable) { _ => - catalog.addContiguousTable( - bigTable, - spillPriority) - } // close the bigTable so it can be spilled - bigTable = null - withResource(catalog.acquireBuffer(bigHandle)) { buffer => - assertResult(StorageTier.DEVICE)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - catalog.synchronousSpill(devStore, 0) - val rapidsBufferCaptor: ArgumentCaptor[RapidsBuffer] = - ArgumentCaptor.forClass(classOf[RapidsBuffer]) - verify(spyStore).copyBuffer( - rapidsBufferCaptor.capture(), - ArgumentMatchers.any[RapidsBufferCatalog], - ArgumentMatchers.any[Cuda.Stream]) - assertResult(bigHandle.id)(rapidsBufferCaptor.getValue.id) - withResource(catalog.acquireBuffer(bigHandle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - } - } - } - } - } - - case class MockRapidsBufferId(tableId: Int) extends RapidsBufferId { - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - throw new UnsupportedOperationException - } -} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala index 5776b2f99a8..0cc017f3792 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala @@ -18,16 +18,14 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{Rmm, RmmAllocationMode, RmmEventHandler} import com.nvidia.spark.rapids.jni.RmmSpark -import org.mockito.Mockito.spy import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite +import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession trait RmmSparkRetrySuiteBase extends AnyFunSuite with BeforeAndAfterEach { private var rmmWasInitialized = false - protected var deviceStorage: RapidsDeviceMemoryStore = _ - override def beforeEach(): Unit = { super.beforeEach() SparkSession.getActiveSession.foreach(_.stop()) @@ -37,14 +35,14 @@ trait RmmSparkRetrySuiteBase extends AnyFunSuite with BeforeAndAfterEach { rmmWasInitialized = true Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) } - deviceStorage = spy(new RapidsDeviceMemoryStore()) - val hostStore = new RapidsHostMemoryStore(Some(1L * 1024 * 1024)) - deviceStorage.setSpillStore(hostStore) - val catalog = new RapidsBufferCatalog(deviceStorage, hostStore) - // set these against the singleton so we close them later - RapidsBufferCatalog.setDeviceStorage(deviceStorage) - RapidsBufferCatalog.setHostStorage(hostStore) - RapidsBufferCatalog.setCatalog(catalog) + val sc = new SparkConf + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1MB") + val conf = new RapidsConf(sc) + SpillFramework.shutdown() + SpillFramework.initialize(conf) + + RmmSpark.clearEventHandler() + val mockEventHandler = new BaseRmmEventHandler() RmmSpark.setEventHandler(mockEventHandler) RmmSpark.currentThreadIsDedicatedToTask(1) @@ -53,11 +51,11 @@ trait RmmSparkRetrySuiteBase extends AnyFunSuite with BeforeAndAfterEach { override def afterEach(): Unit = { super.afterEach() + SpillFramework.shutdown() SparkSession.getActiveSession.foreach(_.stop()) SparkSession.clearActiveSession() RmmSpark.removeAllCurrentThreadAssociation() RmmSpark.clearEventHandler() - RapidsBufferCatalog.close() GpuSemaphore.shutdown() if (rmmWasInitialized) { Rmm.shutdown() diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala index 56ecb1a8c57..545aef2915a 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +18,14 @@ package com.nvidia.spark.rapids import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} -import ai.rapids.cudf.Table +import ai.rapids.cudf.{Rmm, RmmAllocationMode, Table} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.commons.lang3.SerializationUtils import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite +import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch} import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StringType} @@ -33,12 +34,17 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} class SerializationSuite extends AnyFunSuite with BeforeAndAfterAll { + override def beforeAll(): Unit = { - RapidsBufferCatalog.setDeviceStorage(new RapidsDeviceMemoryStore()) + super.beforeAll() + Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) + SpillFramework.initialize(new RapidsConf(new SparkConf)) } override def afterAll(): Unit = { - RapidsBufferCatalog.close() + super.afterAll() + SpillFramework.shutdown() + Rmm.shutdown() } private def buildBatch(): ColumnarBatch = { @@ -170,7 +176,7 @@ class SerializationSuite extends AnyFunSuite withResource(toHostBatch(gpuBatch)) { expectedHostBatch => val broadcast = makeBroadcastBatch(gpuBatch) withBroadcast(broadcast) { _ => - withResource(broadcast.batch.getColumnarBatch()) { materialized => + withResource(broadcast.batch.getColumnarBatch) { materialized => TestUtils.compareBatches(gpuBatch, materialized) } // the host batch here is obtained from the GPU batch since @@ -193,12 +199,12 @@ class SerializationSuite extends AnyFunSuite batches.foreach { gpuExpected => val broadcast = makeBroadcastBatch(gpuExpected) withBroadcast(broadcast) { _ => - withResource(broadcast.batch.getColumnarBatch()) { gpuBatch => + withResource(broadcast.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } // clone via serialization after manifesting the GPU batch withBroadcast(SerializationUtils.clone(broadcast)) { clonedObj => - withResource(clonedObj.batch.getColumnarBatch()) { gpuClonedBatch => + withResource(clonedObj.batch.getColumnarBatch) { gpuClonedBatch => TestUtils.compareBatches(gpuExpected, gpuClonedBatch) } // try to clone it again from the cloned object @@ -214,12 +220,12 @@ class SerializationSuite extends AnyFunSuite batches.foreach { gpuExpected => val broadcast = makeBroadcastBatch(gpuExpected) withBroadcast(broadcast) { _ => - withResource(broadcast.batch.getColumnarBatch()) { gpuBatch => + withResource(broadcast.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } // clone via serialization after manifesting the GPU batch withBroadcast(SerializationUtils.clone(broadcast)) { clonedObj => - withResource(clonedObj.batch.getColumnarBatch()) { gpuClonedBatch => + withResource(clonedObj.batch.getColumnarBatch) { gpuClonedBatch => TestUtils.compareBatches(gpuExpected, gpuClonedBatch) } // try to clone it again from the cloned object @@ -234,12 +240,12 @@ class SerializationSuite extends AnyFunSuite withResource(buildBatch()) { gpuExpected => val broadcast = makeBroadcastBatch(gpuExpected) withBroadcast(broadcast) { _ => - withResource(broadcast.batch.getColumnarBatch()) { gpuBatch => + withResource(broadcast.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } // clone via serialization after manifesting the GPU batch withBroadcast(SerializationUtils.clone(broadcast)) { clonedObj => - withResource(clonedObj.batch.getColumnarBatch()) { gpuClonedBatch => + withResource(clonedObj.batch.getColumnarBatch) { gpuClonedBatch => TestUtils.compareBatches(gpuExpected, gpuClonedBatch) } // try to clone it again from the cloned object @@ -263,7 +269,7 @@ class SerializationSuite extends AnyFunSuite broadcast.doReadObject(inputStream) // use it now - withResource(broadcast.batch.getColumnarBatch()) { gpuBatch => + withResource(broadcast.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } } @@ -275,7 +281,7 @@ class SerializationSuite extends AnyFunSuite val broadcast = makeBroadcastBatch(gpuExpected) withBroadcast(broadcast) { _ => // materialize - withResource(broadcast.batch.getColumnarBatch()) { cb => + withResource(broadcast.batch.getColumnarBatch) { cb => TestUtils.compareBatches(gpuExpected, cb) } @@ -292,7 +298,7 @@ class SerializationSuite extends AnyFunSuite assertResult(before)(broadcast.batch) // it is the same as before // use it now - withResource(broadcast.batch.getColumnarBatch()) { gpuBatch => + withResource(broadcast.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } } @@ -311,7 +317,7 @@ class SerializationSuite extends AnyFunSuite .deserialize[SerializeConcatHostBuffersDeserializeBatch]( inputStream)) { materialized => // this materializes a new batch from what was deserialized - withResource(materialized.batch.getColumnarBatch()) { gpuBatch => + withResource(materialized.batch.getColumnarBatch) { gpuBatch => TestUtils.compareBatches(gpuExpected, gpuBatch) } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala index 00209461d3c..d86aa8fccb2 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,18 +19,15 @@ package com.nvidia.spark.rapids import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar -import org.apache.spark.sql.rapids.RapidsDiskBlockManager - class ShuffleBufferCatalogSuite extends AnyFunSuite with MockitoSugar { + // TODO: AB: more tests please test("registered shuffles should be active") { - val catalog = mock[RapidsBufferCatalog] - val rapidsDiskBlockManager = mock[RapidsDiskBlockManager] - val shuffleCatalog = new ShuffleBufferCatalog(catalog, rapidsDiskBlockManager) - + val shuffleCatalog = new ShuffleBufferCatalog() assertResult(false)(shuffleCatalog.hasActiveShuffle(123)) shuffleCatalog.registerShuffle(123) assertResult(true)(shuffleCatalog.hasActiveShuffle(123)) shuffleCatalog.unregisterShuffle(123) assertResult(false)(shuffleCatalog.hasActiveShuffle(123)) } + } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SpillFrameworkSuite.scala new file mode 100644 index 00000000000..2d92785078c --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SpillFrameworkSuite.scala @@ -0,0 +1,984 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.io.File +import java.math.RoundingMode + +import ai.rapids.cudf.{ColumnVector, ContiguousTable, Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, Table} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.format.CodecType +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfterAll +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.SparkConf +import org.apache.spark.sql.rapids.RapidsDiskBlockManager +import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +class SpillFrameworkSuite + extends FunSuiteWithTempDir + with MockitoSugar + with BeforeAndAfterAll { + + override def beforeEach(): Unit = { + super.beforeEach() + val sc = new SparkConf + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1024") + SpillFramework.initialize(new RapidsConf(sc)) + } + + override def afterEach(): Unit = { + super.afterEach() + SpillFramework.shutdown() + } + + private def buildContiguousTable(): (ContiguousTable, Array[DataType]) = { + val (tbl, dataTypes) = buildTable() + withResource(tbl) { _ => + (tbl.contiguousSplit()(0), dataTypes) + } + } + + private def buildTableOfLongs(numRows: Int): (ContiguousTable, Array[DataType])= { + val vals = (0 until numRows).map(_.toLong) + withResource(HostColumnVector.fromLongs(vals: _*)) { hcv => + withResource(hcv.copyToDevice()) { cv => + withResource(new Table(cv)) { table => + (table.contiguousSplit()(0), Array[DataType](LongType)) + } + } + } + } + + private def buildTable(): (Table, Array[DataType]) = { + val tbl = new Table.TestBuilder() + .column(5, null.asInstanceOf[java.lang.Integer], 3, 1) + .column("five", "two", null, null) + .column(5.0, 2.0, 3.0, 1.0) + .decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123) + .build() + val types: Array[DataType] = + Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray + (tbl, types) + } + + private def buildTableWithDuplicate(): (Table, Array[DataType]) = { + withResource(ColumnVector.fromInts(5, null.asInstanceOf[java.lang.Integer], 3, 1)) { intCol => + withResource(ColumnVector.fromStrings("five", "two", null, null)) { stringCol => + withResource(ColumnVector.fromDoubles(5.0, 2.0, 3.0, 1.0)) { doubleCol => + // add intCol twice + (new Table(intCol, intCol, stringCol, doubleCol), + Array(IntegerType, IntegerType, StringType, DoubleType)) + } + } + } + } + + private def buildEmptyTable(): (Table, Array[DataType]) = { + val (tbl, types) = buildTable() + val emptyTbl = withResource(tbl) { _ => + withResource(ColumnVector.fromBooleans(false, false, false, false)) { mask => + tbl.filter(mask) // filter all out + } + } + (emptyTbl, types) + } + + private def testBufferFileDeletion(canShareDiskPaths: Boolean): Unit = { + val (_, handle, _) = addContiguousTableToCatalog() + var path: File = null + withResource(handle) { _ => + SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + SpillFramework.stores.hostStore.spill(handle.sizeInBytes) + assert(handle.host.isDefined) + assert(handle.host.map(_.disk.isDefined).get) + path = SpillFramework.stores.diskStore.getFile(handle.host.flatMap(_.disk).get.blockId) + assert(path.exists) + } + assert(!path.exists) + } + + private def addContiguousTableToCatalog(): ( + Long, SpillableColumnarBatchFromBufferHandle, Array[DataType]) = { + val (ct, dataTypes) = buildContiguousTable() + val bufferSize = ct.getBuffer.getLength + val handle = SpillableColumnarBatchFromBufferHandle(ct, dataTypes) + (bufferSize, handle, dataTypes) + } + + private def addTableToCatalog(): (SpillableColumnarBatchHandle, Array[DataType]) = { + // store takes ownership of the table + val (tbl, dataTypes) = buildTable() + val cb = withResource(tbl) { _ => GpuColumnVector.from(tbl, dataTypes) } + val handle = SpillableColumnarBatchHandle(cb) + (handle, dataTypes) + } + + private def addZeroRowsTableToCatalog(): (SpillableColumnarBatchHandle, Array[DataType]) = { + val (table, dataTypes) = buildEmptyTable() + val cb = withResource(table) { _ => GpuColumnVector.from(table, dataTypes) } + val handle = SpillableColumnarBatchHandle(cb) + (handle, dataTypes) + } + + private def buildHostBatch(): (ColumnarBatch, Array[DataType]) = { + val (ct, dataTypes) = buildContiguousTable() + val hostCols = withResource(ct) { _ => + withResource(ct.getTable) { tbl => + (0 until tbl.getNumberOfColumns) + .map(c => tbl.getColumn(c).copyToHost()) + } + }.toArray + (new ColumnarBatch( + hostCols.zip(dataTypes).map { case (hostCol, dataType) => + new RapidsHostColumnVector(dataType, hostCol) + }, hostCols.head.getRowCount.toInt), dataTypes) + } + + private def buildHostBatchWithDuplicate(): (ColumnarBatch, Array[DataType]) = { + val (ct, dataTypes) = buildContiguousTable() + val hostCols = withResource(ct) { _ => + withResource(ct.getTable) { tbl => + (0 until tbl.getNumberOfColumns) + .map(c => tbl.getColumn(c).copyToHost()) + } + }.toArray + hostCols.foreach(_.incRefCount()) + (new ColumnarBatch( + (hostCols ++ hostCols).zip(dataTypes ++ dataTypes).map { case (hostCol, dataType) => + new RapidsHostColumnVector(dataType, hostCol) + }, hostCols.head.getRowCount.toInt), dataTypes) + } + + // TODO: AB: add tests that span multiple byte buffers for host->disk, and + // test that span multiple chunked pack bounce buffers + + test("add table registers with device store") { + val (ct, dataTypes) = buildContiguousTable() + withResource(SpillableColumnarBatchFromBufferHandle(ct, dataTypes)) { _ => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + } + } + + test("a non-contiguous table is spillable and it is handed over to the store") { + val (tbl, dataTypes) = buildTable() + withResource(SpillableColumnarBatchHandle(tbl, dataTypes)) { handle => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + } + } + + test("a non-contiguous table becomes non-spillable when batch is obtained") { + val (tbl, dataTypes) = buildTable() + withResource(SpillableColumnarBatchHandle(tbl, dataTypes)) { handle => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + withResource(handle.materialize(dataTypes)) { _ => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(!handle.spillable) + assertResult(0)(SpillFramework.stores.deviceStore.spill(handle.sizeInBytes)) + } + assert(handle.spillable) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assertResult(handle.sizeInBytes)(SpillFramework.stores.deviceStore.spill(handle.sizeInBytes)) + } + } + + test("a non-contiguous table is non-spillable until all columns are returned") { + val (table, dataTypes) = buildTable() + withResource(SpillableColumnarBatchHandle(table, dataTypes)) { handle => + assert(handle.spillable) + val cb = handle.materialize(dataTypes) + assert(!handle.spillable) + val columns = GpuColumnVector.extractBases(cb) + withResource(columns.head) { _ => + columns.head.incRefCount() + withResource(cb) { _ => + assert(!handle.spillable) + } + // still 0 after the batch is closed, because of the extra incRefCount + // for columns.head + assert(!handle.spillable) + } + // columns.head is closed, so now our RapidsTable is spillable again + assert(handle.spillable) + } + } + + test("an aliased non-contiguous table is not spillable (until closing the alias) ") { + val (table, dataTypes) = buildTable() + withResource(SpillableColumnarBatchHandle(table, dataTypes)) { handle => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + withResource(SpillableColumnarBatchHandle(handle.materialize(dataTypes))) { aliasHandle => + assertResult(2)(SpillFramework.stores.deviceStore.numHandles) + assert(!handle.spillable) + assert(!aliasHandle.spillable) + } // we now have two copies in the store + assert(handle.spillable) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + } + } + + test("an aliased contiguous table is not spillable (until closing the original) ") { + val (table, dataTypes) = buildContiguousTable() + withResource(SpillableColumnarBatchFromBufferHandle(table, dataTypes)) { handle => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + val materialized = handle.materialize(dataTypes) + // note that materialized is a batch "from buffer", it is not a regular batch + withResource(SpillableColumnarBatchFromBufferHandle(materialized)) { aliasHandle => + assertResult(2)(SpillFramework.stores.deviceStore.numHandles) + assert(!handle.spillable) + assert(!aliasHandle.spillable) + } // we now have two copies in the store + assert(handle.spillable) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + } + } + + test("an non-contiguous table supports duplicated columns") { + val (table, dataTypes) = buildTableWithDuplicate() + withResource(SpillableColumnarBatchHandle(table, dataTypes)) { handle => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + withResource(SpillableColumnarBatchHandle(handle.materialize(dataTypes))) { aliasHandle => + assertResult(2)(SpillFramework.stores.deviceStore.numHandles) + assert(!handle.spillable) + assert(!aliasHandle.spillable) + } // we now have two copies in the store + assert(handle.spillable) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + } + } + + test("a buffer is not spillable until the owner closes columns referencing it") { + val (ct, _) = buildContiguousTable() + // the contract for spillable handles is that they take ownership + // incRefCount to follow that pattern + val buff = ct.getBuffer + buff.incRefCount() + withResource(SpillableDeviceBufferHandle(buff)) { handle => + withResource(ct) { _ => + assert(!handle.spillable) + } + assert(handle.spillable) + } + } + + private def buildContiguousTable(start: Int, numRows: Int): ContiguousTable = { + val vals = (0 until numRows).map(_.toLong + start) + withResource(HostColumnVector.fromLongs(vals: _*)) { hcv => + withResource(hcv.copyToDevice()) { cv => + withResource(HostColumnVector.decimalFromLongs(-3, vals: _*)) { decHcv => + withResource(decHcv.copyToDevice()) { decCv => + withResource(new Table(cv, decCv)) { table => + table.contiguousSplit()(0) + } + } + } + } + } + } + + private def buildCompressedBatch(start: Int, numRows: Int): ColumnarBatch = { + val codec = TableCompressionCodec.getCodec( + CodecType.NVCOMP_LZ4, TableCompressionCodec.makeCodecConfig(new RapidsConf(new SparkConf))) + withResource(codec.createBatchCompressor(0, Cuda.DEFAULT_STREAM)) { compressor => + compressor.addTableToCompress(buildContiguousTable(start, numRows)) + withResource(compressor.finish()) { compressed => + GpuCompressedColumnVector.from(compressed.head) + } + } + } + + private def decompressBach(cb: ColumnarBatch): ColumnarBatch = { + val schema = new StructType().add("i", LongType) + .add("j", DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 3)) + val sparkTypes = GpuColumnVector.extractTypes(schema) + val codec = TableCompressionCodec.getCodec( + CodecType.NVCOMP_LZ4, TableCompressionCodec.makeCodecConfig(new RapidsConf(new SparkConf))) + withResource(codec.createBatchDecompressor(0, Cuda.DEFAULT_STREAM)) { decompressor => + val gcv = cb.column(0).asInstanceOf[GpuCompressedColumnVector] + // we need to incRefCount since the decompressor closes its inputs + gcv.getTableBuffer.incRefCount() + decompressor.addBufferToDecompress(gcv.getTableBuffer, gcv.getTableMeta.bufferMeta()) + withResource(decompressor.finishAsync()) { decompressed => + MetaUtils.getBatchFromMeta( + decompressed.head, + MetaUtils.dropCodecs(gcv.getTableMeta), + sparkTypes) + } + } + } + + test("a compressed batch can be added and recovered") { + val ct = buildCompressedBatch(0, 1000) + withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => + assert(handle.spillable) + withResource(handle.materialize) { materialized => + assert(!handle.spillable) + // since we didn't spill, these buffers are exactly the same + assert( + ct.column(0).asInstanceOf[GpuCompressedColumnVector].getTableBuffer == + materialized.column(0).asInstanceOf[GpuCompressedColumnVector].getTableBuffer) + } + assert(handle.spillable) + } + } + + test("a compressed batch can be added and recovered after being spilled to host") { + val ct = buildCompressedBatch(0, 1000) + withResource(decompressBach(ct)) { decompressedExpected => + withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => + assert(handle.spillable) + SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + assert(!handle.spillable) + assert(handle.dev.isEmpty) + assert(handle.host.isDefined) + withResource(handle.materialize) { materialized => + withResource(decompressBach(materialized)) { decompressed => + TestUtils.compareBatches(decompressedExpected, decompressed) + } + } + } + } + } + + test("a compressed batch can be added and recovered after being spilled to disk") { + val ct = buildCompressedBatch(0, 1000) + withResource(decompressBach(ct)) { decompressedExpected => + withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => + assert(handle.spillable) + SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + assert(!handle.spillable) + SpillFramework.stores.hostStore.spill(handle.sizeInBytes) + assert(handle.dev.isEmpty) + assert(handle.host.isDefined) + assert(handle.host.get.host.isEmpty) + assert(handle.host.get.disk.isDefined) + withResource(handle.materialize) { materialized => + withResource(decompressBach(materialized)) { decompressed => + TestUtils.compareBatches(decompressedExpected, decompressed) + } + } + } + } + } + + + test("a second handle prevents buffer to be spilled") { + val buffer = DeviceMemoryBuffer.allocate(123) + val handle1 = SpillableDeviceBufferHandle(buffer) + // materialize will incRefCount `buffer`. This looks a little weird + // but it simulates aliasing as it happens in real code + val handle2 = SpillableDeviceBufferHandle(handle1.materialize) + + withResource(handle1) { _ => + withResource(handle2) { _ => + assertResult(2)(handle1.dev.get.getRefCount) + assertResult(2)(handle2.dev.get.getRefCount) + assertResult(false)(handle1.spillable) + assertResult(false)(handle2.spillable) + } + assertResult(1)(handle1.dev.get.getRefCount) + assertResult(true)(handle1.spillable) + } + } + + test("removing handle releases buffer resources in all stores") { + val handle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(123)) + withResource(handle) { _ => + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assertResult(0)(SpillFramework.stores.hostStore.numHandles) + assertResult(0)(SpillFramework.stores.diskStore.numHandles) + + assertResult(123)(SpillFramework.stores.deviceStore.spill(123)) // spill to host memory + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + assertResult(1)(SpillFramework.stores.hostStore.numHandles) + assertResult(0)(SpillFramework.stores.diskStore.numHandles) + assert(handle.dev.isEmpty) + assert(handle.host.isDefined) + assert(handle.host.get.host.isDefined) + + assertResult(123)(SpillFramework.stores.hostStore.spill(123)) // spill to disk + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + assertResult(0)(SpillFramework.stores.hostStore.numHandles) + assertResult(1)(SpillFramework.stores.diskStore.numHandles) + assert(handle.dev.isEmpty) + assert(handle.host.isDefined) + assert(handle.host.get.host.isEmpty) + assert(handle.host.get.disk.isDefined) + } + assert(handle.host.isEmpty) + assert(handle.dev.isEmpty) + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + assertResult(0)(SpillFramework.stores.hostStore.numHandles) + assertResult(0)(SpillFramework.stores.diskStore.numHandles) + } + + test("spill updates catalog") { + val diskStore = SpillFramework.stores.diskStore + val hostStore = SpillFramework.stores.hostStore + val deviceStore = SpillFramework.stores.deviceStore + + val (bufferSize, handle, _) = + addContiguousTableToCatalog() + + withResource(handle) { _ => + assertResult(1)(deviceStore.numHandles) + assertResult(0)(diskStore.numHandles) + assertResult(0)(hostStore.numHandles) + + assertResult(bufferSize)(SpillFramework.stores.deviceStore.spill(bufferSize)) + assertResult(bufferSize)(SpillFramework.stores.hostStore.spill(bufferSize)) + + assertResult(0)(deviceStore.numHandles) + assertResult(0)(hostStore.numHandles) + assertResult(1)(diskStore.numHandles) + + val diskHandle = handle.host.flatMap(_.disk).get + val path = diskStore.getFile(diskHandle.blockId) + assert(path.exists) + } + } + + test("get columnar batch after host spill") { + val (ct, dataTypes) = buildContiguousTable() + val expectedBatch = GpuColumnVector.from(ct.getTable, dataTypes) + withResource(SpillableColumnarBatchFromBufferHandle( + ct, dataTypes)) { handle => + withResource(expectedBatch) { _ => + SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + withResource(handle.materialize(dataTypes)) { cb => + TestUtils.compareBatches(expectedBatch, cb) + } + } + } + } + + test("get memory buffer after host spill") { + val (ct, dataTypes) = buildContiguousTable() + val expectedBatch = closeOnExcept(ct) { _ => + // make a copy of the table so we can compare it later to the + // one reconstituted after the spill + withResource(ct.getTable.contiguousSplit()) { copied => + GpuColumnVector.from(copied(0).getTable, dataTypes) + } + } + val handle = SpillableColumnarBatchFromBufferHandle(ct, dataTypes) + withResource(handle) { _ => + withResource(expectedBatch) { _ => + assertResult(SpillFramework.stores.deviceStore.spill(handle.sizeInBytes))( + handle.sizeInBytes) + val hostSize = handle.host.get.sizeInBytes + assertResult(SpillFramework.stores.hostStore.spill(hostSize))(hostSize) + withResource(handle.materialize(dataTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } + } + } + + test("host originated: get host memory buffer") { + val spillPriority = -10 + val hmb = HostMemoryBuffer.allocate(1L * 1024) + val spillableBuffer = SpillableHostBuffer(hmb, hmb.getLength, spillPriority) + withResource(spillableBuffer) { _ => + // the refcount of 1 is the store + assertResult(1)(hmb.getRefCount) + withResource(spillableBuffer.getHostBuffer) { memoryBuffer => + assertResult(hmb)(memoryBuffer) + assertResult(2)(memoryBuffer.getRefCount) + } + } + assertResult(0)(hmb.getRefCount) + } + + test("host originated: get host memory buffer after spill to disk") { + val spillPriority = -10 + val hmb = HostMemoryBuffer.allocate(1L * 1024) + val spillableBuffer = SpillableHostBuffer( + hmb, + hmb.getLength, + spillPriority) + assertResult(1)(hmb.getRefCount) + // we spill it + SpillFramework.stores.hostStore.spill(hmb.getLength) + withResource(spillableBuffer) { _ => + // the refcount of the original buffer is 0 because it spilled + assertResult(0)(hmb.getRefCount) + withResource(spillableBuffer.getHostBuffer) { memoryBuffer => + assertResult(memoryBuffer.getLength)(hmb.getLength) + } + } + } + + test("host originated: a buffer is not spillable when we leak it") { + val spillPriority = -10 + val hmb = HostMemoryBuffer.allocate(1L * 1024) + withResource(SpillableHostBuffer(hmb, hmb.getLength, spillPriority)) { spillableBuffer => + withResource(spillableBuffer.getHostBuffer) { _ => + assertResult(0)(SpillFramework.stores.hostStore.spill(hmb.getLength)) + } + assertResult(hmb.getLength)(SpillFramework.stores.hostStore.spill(hmb.getLength)) + } + } + + test("host originated: a host batch is not spillable when we leak it") { + val (hostCb, sparkTypes) = buildHostBatch() + val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) + withResource(SpillableHostColumnarBatchHandle(hostCb)) { handle => + assertResult(true)(handle.spillable) + + withResource(handle.materialize(sparkTypes)) { _ => + // 0 because we have a reference to the host batch + assertResult(false)(handle.spillable) + assertResult(0)(SpillFramework.stores.hostStore.spill(sizeOnHost)) + } + + // after closing we still have 0 bytes in the store or available to spill + assertResult(true)(handle.spillable) + } + } + + test("host originated: a host batch is not spillable when columns are incRefCounted") { + val (hostCb, sparkTypes) = buildHostBatch() + val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) + withResource(SpillableHostColumnarBatchHandle(hostCb)) { handle => + assertResult(true)(handle.spillable) + val leakedFirstColumn = withResource(handle.materialize(sparkTypes)) { cb => + // 0 because we have a reference to the host batch + assertResult(false)(handle.spillable) + assertResult(0)(SpillFramework.stores.hostStore.spill(sizeOnHost)) + // leak it by increasing the ref count of the underlying cuDF column + RapidsHostColumnVector.extractBases(cb).head.incRefCount() + } + withResource(leakedFirstColumn) { _ => + // 0 because we have a reference to the first column + assertResult(false)(handle.spillable) + assertResult(0)(SpillFramework.stores.hostStore.spill(sizeOnHost)) + } + // batch is now spillable because we close our reference to the column + assertResult(true)(handle.spillable) + assertResult(sizeOnHost)(SpillFramework.stores.hostStore.spill(sizeOnHost)) + } + } + + test("host originated: an aliased host batch is not spillable (until closing the original) ") { + val (hostBatch, sparkTypes) = buildHostBatch() + val handle = SpillableHostColumnarBatchHandle(hostBatch) + withResource(handle) { _ => + assertResult(1)(SpillFramework.stores.hostStore.numHandles) + assertResult(true)(handle.spillable) + withResource(handle.materialize(sparkTypes)) { _ => + assertResult(false)(handle.spillable) + } // we now have two copies in the store + assertResult(true)(handle.spillable) + } + } + + test("host originated: a host batch supports duplicated columns") { + val (hostBatch, sparkTypes) = buildHostBatchWithDuplicate() + val handle = SpillableHostColumnarBatchHandle(hostBatch) + withResource(handle) { _ => + assertResult(1)(SpillFramework.stores.hostStore.numHandles) + assertResult(true)(handle.spillable) + withResource(handle.materialize(sparkTypes)) { _ => + assertResult(false)(handle.spillable) + } // we now have two copies in the store + assertResult(true)(handle.spillable) + } + } + + test("host originated: a host batch supports aliasing and duplicated columns") { + SpillFramework.shutdown() + val sc = new SparkConf + // disables the host store limit + sc.set(RapidsConf.OFF_HEAP_LIMIT_ENABLED.key, "true") + SpillFramework.initialize(new RapidsConf(sc)) + + try { + val (hostBatch, sparkTypes) = buildHostBatchWithDuplicate() + withResource(SpillableHostColumnarBatchHandle(hostBatch)) { handle => + withResource(SpillableHostColumnarBatchHandle(handle.materialize(sparkTypes))) { handle2 => + assertResult(2)(SpillFramework.stores.hostStore.numHandles) + assertResult(false)(handle.spillable) + assertResult(false)(handle2.spillable) + } + assertResult(true)(handle.spillable) + } + } finally { + SpillFramework.shutdown() + } + } + + test("direct spill to disk: when buffer exceeds maximum size") { + var (bigTable, sparkTypes) = buildTableOfLongs(2 * 1024 * 1024) + closeOnExcept(bigTable) { _ => + // make a copy of the table so we can compare it later to the + // one reconstituted after the spill + val expectedBatch = + withResource(bigTable.getTable.contiguousSplit()) { expectedTable => + GpuColumnVector.from(expectedTable(0).getTable, sparkTypes) + } + withResource(expectedBatch) { _ => + withResource(SpillableColumnarBatchFromBufferHandle( + bigTable, sparkTypes)) { bigHandle => + bigTable = null + withResource(bigHandle.materialize(sparkTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + SpillFramework.stores.deviceStore.spill(bigHandle.sizeInBytes) + assertResult(true)(bigHandle.dev.isEmpty) + assertResult(true)(bigHandle.host.get.host.isEmpty) + assertResult(false)(bigHandle.host.get.disk.isEmpty) + + withResource(bigHandle.materialize(sparkTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } + } + } + } + + test("get columnar batch after spilling to disk") { + val (size, handle, dataTypes) = addContiguousTableToCatalog() + val diskStore = SpillFramework.stores.diskStore + val hostStore = SpillFramework.stores.hostStore + val deviceStore = SpillFramework.stores.deviceStore + withResource(handle) { _ => + assertResult(1)(deviceStore.numHandles) + assertResult(0)(diskStore.numHandles) + assertResult(0)(hostStore.numHandles) + + val expectedTable = + withResource(handle.materialize(dataTypes)) { beforeSpill => + withResource(GpuColumnVector.from(beforeSpill)) { table => + table.contiguousSplit()(0) + } + } // closing the batch from the store so that we can spill it + + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable.getTable, dataTypes)) { expectedBatch => + deviceStore.spill(size) + hostStore.spill(size) + + assertResult(0)(deviceStore.numHandles) + assertResult(0)(hostStore.numHandles) + assertResult(1)(diskStore.numHandles) + + val diskHandle = handle.host.flatMap(_.disk).get + val path = diskStore.getFile(diskHandle.blockId) + assert(path.exists) + withResource(handle.materialize(dataTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } + } + } + } + + test("get memory buffer after spilling to disk") { + val handle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(123)) + val diskStore = SpillFramework.stores.diskStore + val hostStore = SpillFramework.stores.hostStore + val deviceStore = SpillFramework.stores.deviceStore + withResource(handle) { _ => + assertResult(1)(deviceStore.numHandles) + assertResult(0)(diskStore.numHandles) + assertResult(0)(hostStore.numHandles) + val expectedBuffer = + withResource(handle.materialize) { devbuf => + closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => + hostbuf.copyFromDeviceBuffer(devbuf) + hostbuf + } + } + withResource(expectedBuffer) { expectedBuffer => + deviceStore.spill(handle.sizeInBytes) + hostStore.spill(handle.sizeInBytes) + withResource(handle.host.map(_.materialize).get) { actualHostBuffer => + assertResult(expectedBuffer. + asByteBuffer.limit())(actualHostBuffer.asByteBuffer.limit()) + } + } + } + } + + test("Compression on with or without encryption for spill block using single batch") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.io.compression.codec", "zstd") + conf.set("spark.shuffle.spill.compress", "true") + conf.set("spark.shuffle.compress", "true") + readWriteTestWithBatches(conf, false) + } + } + + test("Compression off with or without encryption for spill block using single batch") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.shuffle.spill.compress", "false") + conf.set("spark.shuffle.compress", "false") + readWriteTestWithBatches(conf, false) + } + } + + test("Compression on with or without encryption for spill block using multiple batches") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.io.compression.codec", "zstd") + conf.set("spark.shuffle.spill.compress", "true") + conf.set("spark.shuffle.compress", "true") + readWriteTestWithBatches(conf, false) + } + } + + test("Compression off with or without encryption for spill block using multiple batches") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.shuffle.spill.compress", "false") + conf.set("spark.shuffle.compress", "false") + readWriteTestWithBatches(conf, false) + } + } + + // ===== Tests for shuffle block ===== + + test("Compression on with or without encryption for shuffle block using single batch") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.io.compression.codec", "zstd") + conf.set("spark.shuffle.spill.compress", "true") + conf.set("spark.shuffle.compress", "true") + readWriteTestWithBatches(conf, true) + } + } + + test("Compression off with or without encryption for shuffle block using single batch") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.shuffle.spill.compress", "false") + conf.set("spark.shuffle.compress", "false") + readWriteTestWithBatches(conf, true) + } + } + + test("Compression on with or without encryption for shuffle block using multiple batches") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.io.compression.codec", "zstd") + conf.set("spark.shuffle.spill.compress", "true") + conf.set("spark.shuffle.compress", "true") + readWriteTestWithBatches(conf, true, true) + } + } + + test("Compression off with or without encryption for shuffle block using multiple batches") { + Seq("true", "false").foreach { encryptionEnabled => + val conf = new SparkConf() + conf.set(RapidsConf.TEST_IO_ENCRYPTION.key, encryptionEnabled) + conf.set("spark.shuffle.spill.compress", "false") + conf.set("spark.shuffle.compress", "false") + readWriteTestWithBatches(conf, true, true) + } + } + + test("No encryption and compression for shuffle block using multiple batches") { + readWriteTestWithBatches(new SparkConf(), true, true) + } + + private def readWriteTestWithBatches(conf: SparkConf, shareDiskPaths: Boolean*) = { + assert(shareDiskPaths.nonEmpty) + val mockDiskBlockManager = mock[RapidsDiskBlockManager] + when(mockDiskBlockManager.getSerializerManager()) + .thenReturn(new RapidsSerializerManager(conf)) + + shareDiskPaths.foreach { _ => + val (_, handle, dataTypes) = addContiguousTableToCatalog() + withResource(handle) { _ => + val expectedCt = withResource(handle.materialize(dataTypes)) { devbatch => + withResource(GpuColumnVector.from(devbatch)) { tmpTbl => + tmpTbl.contiguousSplit()(0) + } + } + withResource(expectedCt) { _ => + val expectedBatch = withResource(expectedCt.getTable) { expectedTbl => + GpuColumnVector.from(expectedTbl, dataTypes) + } + withResource(expectedBatch) { _ => + assertResult(true)(SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) > 0) + assertResult(true)(SpillFramework.stores.hostStore.spill(handle.sizeInBytes) > 0) + withResource(handle.materialize(dataTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } + } + } + } + } + + test("skip host: spill device memory buffer to disk") { + SpillFramework.shutdown() + try { + val sc = new SparkConf + // disables the host store limit + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1KB") + SpillFramework.initialize(new RapidsConf(sc)) + // buffer is too big for host store limit, so we will skip host + val handle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(1025)) + val deviceStore = SpillFramework.stores.deviceStore + withResource(handle) { _ => + val expectedBuffer = + withResource(handle.materialize) { devbuf => + closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => + hostbuf.copyFromDeviceBuffer(devbuf) + hostbuf + } + } + + withResource(expectedBuffer) { _ => + // host store will fail to spill + deviceStore.spill(handle.sizeInBytes) + assert(handle.host.map(_.host.isEmpty).get) + assert(handle.host.map(_.disk.isDefined).get) + withResource(handle.host.map(_.materialize).get) { buffer => + assertResult(expectedBuffer.asByteBuffer)(buffer.asByteBuffer) + } + } + } + } finally { + SpillFramework.shutdown() + } + } + + test("skip host: spill table to disk") { + SpillFramework.shutdown() + try { + val sc = new SparkConf + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1KB") + SpillFramework.initialize(new RapidsConf(sc)) + // fill up the host store + withResource(SpillableHostBufferHandle(HostMemoryBuffer.allocate(1024))) { hostHandle => + // make sure the host handle isn't spillable + withResource(hostHandle.materialize) { _ => + val (handle, _) = addTableToCatalog() + withResource(handle) { _ => + val (expectedTable, dataTypes) = buildTable() + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable, dataTypes)) { expectedBatch => + SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + assert(handle.host.map(_.host.isEmpty).get) + assert(handle.host.map(_.disk.isDefined).get) + withResource(handle.materialize(dataTypes)) { fromDiskBatch => + TestUtils.compareBatches(expectedBatch, fromDiskBatch) + assert(handle.dev.isEmpty) + assert(handle.host.map(_.host.isEmpty).get) + assert(handle.host.map(_.disk.isDefined).get) + } + } + } + } + } + } + } finally { + SpillFramework.shutdown() + } + } + + test("skip host: spill table to disk with small host bounce buffer") { + try { + SpillFramework.shutdown() + val sc = new SparkConf + // make this super small so we skip the host + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1") + sc.set(RapidsConf.SPILL_TO_DISK_BOUNCE_BUFFER_SIZE.key, "10") + sc.set(RapidsConf.CHUNKED_PACK_BOUNCE_BUFFER_SIZE.key, "1MB") + val rapidsConf = new RapidsConf(sc) + SpillFramework.initialize(rapidsConf) + val (handle, _) = addTableToCatalog() + withResource(handle) { _ => + val (expectedTable, dataTypes) = buildTable() + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable, dataTypes)) { expectedBatch => + SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + assert(handle.dev.isEmpty) + assert(handle.host.map(_.host.isEmpty).get) + assert(handle.host.map(_.disk.isDefined).get) + withResource(handle.materialize(dataTypes)) { fromDiskBatch => + TestUtils.compareBatches(expectedBatch, fromDiskBatch) + } + } + } + } + } finally { + SpillFramework.shutdown() + } + } + + test("0-byte table is never spillable") { + val (handle, _) = addZeroRowsTableToCatalog() + val (handle2, _) = addTableToCatalog() + + withResource(handle) { _ => + withResource(handle2) { _ => + assert(handle2.host.isEmpty) + val (expectedTable, expectedTypes) = buildTable() + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable, expectedTypes)) { expectedCb => + SpillFramework.stores.deviceStore.spill(handle.sizeInBytes + handle2.sizeInBytes) + SpillFramework.stores.hostStore.spill(handle.sizeInBytes + handle2.sizeInBytes) + // the 0-byte table never moved from device. It is not spillable + assert(handle.host.isEmpty) + assert(!handle.spillable) + // the second table (with rows) did spill + assert(handle2.host.isDefined) + assert(handle2.host.map(_.host.isEmpty).get) + assert(handle2.host.map(_.disk.isDefined).get) + + withResource(handle2.materialize(expectedTypes)) { spilledBatch => + TestUtils.compareBatches(expectedCb, spilledBatch) + } + } + } + } + } + } + + test("exclusive spill files are deleted when buffer deleted") { + testBufferFileDeletion(canShareDiskPaths = false) + } + + test("shared spill files are not deleted when a buffer is deleted") { + testBufferFileDeletion(canShareDiskPaths = true) + } + +} \ No newline at end of file diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/WindowRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/WindowRetrySuite.scala index 0ecc5faf1a3..bf61e495c06 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/WindowRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/WindowRetrySuite.scala @@ -80,7 +80,7 @@ class WindowRetrySuite assertResult(4)(hostCol.getLong(row)) } } - verify(inputBatch, times(2)).getColumnarBatch() + verify(inputBatch, times(2)).getColumnarBatch verify(inputBatch, times(1)).close() } } @@ -102,7 +102,7 @@ class WindowRetrySuite assertResult(row + 1)(hostCol.getLong(row)) } } - verify(inputBatch, times(2)).getColumnarBatch() + verify(inputBatch, times(2)).getColumnarBatch verify(inputBatch, times(1)).close() } } @@ -126,7 +126,7 @@ class WindowRetrySuite assertResult(4)(hostCol.getLong(row)) } } - verify(inputBatch, times(2)).getColumnarBatch() + verify(inputBatch, times(2)).getColumnarBatch verify(inputBatch, times(1)).close() } } @@ -143,7 +143,7 @@ class WindowRetrySuite assertThrows[GpuSplitAndRetryOOM] { it.next() } - verify(inputBatch, times(1)).getColumnarBatch() + verify(inputBatch, times(1)).getColumnarBatch verify(inputBatch, times(1)).close() } @@ -173,7 +173,7 @@ class WindowRetrySuite } } } - verify(inputBatch, times(2)).getColumnarBatch() + verify(inputBatch, times(2)).getColumnarBatch verify(inputBatch, times(1)).close() } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala index aa003c454f1..d617726709d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar +import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.{DataType, LongType} @@ -52,19 +53,16 @@ class WithRetrySuite rmmWasInitialized = true Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) } - val deviceStorage = new RapidsDeviceMemoryStore() - val catalog = new RapidsBufferCatalog(deviceStorage) - RapidsBufferCatalog.setDeviceStorage(deviceStorage) - RapidsBufferCatalog.setCatalog(catalog) + SpillFramework.initialize(new RapidsConf(new SparkConf)) val mockEventHandler = new BaseRmmEventHandler() RmmSpark.setEventHandler(mockEventHandler) RmmSpark.currentThreadIsDedicatedToTask(1) } override def afterEach(): Unit = { + SpillFramework.shutdown() RmmSpark.removeAllCurrentThreadAssociation() RmmSpark.clearEventHandler() - RapidsBufferCatalog.close() if (rmmWasInitialized) { Rmm.shutdown() } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala index e9873b0bc5f..515ff75ffaa 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -255,7 +255,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { verify(client, times(1)).track(any[DeviceMemoryBuffer](), tmCaptor.capture()) verifyTableMeta(tableMeta, tmCaptor.getValue.asInstanceOf[TableMeta]) verify(mockCatalog, times(1)) - .addBuffer(dmbCaptor.capture(), any(), any(), any()) + .addBuffer(dmbCaptor.capture(), any(), any()) val receivedBuff = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer] assertResult(tableMeta.bufferMeta().size())(receivedBuff.getLength) @@ -310,8 +310,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { verify(client, times(1)).track(any[DeviceMemoryBuffer](), tmCaptor.capture()) verifyTableMeta(tableMeta, tmCaptor.getValue.asInstanceOf[TableMeta]) verify(mockCatalog, times(1)) - .addBuffer(dmbCaptor.capture(), any(), any(), any()) - verify(mockCatalog, times(1)).removeBuffer(any()) + .addBuffer(dmbCaptor.capture(), any(), any()) val receivedBuff = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer] assertResult(tableMeta.bufferMeta().size())(receivedBuff.getLength) @@ -367,7 +366,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { } verify(mockCatalog, times(5)) - .addBuffer(dmbCaptor.capture(), any(), any(), any()) + .addBuffer(dmbCaptor.capture(), any(), any()) assertResult(totalExpectedSize)( dmbCaptor.getAllValues().toArray().map(_.asInstanceOf[DeviceMemoryBuffer].getLength).sum) @@ -424,7 +423,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { } verify(mockCatalog, times(20)) - .addBuffer(dmbCaptor.capture(), any(), any(), any()) + .addBuffer(dmbCaptor.capture(), any(), any()) assertResult(totalExpectedSize)( dmbCaptor.getAllValues().toArray().map(_.asInstanceOf[DeviceMemoryBuffer].getLength).sum) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala index 70064682ed0..bf379f02937 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shuffle -import com.nvidia.spark.rapids.{RapidsBuffer, RapidsBufferHandle} +import com.nvidia.spark.rapids.{RapidsShuffleHandle, SpillableDeviceBufferHandle} import com.nvidia.spark.rapids.jni.RmmSpark import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers._ @@ -214,13 +214,14 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any())).thenReturn(client) doNothing().when(client).doFetch(any(), ac.capture()) - val mockBuffer = mock[RapidsBuffer] + val mockBuffer = RapidsShuffleHandle(mock[SpillableDeviceBufferHandle], null) + when(mockBuffer.spillable.sizeInBytes).thenReturn(123L) val cb = new ColumnarBatch(Array.empty, 10) - val handle = mock[RapidsBufferHandle] - when(mockBuffer.getColumnarBatch(Array.empty)).thenReturn(cb) - when(mockCatalog.acquireBuffer(any[RapidsBufferHandle]())).thenReturn(mockBuffer) - doNothing().when(mockCatalog).removeBuffer(any()) + val handle = mock[RapidsShuffleHandle] + doAnswer(_ => (cb, 123L)).when(mockCatalog) + .getColumnarBatchAndRemove(any[RapidsShuffleHandle](), any()) + cl.start() val handler = ac.getValue.asInstanceOf[RapidsShuffleFetchHandler] @@ -232,7 +233,7 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { assert(cl.hasNext) assertResult(cb)(cl.next()) assertResult(1)(testMetricsUpdater.totalRemoteBlocksFetched) - assertResult(mockBuffer.memoryUsedBytes)(testMetricsUpdater.totalRemoteBytesRead) + assertResult(123L)(testMetricsUpdater.totalRemoteBytesRead) assertResult(10)(testMetricsUpdater.totalRowsFetched) } finally { RmmSpark.taskDone(taskId) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala index 3eb73ef0f13..ccc63a597fd 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,73 +18,57 @@ package com.nvidia.spark.rapids.shuffle import java.io.IOException import java.nio.ByteBuffer -import java.util -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.{MetaUtils, RapidsBuffer, ShuffleMetadata} +import ai.rapids.cudf.{DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} +import com.nvidia.spark.rapids.{MetaUtils, RapidsShuffleHandle, ShuffleMetadata, SpillableDeviceBufferHandle} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.format.TableMeta import org.mockito.{ArgumentCaptor, ArgumentMatchers} -import org.mockito.ArgumentMatchers.{any, anyLong} +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.apache.spark.storage.ShuffleBlockBatchId -class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { +class MockRapidsShuffleRequestHandler(mockBuffers: Seq[RapidsShuffleHandle]) + extends RapidsShuffleRequestHandler with AutoCloseable { + var acquiredTables = Seq[Int]() + override def getShuffleBufferMetas( + shuffleBlockBatchId: ShuffleBlockBatchId): Seq[TableMeta] = { + throw new NotImplementedError("getShuffleBufferMetas") + } - def setupMocks(deviceBuffers: Seq[DeviceMemoryBuffer]): (RapidsShuffleRequestHandler, - Seq[RapidsBuffer], util.HashMap[RapidsBuffer, Int]) = { + override def getShuffleHandle(tableId: Int): RapidsShuffleHandle = { + acquiredTables = acquiredTables :+ tableId + mockBuffers(tableId) + } - val numCloses = new util.HashMap[RapidsBuffer, Int]() + override def close(): Unit = { + // a removeShuffle action would likewise remove handles + mockBuffers.foreach(_.close()) + } +} + +class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { + + def setupMocks(deviceBuffers: Seq[DeviceMemoryBuffer]): MockRapidsShuffleRequestHandler = { val mockBuffers = deviceBuffers.map { deviceBuffer => withResource(HostMemoryBuffer.allocate(deviceBuffer.getLength)) { hostBuff => fillBuffer(hostBuff) deviceBuffer.copyFromHostBuffer(hostBuff) - val mockBuffer = mock[RapidsBuffer] val mockMeta = RapidsShuffleTestHelper.mockTableMeta(100000) - when(mockBuffer.copyToMemoryBuffer(anyLong(), any[MemoryBuffer](), anyLong(), anyLong(), - any[Cuda.Stream]())).thenAnswer { invocation => - // start at 1 close, since we'll need to close at refcount 0 too - val newNumCloses = numCloses.getOrDefault(mockBuffer, 1) + 1 - numCloses.put(mockBuffer, newNumCloses) - val srcOffset = invocation.getArgument[Long](0) - val dst = invocation.getArgument[MemoryBuffer](1) - val dstOffset = invocation.getArgument[Long](2) - val length = invocation.getArgument[Long](3) - val stream = invocation.getArgument[Cuda.Stream](4) - dst.copyFromMemoryBuffer(dstOffset, deviceBuffer, srcOffset, length, stream) - } - when(mockBuffer.getPackedSizeBytes).thenReturn(deviceBuffer.getLength) - when(mockBuffer.meta).thenReturn(mockMeta) - mockBuffer + RapidsShuffleHandle(SpillableDeviceBufferHandle(deviceBuffer), mockMeta) } } - - val handler = new RapidsShuffleRequestHandler { - var acquiredTables = Seq[Int]() - override def getShuffleBufferMetas( - shuffleBlockBatchId: ShuffleBlockBatchId): Seq[TableMeta] = { - throw new NotImplementedError("getShuffleBufferMetas") - } - - override def acquireShuffleBuffer(tableId: Int): RapidsBuffer = { - acquiredTables = acquiredTables :+ tableId - mockBuffers(tableId) - } - } - (handler, mockBuffers, numCloses) + new MockRapidsShuffleRequestHandler(mockBuffers) } - class MockBlockWithSize(val b: DeviceMemoryBuffer) extends BlockWithSize { - override def size: Long = b.getLength - } + class MockBlockWithSize(override val size: Long) extends BlockWithSize {} def compareRanges( bounceBuffer: SendBounceBuffers, - receiveBlocks: Seq[BlockRange[MockBlockWithSize]]): Unit = { + receiveBlocks: Seq[(BlockRange[MockBlockWithSize], DeviceMemoryBuffer)]): Unit = { var bounceBuffOffset = 0L - receiveBlocks.foreach { range => - val deviceBuff = range.block.b + receiveBlocks.foreach { case (range, deviceBuff) => val deviceBounceBuff = bounceBuffer.deviceBounceBuffer.buffer withResource(deviceBounceBuff.slice(bounceBuffOffset, range.rangeSize())) { bbSlice => bounceBuffOffset = bounceBuffOffset + range.rangeSize() @@ -104,26 +88,24 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { when(mockTx.releaseMessage()).thenReturn(transferRequest) val bb = closeOnExcept(getSendBounceBuffer(10000)) { bounceBuffer => - withResource((0 until 10).map(_ => DeviceMemoryBuffer.allocate(1000))) { deviceBuffers => - val receiveSide = deviceBuffers.map(b => new MockBlockWithSize(b)) + val deviceBuffers = (0 until 10).map(_ => DeviceMemoryBuffer.allocate(1000)) + val receiveSide = deviceBuffers.map(_ => new MockBlockWithSize(1000)) + withResource(setupMocks(deviceBuffers)) { handler => val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) - val (handler, mockBuffers, numCloses) = setupMocks(deviceBuffers) withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss => assert(bss.hasMoreSends) withResource(bss.getBufferToSend()) { mb => val receiveBlocks = receiveWindow.next() - compareRanges(bounceBuffer, receiveBlocks) + compareRanges(bounceBuffer, receiveBlocks.zip(deviceBuffers)) assertResult(10000)(mb.getLength) assert(!bss.hasMoreSends) bss.releaseAcquiredToCatalog() - mockBuffers.foreach { b: RapidsBuffer => - // should have seen 2 closes, one for BufferSendState acquiring for metadata - // and the second acquisition for copying - verify(b, times(numCloses.get(b))).close() - } } } } + deviceBuffers.foreach { b => + assertResult(0)(b.getRefCount) + } bounceBuffer } assert(bb.deviceBounceBuffer.isClosed) @@ -136,32 +118,29 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { when(mockTx.releaseMessage()).thenReturn(transferRequest) val bb = closeOnExcept(getSendBounceBuffer(10000)) { bounceBuffer => - withResource((0 until 20).map(_ => DeviceMemoryBuffer.allocate(1000))) { deviceBuffers => - val receiveSide = deviceBuffers.map(b => new MockBlockWithSize(b)) - val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) - val (handler, mockBuffers, numCloses) = setupMocks(deviceBuffers) + val deviceBuffers = (0 until 20).map(_ => DeviceMemoryBuffer.allocate(1000)) + val receiveSide = deviceBuffers.map(_ => new MockBlockWithSize(1000)) + val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) + withResource(setupMocks(deviceBuffers)) { handler => withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss => withResource(bss.getBufferToSend()) { _ => val receiveBlocks = receiveWindow.next() - compareRanges(bounceBuffer, receiveBlocks) + compareRanges(bounceBuffer, receiveBlocks.zip(deviceBuffers)) assert(bss.hasMoreSends) bss.releaseAcquiredToCatalog() } withResource(bss.getBufferToSend()) { _ => val receiveBlocks = receiveWindow.next() - compareRanges(bounceBuffer, receiveBlocks) + compareRanges(bounceBuffer, receiveBlocks.zip(deviceBuffers)) assert(!bss.hasMoreSends) bss.releaseAcquiredToCatalog() } - - mockBuffers.foreach { b: RapidsBuffer => - // should have seen 2 closes, one for BufferSendState acquiring for metadata - // and the second acquisition for copying - verify(b, times(numCloses.get(b))).close() - } } } + deviceBuffers.foreach { b => + assertResult(0)(b.getRefCount) + } bounceBuffer } assert(bb.deviceBounceBuffer.isClosed) @@ -174,24 +153,23 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { when(mockTx.releaseMessage()).thenReturn(transferRequest) val bb = closeOnExcept(getSendBounceBuffer(10000)) { bounceBuffer => - withResource((0 until 20).map(_ => DeviceMemoryBuffer.allocate(123000))) { deviceBuffers => - val (handler, mockBuffers, numCloses) = setupMocks(deviceBuffers) - - val receiveSide = deviceBuffers.map(b => new MockBlockWithSize(b)) + val deviceBuffers = (0 until 20).map(_ => DeviceMemoryBuffer.allocate(123000)) + withResource(setupMocks(deviceBuffers)) { handler => + val receiveSide = deviceBuffers.map(_ => new MockBlockWithSize(123000)) val receiveWindow = new WindowedBlockIterator[MockBlockWithSize](receiveSide, 10000) withResource(new BufferSendState(mockTx, bounceBuffer, handler)) { bss => (0 until 246).foreach { _ => withResource(bss.getBufferToSend()) { _ => val receiveBlocks = receiveWindow.next() - compareRanges(bounceBuffer, receiveBlocks) + compareRanges(bounceBuffer, receiveBlocks.zip(deviceBuffers)) bss.releaseAcquiredToCatalog() } } assert(!bss.hasMoreSends) } - mockBuffers.foreach { b: RapidsBuffer => - verify(b, times(numCloses.get(b))).close() - } + } + deviceBuffers.foreach { b => + assertResult(0)(b.getRefCount) } bounceBuffer } @@ -224,14 +202,13 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { any(), any(), any(), any[MemoryBuffer](), ac.capture())).thenReturn(mockTransaction) val mockRequestHandler = mock[RapidsShuffleRequestHandler] - val rapidsBuffer = mock[RapidsBuffer] val bb = ByteBuffer.allocateDirect(123) withResource(new RefCountedDirectByteBuffer(bb)) { _ => val tableMeta = MetaUtils.buildTableMeta(1, 456, bb, 100) - when(rapidsBuffer.meta).thenReturn(tableMeta) - when(rapidsBuffer.getPackedSizeBytes).thenReturn(tableMeta.bufferMeta().size()) - when(mockRequestHandler.acquireShuffleBuffer(ArgumentMatchers.eq(1))) + val testHandle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(456)) + val rapidsBuffer = RapidsShuffleHandle(testHandle, tableMeta) + when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(1))) .thenReturn(rapidsBuffer) val server = new RapidsShuffleServer( @@ -245,7 +222,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { server.start() - val bss = new BufferSendState(mockTransaction, mockSendBuffer, mockRequestHandler, null) + val bss = new BufferSendState(mockTransaction, mockSendBuffer, mockRequestHandler) server.doHandleTransferRequest(Seq(bss)) val cb = ac.getValue.asInstanceOf[TransactionCallback] cb(mockTransaction) @@ -253,10 +230,10 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { cb(mockTransaction) // bounce buffers are freed verify(mockSendBuffer, times(1)).close() - // acquire 3 times, and close 3 times - verify(mockRequestHandler, times(3)) - .acquireShuffleBuffer(ArgumentMatchers.eq(1)) - verify(rapidsBuffer, times(3)).close() + // acquire once at the beginning, and closed at the end + verify(mockRequestHandler, times(1)) + .getShuffleHandle(ArgumentMatchers.eq(1)) + assertResult(1)(rapidsBuffer.spillable.dev.get.getRefCount) } } } @@ -264,74 +241,80 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { test("when we fail to prepare a send, throw if nothing can be handled") { val mockSendBuffer = mock[SendBounceBuffers] val mockDeviceBounceBuffer = mock[BounceBuffer] - withResource(DeviceMemoryBuffer.allocate(123)) { buff => - when(mockDeviceBounceBuffer.buffer).thenReturn(buff) - when(mockSendBuffer.bounceBufferSize).thenReturn(buff.getLength) - when(mockSendBuffer.hostBounceBuffer).thenReturn(None) - when(mockSendBuffer.deviceBounceBuffer).thenReturn(mockDeviceBounceBuffer) - - when(mockTransport.tryGetSendBounceBuffers(any(), any())) - .thenReturn(Seq(mockSendBuffer)) - - val tr = ShuffleMetadata.buildTransferRequest(0, Seq(1)) - when(mockTransaction.getStatus) - .thenReturn(TransactionStatus.Success) - when(mockTransaction.releaseMessage()).thenReturn( - new MetadataTransportBuffer(new RefCountedDirectByteBuffer(tr))) - - val mockServerConnection = mock[ServerConnection] - val mockRequestHandler = mock[RapidsShuffleRequestHandler] - val rapidsBuffer = mock[RapidsBuffer] - - val bb = ByteBuffer.allocateDirect(123) - withResource(new RefCountedDirectByteBuffer(bb)) { _ => - val tableMeta = MetaUtils.buildTableMeta(1, 456, bb, 100) - when(rapidsBuffer.meta).thenReturn(tableMeta) - when(rapidsBuffer.getPackedSizeBytes).thenReturn(tableMeta.bufferMeta().size()) - when(mockRequestHandler.acquireShuffleBuffer(ArgumentMatchers.eq(1))) - .thenReturn(rapidsBuffer) - - val server = spy(new RapidsShuffleServer( - mockTransport, - mockServerConnection, - RapidsShuffleTestHelper.makeMockBlockManager("1", "foo"), - mockRequestHandler, - mockExecutor, - mockBssExecutor, - mockConf)) - - server.start() - - val ioe = new IOException("mmap failed in test") - - when(rapidsBuffer.copyToMemoryBuffer(any(), any(), any(), any(), any())) - .thenAnswer(_ => throw ioe) - - val bss = new BufferSendState(mockTransaction, mockSendBuffer, mockRequestHandler, null) - // if nothing else can be handled, we throw - assertThrows[IllegalStateException] { - try { - server.doHandleTransferRequest(Seq(bss)) - } catch { - case e: Throwable => - assertResult(1)(e.getSuppressed.length) - assertResult(ioe)(e.getSuppressed()(0).getCause) - throw e - } + val mockDeviceMemoryBuffer = mock[DeviceMemoryBuffer] + when(mockDeviceBounceBuffer.buffer).thenReturn(mockDeviceMemoryBuffer) + when(mockSendBuffer.bounceBufferSize).thenReturn(1024) + when(mockSendBuffer.hostBounceBuffer).thenReturn(None) + when(mockSendBuffer.deviceBounceBuffer).thenReturn(mockDeviceBounceBuffer) + + when(mockTransport.tryGetSendBounceBuffers(any(), any())) + .thenReturn(Seq(mockSendBuffer)) + + val tr = ShuffleMetadata.buildTransferRequest(0, Seq(1, 2)) + when(mockTransaction.getStatus) + .thenReturn(TransactionStatus.Success) + when(mockTransaction.releaseMessage()).thenReturn( + new MetadataTransportBuffer(new RefCountedDirectByteBuffer(tr))) + + val mockServerConnection = mock[ServerConnection] + val mockRequestHandler = mock[RapidsShuffleRequestHandler] + + val bb = ByteBuffer.allocateDirect(123) + withResource(new RefCountedDirectByteBuffer(bb)) { _ => + val tableMeta = MetaUtils.buildTableMeta(1, 456, bb, 100) + val mockHandle = mock[SpillableDeviceBufferHandle] + val mockHandleThatThrows = mock[SpillableDeviceBufferHandle] + val mockMaterialized = mock[DeviceMemoryBuffer] + when(mockHandle.sizeInBytes).thenReturn(tableMeta.bufferMeta().size()) + when(mockHandle.materialize).thenAnswer(_ => mockMaterialized) + + when(mockHandleThatThrows.sizeInBytes).thenReturn(tableMeta.bufferMeta().size()) + val ex = new IllegalStateException("something happened") + when(mockHandleThatThrows.materialize).thenThrow(ex) + + val rapidsBuffer = RapidsShuffleHandle(mockHandle, tableMeta) + val rapidsBufferThatThrows = RapidsShuffleHandle(mockHandleThatThrows, tableMeta) + + when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(1))) + .thenReturn(rapidsBuffer) + when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(2))) + .thenReturn(rapidsBufferThatThrows) + + val server = spy(new RapidsShuffleServer( + mockTransport, + mockServerConnection, + RapidsShuffleTestHelper.makeMockBlockManager("1", "foo"), + mockRequestHandler, + mockExecutor, + mockBssExecutor, + mockConf)) + + server.start() + + val bss = new BufferSendState(mockTransaction, mockSendBuffer, mockRequestHandler, null) + // if nothing else can be handled, we throw + assertThrows[IllegalStateException] { + try { + server.doHandleTransferRequest(Seq(bss)) + } catch { + case e: Throwable => + assertResult(1)(e.getSuppressed.length) + assertResult(ex)(e.getSuppressed()(0).getCause) + throw e } + } - // since nothing could be handled, we don't try again - verify(server, times(0)).addToContinueQueue(any()) + // since nothing could be handled, we don't try again + verify(server, times(0)).addToContinueQueue(any()) - // bounce buffers are freed - verify(mockSendBuffer, times(1)).close() + // bounce buffers are freed + verify(mockSendBuffer, times(1)).close() - // acquire 2 times, 1 to make the ranges, and the 2 before the copy - // close 2 times corresponding to each open - verify(mockRequestHandler, times(2)) - .acquireShuffleBuffer(ArgumentMatchers.eq(1)) - verify(rapidsBuffer, times(2)).close() - } + verify(mockRequestHandler, times(1)) + .getShuffleHandle(ArgumentMatchers.eq(1)) + + // the spillable that materialized we need to close + verify(mockMaterialized, times(1)).close() } } @@ -367,13 +350,24 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { val mockRequestHandler = mock[RapidsShuffleRequestHandler] - def makeMockBuffer(tableId: Int, bb: ByteBuffer): RapidsBuffer = { - val rapidsBuffer = mock[RapidsBuffer] + def makeMockBuffer(tableId: Int, bb: ByteBuffer, error: Boolean): RapidsShuffleHandle = { val tableMeta = MetaUtils.buildTableMeta(tableId, 456, bb, 100) - when(rapidsBuffer.meta).thenReturn(tableMeta) - when(rapidsBuffer.getPackedSizeBytes).thenReturn(tableMeta.bufferMeta().size()) - when(mockRequestHandler.acquireShuffleBuffer(ArgumentMatchers.eq(tableId))) - .thenReturn(rapidsBuffer) + val rapidsBuffer = if (error) { + val mockHandle = mock[SpillableDeviceBufferHandle] + val rapidsBuffer = RapidsShuffleHandle(mockHandle, tableMeta) + when(mockHandle.sizeInBytes).thenReturn(tableMeta.bufferMeta().size()) + // mock an error with the copy + when(rapidsBuffer.spillable.materialize) + .thenAnswer(_ => { + throw new IOException("mmap failed in test") + }) + rapidsBuffer + } else { + val testHandle = spy(SpillableDeviceBufferHandle(spy(DeviceMemoryBuffer.allocate(456)))) + RapidsShuffleHandle(testHandle, tableMeta) + } + when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(tableId))) + .thenAnswer(_ => rapidsBuffer) rapidsBuffer } @@ -381,19 +375,8 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { val bb2 = ByteBuffer.allocateDirect(123) withResource(new RefCountedDirectByteBuffer(bb)) { _ => withResource(new RefCountedDirectByteBuffer(bb2)) { _ => - val rapidsBuffer = makeMockBuffer(1, bb) - val rapidsBuffer2 = makeMockBuffer(2, bb2) - - // error with copy - when(rapidsBuffer.copyToMemoryBuffer(any(), any(), any(), any(), any())) - .thenAnswer(_ => { - throw new IOException("mmap failed in test") - }) - - // successful copy - doNothing() - .when(rapidsBuffer2) - .copyToMemoryBuffer(any(), any(), any(), any(), any()) + val rapidsHandle = makeMockBuffer(1, bb, error = true) + val rapidsHandle2 = makeMockBuffer(2, bb2, error = false) val server = spy(new RapidsShuffleServer( mockTransport, @@ -407,14 +390,21 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { server.start() val bssFailed = new BufferSendState( - mockTransaction, mockSendBuffer, mockRequestHandler, null) + mockTransaction, mockSendBuffer, mockRequestHandler) val bssSuccess = spy(new BufferSendState( - mockTransaction2, mockSendBuffer, mockRequestHandler, null)) - - when(bssSuccess.hasMoreSends) - .thenReturn(true) // send 1 bounce buffer length - .thenReturn(false) + mockTransaction2, mockSendBuffer, mockRequestHandler)) + + var callCount = 0 + doAnswer { _ => + callCount += 1 + // send 1 buffer length + if (callCount > 1){ + false + } else { + true + } + }.when(bssSuccess).hasMoreSends // if something else can be handled we don't throw, and re-queue server.doHandleTransferRequest(Seq(bssFailed, bssSuccess)) @@ -427,15 +417,17 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { // the bounce buffer is freed 1 time for `bssSuccess`, but not for `bssFailed` verify(mockSendBuffer, times(1)).close() - // acquire/close 4 times => - // we had two requests for 1 buffer, and each request acquires 2 times and closes - // 2 times. - verify(mockRequestHandler, times(2)) - .acquireShuffleBuffer(ArgumentMatchers.eq(1)) - verify(mockRequestHandler, times(2)) - .acquireShuffleBuffer(ArgumentMatchers.eq(2)) - verify(rapidsBuffer, times(2)).close() - verify(rapidsBuffer2, times(2)).close() + // we obtained the handles once, we don't need to get them again + verify(mockRequestHandler, times(1)) + .getShuffleHandle(ArgumentMatchers.eq(1)) + verify(mockRequestHandler, times(1)) + .getShuffleHandle(ArgumentMatchers.eq(2)) + // this handle fails to materialize + verify(rapidsHandle.spillable, times(1)).materialize + + // this handle materializes, so make sure we close it + verify(rapidsHandle2.spillable, times(1)).materialize + verify(rapidsHandle2.spillable.dev.get, times(1)).close() } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala index a9618a448cf..cd918694b43 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala @@ -51,7 +51,7 @@ import org.apache.spark.sql.types._ */ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAll { // perf test is disabled by default since it's a long running time in UT. - private val enablePerfTest = java.lang.Boolean.getBoolean("enableTimeZonePerf") + private val enablePerfTest = true // java.lang.Boolean.getBoolean("enableTimeZonePerf") private val timeZoneStrings = System.getProperty("TZs", "Asia/Shanghai") // rows for perf test @@ -65,11 +65,13 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl * Create a Parquet file to test */ override def beforeAll(): Unit = { + super.beforeAll() withCpuSparkSession( spark => createDF(spark).write.mode("overwrite").parquet(path)) } override def afterAll(): Unit = { + super.afterAll() FileUtils.deleteRecursively(new File(path)) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZoneSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZoneSuite.scala index dcfbc508034..b22f827916b 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZoneSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZoneSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -319,6 +319,7 @@ class TimeZoneSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAll { } override def afterAll(): Unit = { + super.afterAll() if (useGPU) { GpuTimeZoneDB.shutdown() } diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala index d52c8b47ae7..54866a19556 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala @@ -15,8 +15,8 @@ */ package org.apache.spark.sql.rapids -import ai.rapids.cudf.TableWriter -import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuColumnVector, GpuLiteral, RapidsBufferCatalog, RapidsDeviceMemoryStore, ScalableTaskCompletion} +import ai.rapids.cudf.{Rmm, RmmAllocationMode, TableWriter} +import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuColumnVector, GpuLiteral, RapidsConf, ScalableTaskCompletion, SpillFramework} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.jni.{GpuRetryOOM, GpuSplitAndRetryOOM} import org.apache.hadoop.conf.Configuration @@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar.mock +import org.apache.spark.SparkConf import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, ExprId, SortOrder} @@ -42,7 +43,6 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { private var mockCommitter: FileCommitProtocol = _ private var mockOutputWriterFactory: ColumnarOutputWriterFactory = _ private var mockOutputWriter: NoTransformColumnarOutputWriter = _ - private var devStore: RapidsDeviceMemoryStore = _ private var allCols: Seq[AttributeReference] = _ private var partSpec: Seq[AttributeReference] = _ private var dataSpec: Seq[AttributeReference] = _ @@ -175,16 +175,13 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { } override def beforeEach(): Unit = { - devStore = new RapidsDeviceMemoryStore() - val catalog = new RapidsBufferCatalog(devStore) - RapidsBufferCatalog.setCatalog(catalog) + Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) + SpillFramework.initialize(new RapidsConf(new SparkConf)) } override def afterEach(): Unit = { - // test that no buffers we left in the spill framework - assertResult(0)(RapidsBufferCatalog.numBuffers) - RapidsBufferCatalog.close() - devStore.close() + SpillFramework.shutdown() + Rmm.shutdown() } def buildEmptyBatch: ColumnarBatch = diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala index 001f82ab3a0..3c4be17fdf9 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,53 +16,29 @@ package org.apache.spark.sql.rapids -import java.util.UUID - -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.{RapidsBuffer, RapidsBufferCatalog, RapidsBufferId, SpillableColumnarBatchImpl, StorageTier} -import com.nvidia.spark.rapids.StorageTier.StorageTier -import com.nvidia.spark.rapids.format.TableMeta +import ai.rapids.cudf.DeviceMemoryBuffer +import com.nvidia.spark.rapids.{RapidsConf, SpillableBuffer, SpillFramework} +import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite -import org.apache.spark.sql.types.{DataType, IntegerType} -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.storage.TempLocalBlockId +import org.apache.spark.SparkConf -class SpillableColumnarBatchSuite extends AnyFunSuite { +class SpillableColumnarBatchSuite extends AnyFunSuite with BeforeAndAfterAll { + override def beforeAll(): Unit = { + super.beforeAll() + SpillFramework.initialize(new RapidsConf(new SparkConf())) + } - test("close updates catalog") { - val id = TempSpillBufferId(0, TempLocalBlockId(new UUID(1, 2))) - val mockBuffer = new MockBuffer(id) - val catalog = RapidsBufferCatalog.singleton - val oldBufferCount = catalog.numBuffers - catalog.registerNewBuffer(mockBuffer) - val handle = catalog.makeNewHandle(id, -1) - assertResult(oldBufferCount + 1)(catalog.numBuffers) - val spillableBatch = new SpillableColumnarBatchImpl( - handle, - 5, - Array[DataType](IntegerType)) - spillableBatch.close() - assertResult(oldBufferCount)(catalog.numBuffers) + override def afterAll(): Unit = { + super.afterAll() + SpillFramework.shutdown() } - class MockBuffer(override val id: RapidsBufferId) extends RapidsBuffer { - override val memoryUsedBytes: Long = 123 - override def meta: TableMeta = null - override val storageTier: StorageTier = StorageTier.DEVICE - override def getMemoryBuffer: MemoryBuffer = null - override def copyToMemoryBuffer(srcOffset: Long, dst: MemoryBuffer, dstOffset: Long, - length: Long, stream: Cuda.Stream): Unit = {} - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = null - override def getHostMemoryBuffer: HostMemoryBuffer = null - override def addReference(): Boolean = true - override def free(): Unit = {} - override def getSpillPriority: Long = 0 - override def setSpillPriority(priority: Long): Unit = {} - override def close(): Unit = {} - override def getColumnarBatch( - sparkTypes: Array[DataType]): ColumnarBatch = null - override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { body(null) } - override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { body(null) } + test("close updates catalog") { + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + val deviceHandle = SpillableBuffer(DeviceMemoryBuffer.allocate(1234), -1) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + deviceHandle.close() + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) } } diff --git a/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index ab303d8098e..7c73de691c6 100644 --- a/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -39,7 +39,7 @@ import java.util.concurrent.Executor import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer} -import com.nvidia.spark.rapids.{GpuColumnVector, MetaUtils, RapidsBufferHandle, RapidsConf, RapidsDeviceMemoryStore, RmmSparkRetrySuiteBase, ShuffleMetadata, ShuffleReceivedBufferCatalog} +import com.nvidia.spark.rapids.{GpuColumnVector, MetaUtils, RapidsConf, RapidsShuffleHandle, RmmSparkRetrySuiteBase, ShuffleMetadata, ShuffleReceivedBufferCatalog} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.TableMeta import org.mockito.ArgumentCaptor @@ -79,7 +79,6 @@ abstract class RapidsShuffleTestHelper var mockCopyExecutor: Executor = _ var mockBssExecutor: Executor = _ var mockHandler: RapidsShuffleFetchHandler = _ - var mockStorage: RapidsDeviceMemoryStore = _ var mockCatalog: ShuffleReceivedBufferCatalog = _ var mockConf: RapidsConf = _ var testMetricsUpdater: TestShuffleMetricsUpdater = _ @@ -160,11 +159,11 @@ abstract class RapidsShuffleTestHelper testMetricsUpdater = spy(new TestShuffleMetricsUpdater) val dmbCaptor = ArgumentCaptor.forClass(classOf[DeviceMemoryBuffer]) - when(mockCatalog.addBuffer(dmbCaptor.capture(), any(), any(), any())) + when(mockCatalog.addBuffer(dmbCaptor.capture(), any(), any())) .thenAnswer(_ => { val buffer = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer] buffersToClose.append(buffer) - mock[RapidsBufferHandle] + mock[RapidsShuffleHandle] }) client = spy(new RapidsShuffleClient( diff --git a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index b5f0674ca3f..a8962e7fe59 100644 --- a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -35,7 +35,7 @@ import scala.collection import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer} -import com.nvidia.spark.rapids.{GpuColumnVector, MetaUtils, RapidsBufferHandle, RapidsConf, RapidsDeviceMemoryStore, RmmSparkRetrySuiteBase, ShuffleMetadata, ShuffleReceivedBufferCatalog} +import com.nvidia.spark.rapids.{GpuColumnVector, MetaUtils, RapidsConf, RapidsShuffleHandle, RmmSparkRetrySuiteBase, ShuffleMetadata, ShuffleReceivedBufferCatalog} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.TableMeta import org.mockito.ArgumentCaptor @@ -75,7 +75,6 @@ abstract class RapidsShuffleTestHelper var mockCopyExecutor: Executor = _ var mockBssExecutor: Executor = _ var mockHandler: RapidsShuffleFetchHandler = _ - var mockStorage: RapidsDeviceMemoryStore = _ var mockCatalog: ShuffleReceivedBufferCatalog = _ var mockConf: RapidsConf = _ var testMetricsUpdater: TestShuffleMetricsUpdater = _ @@ -156,11 +155,11 @@ abstract class RapidsShuffleTestHelper testMetricsUpdater = spy(new TestShuffleMetricsUpdater) val dmbCaptor = ArgumentCaptor.forClass(classOf[DeviceMemoryBuffer]) - when(mockCatalog.addBuffer(dmbCaptor.capture(), any(), any(), any())) + when(mockCatalog.addBuffer(dmbCaptor.capture(), any(), any())) .thenAnswer(_ => { val buffer = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer] buffersToClose.append(buffer) - mock[RapidsBufferHandle] + mock[RapidsShuffleHandle] }) client = spy(new RapidsShuffleClient( @@ -285,4 +284,3 @@ class MockClientConnection(mockTransaction: Transaction) extends ClientConnectio override def registerReceiveHandler(messageType: MessageType.Value): Unit = {} } - From 4cde90535214c3eff7c7ecab240c8cb9b9972934 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Fri, 22 Nov 2024 07:00:05 -0800 Subject: [PATCH 02/36] code review comment changes --- .../com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala | 2 +- .../main/scala/com/nvidia/spark/rapids/SpillFramework.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala index 0e415ca06dc..7a19d3ece5d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala @@ -124,7 +124,7 @@ class DeviceMemoryEventHandler( val amountSpilled = store.spill(allocSize) retryState.resetIfNeeded(retryCount, amountSpilled > 0) logInfo(s"Device allocation of $allocSize bytes failed. " + - s"Device store spilled $amountSpilled bytes. $attemptMsg " + + s"Device store spilled $amountSpilled bytes. $attemptMsg" + s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.") val shouldRetry = if (amountSpilled == 0) { if (retryState.shouldTrySynchronizing(retryCount)) { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala index aa529e83bbf..9773f6a850f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala @@ -84,7 +84,7 @@ import org.apache.spark.storage.BlockId * * Every store handle supports a `materialize` method that isn't part of the interface. * The reason is that to materialize certain objects, you may need some data (for example, - * spark schema descriptors). `materialize` incRefCounts the object if it's resident in the + * Spark schema descriptors). `materialize` incRefCounts the object if it's resident in the * intended store (`DeviceSpillableHandle` incRefCounts an object if it is in the device store), * and otherwise it will create a new copy from the spilled version and hand it to the user. * Any time a user calls `materialize`, they are responsible for closing the returned object. @@ -120,8 +120,8 @@ import org.apache.spark.storage.BlockId * * We hold the handle lock when we are spilling (performing IO). That means that no other consumer * can access this spillable device handle while it is being spilled, including a second thread - * trying to spill and asking each of the handles wether they are spillable or not, as that requires - * the handle lock. We will relax this likely in follow on work. + * trying to spill and asking each of the handles whether they are spillable or not, as that + * requires the handle lock. We will relax this likely in follow on work. * * We never hold a store-wide coarse grain lock in the stores when we do IO. */ From 2619cab167bcab92e29b2fd237e30b8053cedf19 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Fri, 22 Nov 2024 07:26:09 -0800 Subject: [PATCH 03/36] private[spill] --- .../rapids/DeviceMemoryEventHandler.scala | 1 + .../spark/rapids/GpuDeviceManager.scala | 1 + .../com/nvidia/spark/rapids/HostAlloc.scala | 1 + .../spark/rapids/ShuffleBufferCatalog.scala | 1 + .../rapids/ShuffleReceivedBufferCatalog.scala | 1 + .../spark/rapids/SpillableColumnarBatch.scala | 1 + .../rapids/{ => spill}/SpillFramework.scala | 62 ++++++++++--------- .../nvidia/spark/rapids/HostAllocSuite.scala | 1 + .../DeviceMemoryEventHandlerSuite.scala | 1 + ...ternalRowToCudfRowIteratorRetrySuite.scala | 1 + .../spark/rapids/RmmSparkRetrySuiteBase.scala | 1 + .../spark/rapids/SerializationSuite.scala | 1 + .../nvidia/spark/rapids/WithRetrySuite.scala | 1 + .../shuffle/RapidsShuffleIteratorSuite.scala | 3 +- .../shuffle/RapidsShuffleServerSuite.scala | 15 ++++- .../{ => spill}/SpillFrameworkSuite.scala | 7 ++- .../rapids/GpuFileFormatDataWriterSuite.scala | 3 +- .../rapids/SpillableColumnarBatchSuite.scala | 3 +- 18 files changed, 66 insertions(+), 39 deletions(-) rename sql-plugin/src/main/scala/com/nvidia/spark/rapids/{ => spill}/SpillFramework.scala (96%) rename tests/src/test/scala/com/nvidia/spark/rapids/{ => spill}/SpillFrameworkSuite.scala (99%) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala index 7a19d3ece5d..3f450e075da 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala @@ -21,6 +21,7 @@ import java.lang.management.ManagementFactory import java.util.concurrent.atomic.AtomicLong import ai.rapids.cudf.{Cuda, Rmm, RmmEventHandler} +import com.nvidia.spark.rapids.spill.SpillableDeviceStore import com.sun.management.HotSpotDiagnosticMXBean import org.apache.spark.internal.Logging diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala index f71ae813010..42776a6cab0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import ai.rapids.cudf._ import com.nvidia.spark.rapids.jni.RmmSpark +import com.nvidia.spark.rapids.spill.SpillFramework import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala index ea61d640906..80e1e7175fc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{DefaultHostMemoryAllocator, HostMemoryAllocator, HostMemoryBuffer, MemoryBuffer, PinnedMemoryPool} import com.nvidia.spark.rapids.jni.{CpuRetryOOM, RmmSpark} +import com.nvidia.spark.rapids.spill.SpillFramework import org.apache.spark.internal.Logging diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index 6c17a29fdc9..bf47e97e683 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, Table} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.spill.{SpillableDeviceBufferHandle, SpillableHandle} import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala index 87c8deb2c5d..46ca800053a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{DeviceMemoryBuffer, Table} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle import org.apache.spark.internal.Logging import org.apache.spark.sql.types.DataType diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index fc10477520b..60a0f5f55a5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer} import com.nvidia.spark.rapids.Arm.closeOnExcept +import com.nvidia.spark.rapids.spill.{SpillableColumnarBatchFromBufferHandle, SpillableColumnarBatchHandle, SpillableCompressedColumnarBatchHandle, SpillableDeviceBufferHandle, SpillableHostBufferHandle, SpillableHostColumnarBatchHandle} import org.apache.spark.TaskContext import org.apache.spark.sql.types.DataType diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala similarity index 96% rename from sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala rename to sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 9773f6a850f..58f39ec56ee 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -14,9 +14,9 @@ * limitations under the License. */ -package com.nvidia.spark.rapids +package com.nvidia.spark.rapids.spill -import java.io.{DataOutputStream, File, FileInputStream, InputStream, OutputStream} +import java.io._ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel, WritableByteChannel} import java.nio.file.StandardOpenOption @@ -26,10 +26,12 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable -import ai.rapids.cudf.{ColumnVector, ContiguousTable, Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange, Table} +import ai.rapids.cudf._ +import com.nvidia.spark.rapids.{GpuColumnVector, GpuColumnVectorFromBuffer, GpuCompressedColumnVector, GpuDeviceManager, HostAlloc, HostMemoryOutputStream, MemoryBufferToHostByteBufferIterator, RapidsConf, RapidsHostColumnVector} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableSeq import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.internal.HostByteBufferIterator import org.apache.commons.io.IOUtils import org.apache.spark.{SparkConf, SparkEnv, TaskContext} @@ -154,7 +156,7 @@ trait SpillableHandle extends StoreHandle { * is a `val`. * @return true if currently spillable, false otherwise */ - def spillable: Boolean = sizeInBytes > 0 + private[spill] def spillable: Boolean = sizeInBytes > 0 } /** @@ -163,9 +165,9 @@ trait SpillableHandle extends StoreHandle { * on the device. */ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { - var dev: Option[T] + private[spill] var dev: Option[T] - override def spillable: Boolean = synchronized { + private[spill] override def spillable: Boolean = synchronized { super.spillable && dev.isDefined } } @@ -176,9 +178,9 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { * on the host. */ trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle { - var host: Option[T] + private[spill] var host: Option[T] - override def spillable: Boolean = synchronized { + private[spill] override def spillable: Boolean = synchronized { super.spillable && host.isDefined } } @@ -194,7 +196,7 @@ object SpillableHostBufferHandle extends Logging { handle } - def createHostHandleWithPacker( + private[spill] def createHostHandleWithPacker( chunkedPacker: ChunkedPacker): SpillableHostBufferHandle = { val handle = new SpillableHostBufferHandle(chunkedPacker.getTotalContiguousSize) withResource( @@ -213,7 +215,7 @@ object SpillableHostBufferHandle extends Logging { } } - def createHostHandleFromDeviceBuff( + private[spill] def createHostHandleFromDeviceBuff( buff: DeviceMemoryBuffer): SpillableHostBufferHandle = { val handle = new SpillableHostBufferHandle(buff.getLength) withResource( @@ -227,11 +229,11 @@ object SpillableHostBufferHandle extends Logging { class SpillableHostBufferHandle private ( override val sizeInBytes: Long, - override var host: Option[HostMemoryBuffer] = None, - var disk: Option[DiskHandle] = None) + private[spill] override var host: Option[HostMemoryBuffer] = None, + private[spill] var disk: Option[DiskHandle] = None) extends HostSpillableHandle[HostMemoryBuffer] { - override def spillable: Boolean = synchronized { + private[spill] override def spillable: Boolean = synchronized { if (super.spillable) { host.getOrElse { throw new IllegalStateException( @@ -307,7 +309,7 @@ class SpillableHostBufferHandle private ( } } - def getHostBuffer: Option[HostMemoryBuffer] = synchronized { + private[spill] def getHostBuffer: Option[HostMemoryBuffer] = synchronized { host.foreach(_.incRefCount()) host } @@ -320,7 +322,7 @@ class SpillableHostBufferHandle private ( } } - def materializeToDeviceMemoryBuffer(dmb: DeviceMemoryBuffer): Unit = { + private[spill] def materializeToDeviceMemoryBuffer(dmb: DeviceMemoryBuffer): Unit = { var hostBuffer: HostMemoryBuffer = null var diskHandle: DiskHandle = null synchronized { @@ -351,11 +353,11 @@ class SpillableHostBufferHandle private ( } } - def setHost(singleShotBuffer: HostMemoryBuffer): Unit = synchronized { + private[spill] def setHost(singleShotBuffer: HostMemoryBuffer): Unit = synchronized { host = Some(singleShotBuffer) } - def setDisk(handle: DiskHandle): Unit = synchronized { + private[spill] def setDisk(handle: DiskHandle): Unit = synchronized { disk = Some(handle) } } @@ -370,11 +372,11 @@ object SpillableDeviceBufferHandle { class SpillableDeviceBufferHandle private ( override val sizeInBytes: Long, - override var dev: Option[DeviceMemoryBuffer], - var host: Option[SpillableHostBufferHandle] = None) + private[spill] override var dev: Option[DeviceMemoryBuffer], + private[spill] var host: Option[SpillableHostBufferHandle] = None) extends DeviceSpillableHandle[DeviceMemoryBuffer] { - override def spillable: Boolean = synchronized { + private[spill] override def spillable: Boolean = synchronized { if (super.spillable) { dev.getOrElse { throw new IllegalStateException( @@ -445,8 +447,8 @@ class SpillableDeviceBufferHandle private ( class SpillableColumnarBatchHandle private ( override val sizeInBytes: Long, - override var dev: Option[ColumnarBatch], - var host: Option[SpillableHostBufferHandle] = None) + private[spill] override var dev: Option[ColumnarBatch], + private[spill] var host: Option[SpillableHostBufferHandle] = None) extends DeviceSpillableHandle[ColumnarBatch] with Logging { override def spillable: Boolean = synchronized { @@ -573,13 +575,13 @@ object SpillableColumnarBatchFromBufferHandle { class SpillableColumnarBatchFromBufferHandle private ( override val sizeInBytes: Long, - override var dev: Option[ColumnarBatch], - var host: Option[SpillableHostBufferHandle] = None) + private[spill] override var dev: Option[ColumnarBatch], + private[spill] var host: Option[SpillableHostBufferHandle] = None) extends DeviceSpillableHandle[ColumnarBatch] { private var meta: Option[TableMeta] = None - override def spillable: Boolean = synchronized { + private[spill] override def spillable: Boolean = synchronized { if (super.spillable) { val dcvs = GpuColumnVector.extractBases(dev.get) val colRepetition = mutable.HashMap[ColumnVector, Int]() @@ -670,8 +672,8 @@ object SpillableCompressedColumnarBatchHandle { class SpillableCompressedColumnarBatchHandle private ( val compressedSizeInBytes: Long, - override var dev: Option[ColumnarBatch], - var host: Option[SpillableHostBufferHandle] = None) + private[spill] override var dev: Option[ColumnarBatch], + private[spill] var host: Option[SpillableHostBufferHandle] = None) extends DeviceSpillableHandle[ColumnarBatch] { override val sizeInBytes: Long = compressedSizeInBytes @@ -760,8 +762,8 @@ object SpillableHostColumnarBatchHandle { class SpillableHostColumnarBatchHandle private ( val sizeInBytes: Long, val numRows: Int, - override var host: Option[ColumnarBatch], - var disk: Option[DiskHandle] = None) + private[spill] override var host: Option[ColumnarBatch], + private[spill] var disk: Option[DiskHandle] = None) extends HostSpillableHandle[ColumnarBatch] { override def spillable: Boolean = synchronized { @@ -1116,7 +1118,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) HostAlloc.tryAlloc(handle.sizeInBytes).foreach { hmb => withResource(hmb) { _ => if (trackInternal(handle)) { - hmb.incRefCount + hmb.incRefCount() // the host store made room or fit this buffer builder = Some(new SpillableHostBufferHandleBuilderForHost(handle, hmb)) } diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala index fb990b99d3b..fbd61b9a7fb 100644 --- a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/HostAllocSuite.scala @@ -21,6 +21,7 @@ import java.util.concurrent.{ExecutionException, Future, LinkedBlockingQueue, Ti import ai.rapids.cudf.{HostMemoryBuffer, PinnedMemoryPool, Rmm, RmmAllocationMode} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.jni.{RmmSpark, RmmSparkThreadState} +import com.nvidia.spark.rapids.spill._ import org.mockito.Mockito.when import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.concurrent.{Signaler, TimeLimits} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala index 8045691b504..9ba0147d878 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala @@ -16,6 +16,7 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.spill.SpillableDeviceStore import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.when import org.scalatestplus.mockito.MockitoSugar diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala index 202c12eb774..8e1f404119c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.Table import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.jni.{GpuSplitAndRetryOOM, RmmSpark} +import com.nvidia.spark.rapids.spill.{SpillableColumnarBatchHandle, SpillableDeviceStore, SpillFramework} import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{doAnswer, spy} import org.mockito.invocation.InvocationOnMock diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala index 0cc017f3792..24d5984f749 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{Rmm, RmmAllocationMode, RmmEventHandler} import com.nvidia.spark.rapids.jni.RmmSpark +import com.nvidia.spark.rapids.spill.SpillFramework import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala index 545aef2915a..c6c251aeabb 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, import ai.rapids.cudf.{Rmm, RmmAllocationMode, Table} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.spill.SpillFramework import org.apache.commons.lang3.SerializationUtils import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala index d617726709d..0f37b7566d9 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala @@ -21,6 +21,7 @@ import com.nvidia.spark.Retryable import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitTargetSizeInHalfGpu, withRestoreOnRetry, withRetry, withRetryNoSplit} import com.nvidia.spark.rapids.jni.{GpuRetryOOM, GpuSplitAndRetryOOM, RmmSpark} +import com.nvidia.spark.rapids.spill.SpillFramework import org.mockito.Mockito._ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala index bf379f02937..4962d3fa219 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala @@ -16,8 +16,9 @@ package com.nvidia.spark.rapids.shuffle -import com.nvidia.spark.rapids.{RapidsShuffleHandle, SpillableDeviceBufferHandle} +import com.nvidia.spark.rapids.RapidsShuffleHandle import com.nvidia.spark.rapids.jni.RmmSpark +import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala index ccc63a597fd..f564089c2c9 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala @@ -20,9 +20,10 @@ import java.io.IOException import java.nio.ByteBuffer import ai.rapids.cudf.{DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.{MetaUtils, RapidsShuffleHandle, ShuffleMetadata, SpillableDeviceBufferHandle} +import com.nvidia.spark.rapids.{MetaUtils, RapidsShuffleHandle, ShuffleMetadata} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle import org.mockito.{ArgumentCaptor, ArgumentMatchers} import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ @@ -233,7 +234,11 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { // acquire once at the beginning, and closed at the end verify(mockRequestHandler, times(1)) .getShuffleHandle(ArgumentMatchers.eq(1)) - assertResult(1)(rapidsBuffer.spillable.dev.get.getRefCount) + withResource(rapidsBuffer.spillable.materialize) { dmb => + // refcount=2 because it was on the device, and we +1 to materialize. + // but it shows no leaks. + assertResult(2)(dmb.getRefCount) + } } } } @@ -427,7 +432,11 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { // this handle materializes, so make sure we close it verify(rapidsHandle2.spillable, times(1)).materialize - verify(rapidsHandle2.spillable.dev.get, times(1)).close() + withResource(rapidsHandle2.spillable.materialize) { dmb => + // refcount=2 because it was on the device, and we +1 to materialize. + // but it shows no leaks. + assertResult(2)(dmb.getRefCount) + } } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala similarity index 99% rename from tests/src/test/scala/com/nvidia/spark/rapids/SpillFrameworkSuite.scala rename to tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala index 2d92785078c..747894f9b08 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SpillFrameworkSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala @@ -14,12 +14,13 @@ * limitations under the License. */ -package com.nvidia.spark.rapids +package com.nvidia.spark.rapids.spill import java.io.File import java.math.RoundingMode -import ai.rapids.cudf.{ColumnVector, ContiguousTable, Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, Table} +import ai.rapids.cudf._ +import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.format.CodecType import org.mockito.Mockito.when @@ -28,7 +29,7 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.SparkConf import org.apache.spark.sql.rapids.RapidsDiskBlockManager -import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch class SpillFrameworkSuite diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala index 54866a19556..033173468fd 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala @@ -16,9 +16,10 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{Rmm, RmmAllocationMode, TableWriter} -import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuColumnVector, GpuLiteral, RapidsConf, ScalableTaskCompletion, SpillFramework} +import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuColumnVector, GpuLiteral, RapidsConf, ScalableTaskCompletion} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.jni.{GpuRetryOOM, GpuSplitAndRetryOOM} +import com.nvidia.spark.rapids.spill.SpillFramework import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FSDataOutputStream import org.apache.hadoop.mapred.TaskAttemptContext diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala index 3c4be17fdf9..9b7a1dbdddd 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.DeviceMemoryBuffer -import com.nvidia.spark.rapids.{RapidsConf, SpillableBuffer, SpillFramework} +import com.nvidia.spark.rapids.{RapidsConf, SpillableBuffer} +import com.nvidia.spark.rapids.spill.SpillFramework import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite From b48b801201ea43254b5e3f9550656f6fa7c4adbe Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Fri, 22 Nov 2024 07:31:01 -0800 Subject: [PATCH 04/36] remove extra sync, and make sure copyNext is always synchronous with the cuda stream --- .../com/nvidia/spark/rapids/spill/SpillFramework.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 58f39ec56ee..5e62e908d53 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -205,9 +205,8 @@ object SpillableHostBufferHandle extends Logging { while (chunkedPacker.hasNext) { withResource(chunkedPacker.next(bb)) { n => builder.copyNext(n, Cuda.DEFAULT_STREAM) - // we are calling chunked packer on `bb` again each time, we need - // to synchronize before we ask for the next chunk - Cuda.DEFAULT_STREAM.sync() + // copyNext is synchronous w.r.t. the cuda stream passed, + // no need to synchronize here. } } } @@ -221,7 +220,6 @@ object SpillableHostBufferHandle extends Logging { withResource( SpillFramework.stores.hostStore.makeBuilder(handle)) { builder => builder.copyNext(buff, Cuda.DEFAULT_STREAM) - Cuda.DEFAULT_STREAM.sync() builder.build } } @@ -1163,7 +1161,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) override def copyNext(mb: DeviceMemoryBuffer, stream: Cuda.Stream): Unit = { GpuTaskMetrics.get.spillToHostTime { - singleShotBuffer.copyFromMemoryBufferAsync( + singleShotBuffer.copyFromMemoryBuffer( copied, mb, 0, From 8c42a01e30f42fd31aa8c7274a5361bdd6872f7c Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Fri, 22 Nov 2024 07:32:15 -0800 Subject: [PATCH 05/36] private getHostBuffer --- .../scala/com/nvidia/spark/rapids/spill/SpillFramework.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 5e62e908d53..d16a4ad34d1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -307,7 +307,7 @@ class SpillableHostBufferHandle private ( } } - private[spill] def getHostBuffer: Option[HostMemoryBuffer] = synchronized { + private def getHostBuffer: Option[HostMemoryBuffer] = synchronized { host.foreach(_.incRefCount()) host } From a53f10b7b68c7381330e3bfdf2386eaa2ac32c95 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Fri, 22 Nov 2024 07:34:32 -0800 Subject: [PATCH 06/36] bring back comment on ownership --- .../src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala index 771d91ddada..c4584086173 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala @@ -307,6 +307,7 @@ class LazySpillableColumnarBatchImpl( spill = Some(SpillableColumnarBatch(cached.get, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) } finally { + // Putting data in a SpillableColumnarBatch takes ownership of it. cached = None } } From 433746be3fc8e37c271e058ed54b6931f997b642 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Fri, 22 Nov 2024 10:48:49 -0800 Subject: [PATCH 07/36] remove comment, as metrics are being fixed in a different pr --- .../scala/com/nvidia/spark/rapids/spill/SpillFramework.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index d16a4ad34d1..5b571500e91 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -1271,7 +1271,6 @@ class DiskHandleStore(conf: SparkConf) } } - // TODO: AB: yeah GpuTaskMetrics are messed up override def track(handle: DiskHandle): Unit = { // protects the off chance that someone adds this handle twice.. if (doTrack(handle)) { From 116478faa8d8694a6fa6aeb929348fa16f85ebb8 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 2 Dec 2024 09:28:41 -0800 Subject: [PATCH 08/36] Upmerge --- .../nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala | 2 +- .../com/nvidia/spark/rapids/SpillableColumnarBatch.scala | 4 +--- .../scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala | 2 -- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala index 89cb1403cfb..063d8877b50 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala @@ -1132,7 +1132,7 @@ class KudoSpillableHostConcatResult(kudoTableHeader: KudoTableHeader, require(kudoTableHeader != null, "KudoTableHeader cannot be null") require(hmb != null, "HostMemoryBuffer cannot be null") - override def toBatch: ColumnarBatch = closeOnExcept(buffer.getHostBuffer()) { hostBuf => + override def toBatch: ColumnarBatch = closeOnExcept(buffer.getHostBuffer) { hostBuf => KudoSerializedTableColumn.from(new KudoTable(kudoTableHeader, hostBuf)) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index 677f41c21a2..1d0c4c8b68b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -445,9 +445,7 @@ class SpillableBuffer( } override def toString: String = { - val size = withResource(RapidsBufferCatalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.memoryUsedBytes - } + val size = handle.sizeInBytes s"SpillableBuffer size:$size, handle:$handle" } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala index 3d72e5b438f..a8892c85194 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala @@ -145,7 +145,6 @@ class GpuTaskMetrics extends Serializable { GpuTaskMetrics.decHostBytesAllocated(bytes) } - def incDiskBytesAllocated(bytes: Long): Unit = { GpuTaskMetrics.incDiskBytesAllocated(bytes) maxDiskBytesAllocated = maxDiskBytesAllocated.max(GpuTaskMetrics.diskBytesAllocated) @@ -153,7 +152,6 @@ class GpuTaskMetrics extends Serializable { def decDiskBytesAllocated(bytes: Long): Unit = { GpuTaskMetrics.decHostBytesAllocated(bytes) ->>>>>>> nvidia/branch-25.02 } private val metrics = Map[String, AccumulatorV2[_, _]]( From 513d084edeba807ca4aca82d94fcd7a7c70e22db Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 2 Dec 2024 11:59:19 -0800 Subject: [PATCH 09/36] ChunkedPacker as a true iterator using a pool --- .../spark/rapids/spill/SpillFramework.scala | 176 ++++++++++++------ 1 file changed, 120 insertions(+), 56 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 5b571500e91..a3734186b4f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -201,13 +201,12 @@ object SpillableHostBufferHandle extends Logging { val handle = new SpillableHostBufferHandle(chunkedPacker.getTotalContiguousSize) withResource( SpillFramework.stores.hostStore.makeBuilder(handle)) { builder => - SpillFramework.withChunkedPackBounceBuffer { bb => - while (chunkedPacker.hasNext) { - withResource(chunkedPacker.next(bb)) { n => - builder.copyNext(n, Cuda.DEFAULT_STREAM) - // copyNext is synchronous w.r.t. the cuda stream passed, - // no need to synchronize here. - } + while (chunkedPacker.hasNext) { + val (bb, len) = chunkedPacker.next() + withResource(bb) { _ => + builder.copyNext(bb.dmb, len, Cuda.DEFAULT_STREAM) + // copyNext is synchronous w.r.t. the cuda stream passed, + // no need to synchronize here. } } builder.build @@ -219,7 +218,7 @@ object SpillableHostBufferHandle extends Logging { val handle = new SpillableHostBufferHandle(buff.getLength) withResource( SpillFramework.stores.hostStore.makeBuilder(handle)) { builder => - builder.copyNext(buff, Cuda.DEFAULT_STREAM) + builder.copyNext(buff, buff.getLength, Cuda.DEFAULT_STREAM) builder.build } } @@ -522,7 +521,7 @@ class SpillableColumnarBatchHandle private ( GpuColumnVector.from(dev.get) } withResource(tbl) { _ => - withResource(new ChunkedPacker(tbl, SpillFramework.chunkedPackBounceBufferSize)) { packer => + withResource(new ChunkedPacker(tbl, SpillFramework.chunkedPackBounceBufferPool)) { packer => body(packer) } } @@ -1131,15 +1130,16 @@ class SpillableHostStore(val maxSize: Option[Long] = None) trait SpillableHostBufferHandleBuilder extends AutoCloseable { /** - * Copy the entirety of `mb` (0 to getLength) to host or disk. + * Copy `mb` from offset 0 to len to host or disk. * * We synchronize after each copy since we do not manage the lifetime * of `mb`. * * @param mb buffer to copy from + * @param len the amount of bytes that should be copied from `mb` * @param stream CUDA stream to use, and synchronize against */ - def copyNext(mb: DeviceMemoryBuffer, stream: Cuda.Stream): Unit + def copyNext(mb: DeviceMemoryBuffer, len: Long, stream: Cuda.Stream): Unit /** * Returns a usable `SpillableHostBufferHandle` with either the @@ -1159,15 +1159,15 @@ class SpillableHostStore(val maxSize: Option[Long] = None) extends SpillableHostBufferHandleBuilder with Logging { private var copied = 0L - override def copyNext(mb: DeviceMemoryBuffer, stream: Cuda.Stream): Unit = { + override def copyNext(mb: DeviceMemoryBuffer, len: Long, stream: Cuda.Stream): Unit = { GpuTaskMetrics.get.spillToHostTime { singleShotBuffer.copyFromMemoryBuffer( copied, mb, 0, - mb.getLength, + len, stream) - copied += mb.getLength + copied += len } } @@ -1201,22 +1201,24 @@ class SpillableHostStore(val maxSize: Option[Long] = None) private var copied = 0L private var diskHandleBuilder = DiskHandleStore.makeBuilder - override def copyNext(mb: DeviceMemoryBuffer, stream: Cuda.Stream): Unit = { + override def copyNext(mb: DeviceMemoryBuffer, len: Long, stream: Cuda.Stream): Unit = { SpillFramework.withHostSpillBounceBuffer { hostSpillBounceBuffer => GpuTaskMetrics.get.spillToDiskTime { val outputChannel = diskHandleBuilder.getChannel - val iter = new MemoryBufferToHostByteBufferIterator( - mb, - hostSpillBounceBuffer, - Cuda.DEFAULT_STREAM) - iter.foreach { byteBuff => - try { - while (byteBuff.hasRemaining) { - outputChannel.write(byteBuff) + withResource(mb.slice(0, len)) { slice => + val iter = new MemoryBufferToHostByteBufferIterator( + slice, + hostSpillBounceBuffer, + Cuda.DEFAULT_STREAM) + iter.foreach { byteBuff => + try { + while (byteBuff.hasRemaining) { + outputChannel.write(byteBuff) + } + copied += byteBuff.capacity() + } finally { + RapidsStorageUtils.dispose(byteBuff) } - copied += byteBuff.capacity() - } finally { - RapidsStorageUtils.dispose(byteBuff) } } } @@ -1422,12 +1424,8 @@ object SpillFramework extends Logging { storesInternal } - // used by the chunked packer to construct the cuDF-side packer object - var chunkedPackBounceBufferSize: Long = -1L - // TODO: these should be pools, instead of individual buffers private var hostSpillBounceBuffer: HostMemoryBuffer = _ - private var chunkedPackBounceBuffer: DeviceMemoryBuffer = _ private lazy val conf: SparkConf = { val env = SparkEnv.get @@ -1449,10 +1447,22 @@ object SpillFramework extends Logging { } else { Some(rapidsConf.hostSpillStorageSize) } - chunkedPackBounceBufferSize = rapidsConf.chunkedPackBounceBufferSize // this should hopefully be pinned, but it would work without hostSpillBounceBuffer = HostMemoryBuffer.allocate(rapidsConf.spillToDiskBounceBufferSize) - chunkedPackBounceBuffer = DeviceMemoryBuffer.allocate(rapidsConf.chunkedPackBounceBufferSize) + + chunkedPackBounceBufferPool = new DeviceBounceBufferPool { + private val bounceBuffer: DeviceBounceBuffer = + DeviceBounceBuffer(DeviceMemoryBuffer.allocate(rapidsConf.chunkedPackBounceBufferSize)) + override def bufferSize: Long = rapidsConf.chunkedPackBounceBufferSize + override def nextBuffer(): DeviceBounceBuffer = { + // can block waiting for bounceBuffer to be released + bounceBuffer.acquire() + } + override def close(): Unit = { + // this closes the DeviceMemoryBuffer wrapped by the bounce buffer class + bounceBuffer.release() + } + } storesInternal = new SpillableStores { override var deviceStore: SpillableDeviceStore = new SpillableDeviceStore override var hostStore: SpillableHostStore = new SpillableHostStore(hostSpillStorageSize) @@ -1467,9 +1477,9 @@ object SpillFramework extends Logging { hostSpillBounceBuffer.close() hostSpillBounceBuffer = null } - if (chunkedPackBounceBuffer != null) { - chunkedPackBounceBuffer.close() - chunkedPackBounceBuffer = null + if (chunkedPackBounceBufferPool != null) { + chunkedPackBounceBufferPool.close() + chunkedPackBounceBufferPool = null } if (storesInternal != null) { storesInternal.close() @@ -1482,10 +1492,7 @@ object SpillFramework extends Logging { body(hostSpillBounceBuffer) } - def withChunkedPackBounceBuffer[T](body: DeviceMemoryBuffer => T): T = - chunkedPackBounceBuffer.synchronized { - body(chunkedPackBounceBuffer) - } + var chunkedPackBounceBufferPool: DeviceBounceBufferPool = null def removeFromDeviceStore(value: SpillableHandle): Unit = { // if the stores have already shut down, we don't want to create them here @@ -1510,6 +1517,65 @@ object SpillFramework extends Logging { } maybeStores.foreach(_.remove(value)) } + +} + +/** + * A bounce buffer wrapper class that supports the concept of acquisition. + * + * The bounce buffer is acquired exclusively. So any calls to acquire while the + * buffer is in use will block at `acquire`. Calls to `release` notify the blocked + * threads, and they will check to see if they can acquire. + * + * `close` is the interface to unacquire the bounce buffer. + * + * `release` actually closes the underlying DeviceMemoryBuffer, and should be called + * once at the end of the lifetime of the executor. + * + * @param dmb - actual cudf DeviceMemoryBuffer that this class is protecting. + */ +private[spill] case class DeviceBounceBuffer(var dmb: DeviceMemoryBuffer) extends AutoCloseable { + private var acquired: Boolean = false + def acquire(): DeviceBounceBuffer = synchronized { + while (acquired) { + wait() + } + acquired = true + this + } + + private def unaquire(): Unit = synchronized { + acquired = false + notifyAll() + } + + override def close(): Unit = { + unaquire() + } + + def release(): Unit = synchronized { + if (acquired) { + throw new IllegalStateException( + "closing device buffer pool, but some bounce buffers are in use.") + } + if (dmb != null) { + dmb.close() + dmb = null + } + } +} + +/** + * A bounce buffer pool with buffers of size `bufferSize` + * + * This pool returns instances of `DeviceBounceBuffer`, that should + * be closed in order to be reused. + * + * Callers should synchronize before calling close on their `DeviceMemoryBuffer`s. + */ +trait DeviceBounceBufferPool extends AutoCloseable { + def bufferSize: Long + def nextBuffer(): DeviceBounceBuffer } /** @@ -1524,12 +1590,11 @@ object SpillFramework extends Logging { * associated with it. * * @param table cuDF Table to chunk_pack - * @param bounceBufferSize size of GPU memory to be used for packing. The buffer will be - * obtained during the iterator-like 'next(DeviceMemoryBuffer)' + * @param bounceBufferPool bounce buffer pool to use during the lifetime of this packer. */ class ChunkedPacker(table: Table, - bounceBufferSize: Long) - extends Logging with AutoCloseable { + bounceBufferPool: DeviceBounceBufferPool) + extends Iterator[(DeviceBounceBuffer, Long)] with Logging with AutoCloseable { private var closed: Boolean = false @@ -1539,7 +1604,7 @@ class ChunkedPacker(table: Table, val pool = GpuDeviceManager.chunkedPackMemoryResource val cudfChunkedPack = try { pool.flatMap { chunkedPool => - Some(table.makeChunkedPack(bounceBufferSize, chunkedPool)) + Some(table.makeChunkedPack(bounceBufferPool.bufferSize, chunkedPool)) } } catch { case _: OutOfMemoryError => @@ -1554,7 +1619,7 @@ class ChunkedPacker(table: Table, // if the pool is not configured, or we got an OOM, try again with the per-device pool cudfChunkedPack.getOrElse { - table.makeChunkedPack(bounceBufferSize) + table.makeChunkedPack(bounceBufferPool.bufferSize) } } @@ -1572,24 +1637,23 @@ class ChunkedPacker(table: Table, packedMeta } - def hasNext: Boolean = { + override def hasNext: Boolean = { if (closed) { throw new IllegalStateException(s"ChunkedPacker is closed") } chunkedPack.hasNext } - def next(bounceBuffer: DeviceMemoryBuffer): DeviceMemoryBuffer = { - require(bounceBuffer.getLength == bounceBufferSize, - s"Bounce buffer ${bounceBuffer} doesn't match size ${bounceBufferSize} B.") - - if (closed) { - throw new IllegalStateException(s"ChunkedPacker is closed") + override def next(): (DeviceBounceBuffer, Long) = { + withResource(bounceBufferPool.nextBuffer()) { bounceBuffer => + if (closed) { + throw new IllegalStateException(s"ChunkedPacker is closed") + } + val bytesWritten = chunkedPack.next(bounceBuffer.dmb) + // we increment the refcount because the caller has no idea where + // this memory came from, so it should close it. + (bounceBuffer, bytesWritten) } - val bytesWritten = chunkedPack.next(bounceBuffer) - // we increment the refcount because the caller has no idea where - // this memory came from, so it should close it. - bounceBuffer.slice(0, bytesWritten) } override def close(): Unit = { From d9490ee2f91771f68fe511bd864aac4e731cfa9e Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 2 Dec 2024 13:15:20 -0800 Subject: [PATCH 10/36] approxSizeInBytes --- .../spark/rapids/SpillableColumnarBatch.scala | 6 +- .../spark/rapids/spill/SpillFramework.scala | 57 +++++++++++-------- ...ternalRowToCudfRowIteratorRetrySuite.scala | 2 +- .../rapids/spill/SpillFrameworkSuite.scala | 47 ++++++++------- 4 files changed, 62 insertions(+), 50 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index 1d0c4c8b68b..cef85ed6691 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -106,7 +106,7 @@ class SpillableColumnarBatchImpl ( */ override def numRows(): Int = rowCount - override lazy val sizeInBytes: Long = handle.sizeInBytes + override lazy val sizeInBytes: Long = handle.approxSizeInBytes /** * Set a new spill priority. @@ -299,9 +299,7 @@ class SpillableHostColumnarBatchImpl ( */ override def numRows(): Int = rowCount - override lazy val sizeInBytes: Long = { - handle.sizeInBytes - } + override lazy val sizeInBytes: Long = handle.approxSizeInBytes /** * Set a new spill priority. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index a3734186b4f..13a83adf140 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -70,7 +70,7 @@ import org.apache.spark.storage.BlockId * Spillability: * * An object is spillable (it will be copied to host or disk during OOM) if: - * - it has a sizeInBytes > 0 + * - it has a approxSizeInBytes > 0 * - it is not actively being referenced by the user (call to `materialize`, or aliased) * - it hasn't already spilled * - it hasn't been closed @@ -133,30 +133,31 @@ import org.apache.spark.storage.BlockId */ trait StoreHandle extends AutoCloseable { /** - * Approximate size of this handle, used in two scenarios: + * Approximate size of this handle, used in three scenarios: + * - Used by callers when accumulating up to a batch size for size goals. * - Used from the host store to figure out how much host memory total it is tracking. - * - If sizeInBytes is 0, the object is tracked by the stores so it can be + * - If approxSizeInBytes is 0, the object is tracked by the stores so it can be * removed on shutdown, or by handle.close, but 0-byte handles are not spillable. */ - val sizeInBytes: Long + val approxSizeInBytes: Long } trait SpillableHandle extends StoreHandle { /** * Method called to spill this handle. It can be triggered from the spill store, * or directly against the handle. - * @return sizeInBytes if spilled, 0 for any other reason (not spillable, closed) + * @return approxSizeInBytes if spilled, 0 for any other reason (not spillable, closed) */ def spill: Long /** * Method used to determine whether a handle tracks an object that could be spilled * @note At the level of `SpillableHandle`, the only requirement of spillability - * is that the size of the handle is > 0. `sizeInBytes` is known at construction, and - * is a `val`. + * is that the size of the handle is > 0. `approxSizeInBytes` is known at + * construction, and is immutable. * @return true if currently spillable, false otherwise */ - private[spill] def spillable: Boolean = sizeInBytes > 0 + private[spill] def spillable: Boolean = approxSizeInBytes > 0 } /** @@ -225,11 +226,13 @@ object SpillableHostBufferHandle extends Logging { } class SpillableHostBufferHandle private ( - override val sizeInBytes: Long, + val sizeInBytes: Long, private[spill] override var host: Option[HostMemoryBuffer] = None, private[spill] var disk: Option[DiskHandle] = None) extends HostSpillableHandle[HostMemoryBuffer] { + override val approxSizeInBytes: Long = sizeInBytes + private[spill] override def spillable: Boolean = synchronized { if (super.spillable) { host.getOrElse { @@ -368,11 +371,13 @@ object SpillableDeviceBufferHandle { } class SpillableDeviceBufferHandle private ( - override val sizeInBytes: Long, + val sizeInBytes: Long, private[spill] override var dev: Option[DeviceMemoryBuffer], private[spill] var host: Option[SpillableHostBufferHandle] = None) extends DeviceSpillableHandle[DeviceMemoryBuffer] { + override val approxSizeInBytes: Long = sizeInBytes + private[spill] override def spillable: Boolean = synchronized { if (super.spillable) { dev.getOrElse { @@ -443,7 +448,7 @@ class SpillableDeviceBufferHandle private ( } class SpillableColumnarBatchHandle private ( - override val sizeInBytes: Long, + override val approxSizeInBytes: Long, private[spill] override var dev: Option[ColumnarBatch], private[spill] var host: Option[SpillableHostBufferHandle] = None) extends DeviceSpillableHandle[ColumnarBatch] with Logging { @@ -509,7 +514,7 @@ class SpillableColumnarBatchHandle private ( // We return the size we were created with. This is not the actual size // of this batch when it is packed, and it is used by the calling code // to figure out more or less how much did we free in the device. - sizeInBytes + approxSizeInBytes } } @@ -571,11 +576,13 @@ object SpillableColumnarBatchFromBufferHandle { } class SpillableColumnarBatchFromBufferHandle private ( - override val sizeInBytes: Long, + val sizeInBytes: Long, private[spill] override var dev: Option[ColumnarBatch], private[spill] var host: Option[SpillableHostBufferHandle] = None) extends DeviceSpillableHandle[ColumnarBatch] { + override val approxSizeInBytes: Long = sizeInBytes + private var meta: Option[TableMeta] = None private[spill] override def spillable: Boolean = synchronized { @@ -673,7 +680,7 @@ class SpillableCompressedColumnarBatchHandle private ( private[spill] var host: Option[SpillableHostBufferHandle] = None) extends DeviceSpillableHandle[ColumnarBatch] { - override val sizeInBytes: Long = compressedSizeInBytes + override val approxSizeInBytes: Long = compressedSizeInBytes protected var meta: Option[TableMeta] = None @@ -757,7 +764,7 @@ object SpillableHostColumnarBatchHandle { } class SpillableHostColumnarBatchHandle private ( - val sizeInBytes: Long, + override val approxSizeInBytes: Long, val numRows: Int, private[spill] override var host: Option[ColumnarBatch], private[spill] var disk: Option[DiskHandle] = None) @@ -823,7 +830,7 @@ class SpillableHostColumnarBatchHandle private ( } } releaseHostResource() - sizeInBytes + approxSizeInBytes } } @@ -865,9 +872,11 @@ object DiskHandle { class DiskHandle private( val blockId: BlockId, val offset: Long, - override val sizeInBytes: Long) + val sizeInBytes: Long) extends StoreHandle { + override val approxSizeInBytes: Long = sizeInBytes + private def withInputChannel[T](body: FileChannel => T): T = synchronized { val file = SpillFramework.stores.diskStore.diskBlockManager.getFile(blockId) GpuTaskMetrics.get.readSpillFromDiskTime { @@ -979,7 +988,7 @@ trait SpillableStore extends HandleStore[SpillableHandle] with Logging { while (allHandles.hasNext && amountToSpill < spillNeeded) { val handle = allHandles.next() if (handle.spillable) { - amountToSpill += handle.sizeInBytes + amountToSpill += handle.approxSizeInBytes spillables.add(handle) } } @@ -1011,24 +1020,24 @@ class SpillableHostStore(val maxSize: Option[Long] = None) private var totalSize: Long = 0L private def tryTrack(handle: SpillableHandle): Boolean = { - if (maxSize.isEmpty || handle.sizeInBytes == 0) { + if (maxSize.isEmpty || handle.approxSizeInBytes == 0) { super.doTrack(handle) // for now, keep this totalSize part, we technically // do not need to track `totalSize` if we don't have a limit synchronized { - totalSize += handle.sizeInBytes + totalSize += handle.approxSizeInBytes } true } else { synchronized { val storeMaxSize = maxSize.get - if (totalSize + handle.sizeInBytes > storeMaxSize) { + if (totalSize + handle.approxSizeInBytes > storeMaxSize) { // we want to try to make room for this buffer false } else { // it fits if (super.doTrack(handle)) { - totalSize += handle.sizeInBytes + totalSize += handle.approxSizeInBytes } true } @@ -1051,7 +1060,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) // we are going to try to track again, in a loop, // since we want to release var canFit = true - val handleSize = handle.sizeInBytes + val handleSize = handle.approxSizeInBytes var amountSpilled = 0L val hadHandlesToSpill = !handles.isEmpty while (canFit && !tracked && numRetries < 5) { @@ -1093,7 +1102,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) override def remove(handle: SpillableHandle): Unit = { synchronized { if (doRemove(handle)) { - totalSize -= handle.sizeInBytes + totalSize -= handle.approxSizeInBytes } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala index 8e1f404119c..da03992f72f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GeneratedInternalRowToCudfRowIteratorRetrySuite.scala @@ -125,7 +125,7 @@ class GeneratedInternalRowToCudfRowIteratorRetrySuite RmmSpark.OomInjectionType.GPU.ordinal, 0) // at this point we have created a buffer in the Spill Framework // lets spill it - SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) true } }).when(SpillFramework.stores.deviceStore) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala index 747894f9b08..ae2232234ca 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala @@ -105,8 +105,8 @@ class SpillFrameworkSuite val (_, handle, _) = addContiguousTableToCatalog() var path: File = null withResource(handle) { _ => - SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) - SpillFramework.stores.hostStore.spill(handle.sizeInBytes) + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) + SpillFramework.stores.hostStore.spill(handle.approxSizeInBytes) assert(handle.host.isDefined) assert(handle.host.map(_.disk.isDefined).get) path = SpillFramework.stores.diskStore.getFile(handle.host.flatMap(_.disk).get.blockId) @@ -193,11 +193,12 @@ class SpillFrameworkSuite withResource(handle.materialize(dataTypes)) { _ => assertResult(1)(SpillFramework.stores.deviceStore.numHandles) assert(!handle.spillable) - assertResult(0)(SpillFramework.stores.deviceStore.spill(handle.sizeInBytes)) + assertResult(0)(SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes)) } assert(handle.spillable) assertResult(1)(SpillFramework.stores.deviceStore.numHandles) - assertResult(handle.sizeInBytes)(SpillFramework.stores.deviceStore.spill(handle.sizeInBytes)) + assertResult(handle.approxSizeInBytes)( + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes)) } } @@ -349,7 +350,7 @@ class SpillFrameworkSuite withResource(decompressBach(ct)) { decompressedExpected => withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => assert(handle.spillable) - SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) assert(!handle.spillable) assert(handle.dev.isEmpty) assert(handle.host.isDefined) @@ -367,9 +368,9 @@ class SpillFrameworkSuite withResource(decompressBach(ct)) { decompressedExpected => withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => assert(handle.spillable) - SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) assert(!handle.spillable) - SpillFramework.stores.hostStore.spill(handle.sizeInBytes) + SpillFramework.stores.hostStore.spill(handle.approxSizeInBytes) assert(handle.dev.isEmpty) assert(handle.host.isDefined) assert(handle.host.get.host.isEmpty) @@ -466,7 +467,7 @@ class SpillFrameworkSuite withResource(SpillableColumnarBatchFromBufferHandle( ct, dataTypes)) { handle => withResource(expectedBatch) { _ => - SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) withResource(handle.materialize(dataTypes)) { cb => TestUtils.compareBatches(expectedBatch, cb) } @@ -486,9 +487,9 @@ class SpillFrameworkSuite val handle = SpillableColumnarBatchFromBufferHandle(ct, dataTypes) withResource(handle) { _ => withResource(expectedBatch) { _ => - assertResult(SpillFramework.stores.deviceStore.spill(handle.sizeInBytes))( - handle.sizeInBytes) - val hostSize = handle.host.get.sizeInBytes + assertResult(SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes))( + handle.approxSizeInBytes) + val hostSize = handle.host.get.approxSizeInBytes assertResult(SpillFramework.stores.hostStore.spill(hostSize))(hostSize) withResource(handle.materialize(dataTypes)) { actualBatch => TestUtils.compareBatches(expectedBatch, actualBatch) @@ -646,7 +647,7 @@ class SpillFrameworkSuite withResource(bigHandle.materialize(sparkTypes)) { actualBatch => TestUtils.compareBatches(expectedBatch, actualBatch) } - SpillFramework.stores.deviceStore.spill(bigHandle.sizeInBytes) + SpillFramework.stores.deviceStore.spill(bigHandle.approxSizeInBytes) assertResult(true)(bigHandle.dev.isEmpty) assertResult(true)(bigHandle.host.get.host.isEmpty) assertResult(false)(bigHandle.host.get.disk.isEmpty) @@ -714,8 +715,8 @@ class SpillFrameworkSuite } } withResource(expectedBuffer) { expectedBuffer => - deviceStore.spill(handle.sizeInBytes) - hostStore.spill(handle.sizeInBytes) + deviceStore.spill(handle.approxSizeInBytes) + hostStore.spill(handle.approxSizeInBytes) withResource(handle.host.map(_.materialize).get) { actualHostBuffer => assertResult(expectedBuffer. asByteBuffer.limit())(actualHostBuffer.asByteBuffer.limit()) @@ -833,8 +834,10 @@ class SpillFrameworkSuite GpuColumnVector.from(expectedTbl, dataTypes) } withResource(expectedBatch) { _ => - assertResult(true)(SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) > 0) - assertResult(true)(SpillFramework.stores.hostStore.spill(handle.sizeInBytes) > 0) + assertResult(true)( + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) > 0) + assertResult(true)( + SpillFramework.stores.hostStore.spill(handle.approxSizeInBytes) > 0) withResource(handle.materialize(dataTypes)) { actualBatch => TestUtils.compareBatches(expectedBatch, actualBatch) } @@ -865,7 +868,7 @@ class SpillFrameworkSuite withResource(expectedBuffer) { _ => // host store will fail to spill - deviceStore.spill(handle.sizeInBytes) + deviceStore.spill(handle.approxSizeInBytes) assert(handle.host.map(_.host.isEmpty).get) assert(handle.host.map(_.disk.isDefined).get) withResource(handle.host.map(_.materialize).get) { buffer => @@ -894,7 +897,7 @@ class SpillFrameworkSuite withResource(expectedTable) { _ => withResource( GpuColumnVector.from(expectedTable, dataTypes)) { expectedBatch => - SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) assert(handle.host.map(_.host.isEmpty).get) assert(handle.host.map(_.disk.isDefined).get) withResource(handle.materialize(dataTypes)) { fromDiskBatch => @@ -929,7 +932,7 @@ class SpillFrameworkSuite withResource(expectedTable) { _ => withResource( GpuColumnVector.from(expectedTable, dataTypes)) { expectedBatch => - SpillFramework.stores.deviceStore.spill(handle.sizeInBytes) + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) assert(handle.dev.isEmpty) assert(handle.host.map(_.host.isEmpty).get) assert(handle.host.map(_.disk.isDefined).get) @@ -955,8 +958,10 @@ class SpillFrameworkSuite withResource(expectedTable) { _ => withResource( GpuColumnVector.from(expectedTable, expectedTypes)) { expectedCb => - SpillFramework.stores.deviceStore.spill(handle.sizeInBytes + handle2.sizeInBytes) - SpillFramework.stores.hostStore.spill(handle.sizeInBytes + handle2.sizeInBytes) + SpillFramework.stores.deviceStore.spill( + handle.approxSizeInBytes + handle2.approxSizeInBytes) + SpillFramework.stores.hostStore.spill( + handle.approxSizeInBytes + handle2.approxSizeInBytes) // the 0-byte table never moved from device. It is not spillable assert(handle.host.isEmpty) assert(!handle.spillable) From 50e6b216eb790f1a5c70d82ab85dbdff6268000b Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 2 Dec 2024 13:19:54 -0800 Subject: [PATCH 11/36] add note to spill method --- .../scala/com/nvidia/spark/rapids/spill/SpillFramework.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 13a83adf140..f16b4b591a2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -146,6 +146,9 @@ trait SpillableHandle extends StoreHandle { /** * Method called to spill this handle. It can be triggered from the spill store, * or directly against the handle. + * @note The size returned from this method is only used by the spill framework + * to track the approximate size. It should just return `approxSizeInBytes`, as + * that's the size that it used when it first started tracking the object. * @return approxSizeInBytes if spilled, 0 for any other reason (not spillable, closed) */ def spill: Long From d9dbf36706938df44c3b722f6e0495338951037c Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 2 Dec 2024 13:35:21 -0800 Subject: [PATCH 12/36] bring parens back --- .../rapids/DeviceMemoryEventHandler.scala | 4 +-- .../rapids/GpuShuffledSizedHashJoinExec.scala | 4 +-- .../spark/rapids/ShuffleBufferCatalog.scala | 2 +- .../rapids/ShuffleReceivedBufferCatalog.scala | 2 +- .../spark/rapids/SpillableColumnarBatch.scala | 8 ++--- .../rapids/shuffle/BufferSendState.scala | 2 +- .../spark/rapids/spill/SpillFramework.scala | 34 +++++++++---------- .../shuffle/RapidsShuffleServerSuite.scala | 14 ++++---- .../rapids/spill/SpillFrameworkSuite.scala | 24 ++++++------- 9 files changed, 46 insertions(+), 48 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala index 3f450e075da..9c867bb6a90 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala @@ -127,7 +127,7 @@ class DeviceMemoryEventHandler( logInfo(s"Device allocation of $allocSize bytes failed. " + s"Device store spilled $amountSpilled bytes. $attemptMsg" + s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.") - val shouldRetry = if (amountSpilled == 0) { + if (amountSpilled == 0) { if (retryState.shouldTrySynchronizing(retryCount)) { Cuda.deviceSynchronize() logWarning(s"[RETRY ${retryState.getRetriesSoFar}] " + @@ -150,8 +150,6 @@ class DeviceMemoryEventHandler( TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled) true } - - shouldRetry } catch { case t: Throwable => logError(s"Error handling allocation failure", t) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala index 063d8877b50..177710fea81 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala @@ -1116,7 +1116,7 @@ class CudfSpillableHostConcatResult( val hmb: HostMemoryBuffer) extends SpillableHostConcatResult { override def toBatch: ColumnarBatch = { - closeOnExcept(buffer.getHostBuffer) { hostBuf => + closeOnExcept(buffer.getHostBuffer()) { hostBuf => SerializedTableColumn.from(header, hostBuf) } } @@ -1132,7 +1132,7 @@ class KudoSpillableHostConcatResult(kudoTableHeader: KudoTableHeader, require(kudoTableHeader != null, "KudoTableHeader cannot be null") require(hmb != null, "HostMemoryBuffer cannot be null") - override def toBatch: ColumnarBatch = closeOnExcept(buffer.getHostBuffer) { hostBuf => + override def toBatch: ColumnarBatch = closeOnExcept(buffer.getHostBuffer()) { hostBuf => KudoSerializedTableColumn.from(new KudoTable(kudoTableHeader, hostBuf)) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index bf47e97e683..9dd31a3df9e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -210,7 +210,7 @@ class ShuffleBufferCatalog extends Logging { GpuSemaphore.acquireIfNecessary(TaskContext.get) val (maybeHandle, meta) = bufferIdToHandle.get(bId) maybeHandle.map { handle => - withResource(handle.materialize) { buff => + withResource(handle.materialize()) { buff => val bufferMeta = meta.bufferMeta() if (bufferMeta == null || bufferMeta.codecBufferDescrsLength == 0) { MetaUtils.getBatchFromMeta(buff, meta, sparkTypes) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala index 46ca800053a..e591c825d15 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala @@ -71,7 +71,7 @@ class ShuffleReceivedBufferCatalog() extends Logging { var memoryUsedBytes = 0L val cb = if (spillable != null) { memoryUsedBytes = spillable.sizeInBytes - withResource(spillable.materialize) { buff => + withResource(spillable.materialize()) { buff => MetaUtils.getBatchFromMeta(buff, handle.tableMeta, sparkTypes) } } else { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index cef85ed6691..7b247af42d3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -170,7 +170,7 @@ class SpillableCompressedColumnarBatchImpl( override def getColumnarBatch(): ColumnarBatch = { GpuSemaphore.acquireIfNecessary(TaskContext.get()) - handle.materialize + handle.materialize() } override def incRefCount(): SpillableColumnarBatch = { @@ -432,7 +432,7 @@ class SpillableBuffer( * Use the device buffer. */ def getDeviceBuffer(): DeviceMemoryBuffer = { - handle.materialize + handle.materialize() } /** @@ -473,8 +473,8 @@ class SpillableHostBuffer(handle: SpillableHostBufferHandle, handle.close() } - def getHostBuffer: HostMemoryBuffer = { - handle.materialize + def getHostBuffer(): HostMemoryBuffer = { + handle.materialize() } override def toString: String = diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala index 3512871726a..790b39aac5d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala @@ -182,7 +182,7 @@ class BufferSendState( // using `releaseAcquiredToCatalog` //these are closed later, after we synchronize streams val spillable = blockRange.block.bufferHandle.spillable - val buff = spillable.materialize + val buff = spillable.materialize() buff match { case _: DeviceMemoryBuffer => deviceBuffs += blockRange.rangeSize() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index f16b4b591a2..a0e4bb59ab1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -151,7 +151,7 @@ trait SpillableHandle extends StoreHandle { * that's the size that it used when it first started tracking the object. * @return approxSizeInBytes if spilled, 0 for any other reason (not spillable, closed) */ - def spill: Long + def spill(): Long /** * Method used to determine whether a handle tracks an object that could be spilled @@ -247,7 +247,7 @@ class SpillableHostBufferHandle private ( } } - def materialize: HostMemoryBuffer = { + def materialize(): HostMemoryBuffer = { var materialized: HostMemoryBuffer = null var diskHandle: DiskHandle = null synchronized { @@ -270,7 +270,7 @@ class SpillableHostBufferHandle private ( materialized } - override def spill: Long = { + override def spill(): Long = { if (!spillable) { 0L } else { @@ -392,7 +392,7 @@ class SpillableDeviceBufferHandle private ( } } - def materialize: DeviceMemoryBuffer = { + def materialize(): DeviceMemoryBuffer = { var materialized: DeviceMemoryBuffer = null var hostHandle: SpillableHostBufferHandle = null synchronized { @@ -419,7 +419,7 @@ class SpillableDeviceBufferHandle private ( materialized } - override def spill: Long = { + override def spill(): Long = { if (!spillable) { 0L } else { @@ -501,7 +501,7 @@ class SpillableColumnarBatchHandle private ( materialized } - override def spill: Long = { + override def spill(): Long = { if (!spillable) { 0L } else { @@ -631,7 +631,7 @@ class SpillableColumnarBatchFromBufferHandle private ( materialized } - override def spill: Long = { + override def spill(): Long = { if (!spillable) { 0 } else { @@ -697,7 +697,7 @@ class SpillableCompressedColumnarBatchHandle private ( } } - def materialize: ColumnarBatch = { + def materialize(): ColumnarBatch = { var materialized: ColumnarBatch = null var hostHandle: SpillableHostBufferHandle = null synchronized { @@ -722,7 +722,7 @@ class SpillableCompressedColumnarBatchHandle private ( materialized } - override def spill: Long = { + override def spill(): Long = { if (!spillable) { 0L } else { @@ -813,7 +813,7 @@ class SpillableHostColumnarBatchHandle private ( materialized } - override def spill: Long = { + override def spill(): Long = { if (!spillable) { 0L } else { @@ -999,7 +999,7 @@ trait SpillableStore extends HandleStore[SpillableHandle] with Logging { var numSpilled = 0 while (it.hasNext) { val handle = it.next() - val spilled = handle.spill + val spilled = handle.spill() if (spilled > 0) { // this thread was successful at spilling handle. amountSpilled += spilled @@ -1187,7 +1187,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) // add some sort of setter method to Host Handle require(handle != null, "Called build too many times") require(copied == handle.sizeInBytes, - s"Expected ${handle.sizeInBytes} B but copied ${copied} B instead") + s"Expected ${handle.sizeInBytes} B but copied $copied B instead") handle.setHost(singleShotBuffer) singleShotBuffer = null val res = handle @@ -1241,7 +1241,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) // add some sort of setter method to Host Handle require(handle != null, "Called build too many times") require(copied == handle.sizeInBytes, - s"Expected ${handle.sizeInBytes} B but copied ${copied} B instead") + s"Expected ${handle.sizeInBytes} B but copied $copied B instead") handle.setDisk(diskHandleBuilder.build) val res = handle handle = null @@ -1319,7 +1319,7 @@ object DiskHandleStore { // as it is, we could use `DiskWriter` to start writing at other offsets private var closed = false - private var fc: FileChannel = null + private var fc: FileChannel = _ private def getFileChannel: FileChannel = { val options = Seq(StandardOpenOption.CREATE, StandardOpenOption.WRITE) @@ -1428,7 +1428,7 @@ object SpillFramework extends Logging { // public for tests var storesInternal: SpillableStores = _ - def stores = { + def stores: SpillableStores = { if (storesInternal == null) { throw new IllegalStateException( "Cannot use SpillFramework without calling SpillFramework.initialize first") @@ -1504,7 +1504,7 @@ object SpillFramework extends Logging { body(hostSpillBounceBuffer) } - var chunkedPackBounceBufferPool: DeviceBounceBufferPool = null + var chunkedPackBounceBufferPool: DeviceBounceBufferPool = _ def removeFromDeviceStore(value: SpillableHandle): Unit = { // if the stores have already shut down, we don't want to create them here @@ -1676,6 +1676,6 @@ class ChunkedPacker(table: Table, } } -object ChunkedPacker { +private object ChunkedPacker { private var warnedAboutPoolFallback: Boolean = false } \ No newline at end of file diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala index f564089c2c9..8d7415fba04 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala @@ -234,7 +234,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { // acquire once at the beginning, and closed at the end verify(mockRequestHandler, times(1)) .getShuffleHandle(ArgumentMatchers.eq(1)) - withResource(rapidsBuffer.spillable.materialize) { dmb => + withResource(rapidsBuffer.spillable.materialize()) { dmb => // refcount=2 because it was on the device, and we +1 to materialize. // but it shows no leaks. assertResult(2)(dmb.getRefCount) @@ -271,11 +271,11 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { val mockHandleThatThrows = mock[SpillableDeviceBufferHandle] val mockMaterialized = mock[DeviceMemoryBuffer] when(mockHandle.sizeInBytes).thenReturn(tableMeta.bufferMeta().size()) - when(mockHandle.materialize).thenAnswer(_ => mockMaterialized) + when(mockHandle.materialize()).thenAnswer(_ => mockMaterialized) when(mockHandleThatThrows.sizeInBytes).thenReturn(tableMeta.bufferMeta().size()) val ex = new IllegalStateException("something happened") - when(mockHandleThatThrows.materialize).thenThrow(ex) + when(mockHandleThatThrows.materialize()).thenThrow(ex) val rapidsBuffer = RapidsShuffleHandle(mockHandle, tableMeta) val rapidsBufferThatThrows = RapidsShuffleHandle(mockHandleThatThrows, tableMeta) @@ -362,7 +362,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { val rapidsBuffer = RapidsShuffleHandle(mockHandle, tableMeta) when(mockHandle.sizeInBytes).thenReturn(tableMeta.bufferMeta().size()) // mock an error with the copy - when(rapidsBuffer.spillable.materialize) + when(rapidsBuffer.spillable.materialize()) .thenAnswer(_ => { throw new IOException("mmap failed in test") }) @@ -428,11 +428,11 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper { verify(mockRequestHandler, times(1)) .getShuffleHandle(ArgumentMatchers.eq(2)) // this handle fails to materialize - verify(rapidsHandle.spillable, times(1)).materialize + verify(rapidsHandle.spillable, times(1)).materialize() // this handle materializes, so make sure we close it - verify(rapidsHandle2.spillable, times(1)).materialize - withResource(rapidsHandle2.spillable.materialize) { dmb => + verify(rapidsHandle2.spillable, times(1)).materialize() + withResource(rapidsHandle2.spillable.materialize()) { dmb => // refcount=2 because it was on the device, and we +1 to materialize. // but it shows no leaks. assertResult(2)(dmb.getRefCount) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala index ae2232234ca..f1b53ad1ccb 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala @@ -334,7 +334,7 @@ class SpillFrameworkSuite val ct = buildCompressedBatch(0, 1000) withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => assert(handle.spillable) - withResource(handle.materialize) { materialized => + withResource(handle.materialize()) { materialized => assert(!handle.spillable) // since we didn't spill, these buffers are exactly the same assert( @@ -354,7 +354,7 @@ class SpillFrameworkSuite assert(!handle.spillable) assert(handle.dev.isEmpty) assert(handle.host.isDefined) - withResource(handle.materialize) { materialized => + withResource(handle.materialize()) { materialized => withResource(decompressBach(materialized)) { decompressed => TestUtils.compareBatches(decompressedExpected, decompressed) } @@ -375,7 +375,7 @@ class SpillFrameworkSuite assert(handle.host.isDefined) assert(handle.host.get.host.isEmpty) assert(handle.host.get.disk.isDefined) - withResource(handle.materialize) { materialized => + withResource(handle.materialize()) { materialized => withResource(decompressBach(materialized)) { decompressed => TestUtils.compareBatches(decompressedExpected, decompressed) } @@ -390,7 +390,7 @@ class SpillFrameworkSuite val handle1 = SpillableDeviceBufferHandle(buffer) // materialize will incRefCount `buffer`. This looks a little weird // but it simulates aliasing as it happens in real code - val handle2 = SpillableDeviceBufferHandle(handle1.materialize) + val handle2 = SpillableDeviceBufferHandle(handle1.materialize()) withResource(handle1) { _ => withResource(handle2) { _ => @@ -505,7 +505,7 @@ class SpillFrameworkSuite withResource(spillableBuffer) { _ => // the refcount of 1 is the store assertResult(1)(hmb.getRefCount) - withResource(spillableBuffer.getHostBuffer) { memoryBuffer => + withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => assertResult(hmb)(memoryBuffer) assertResult(2)(memoryBuffer.getRefCount) } @@ -526,7 +526,7 @@ class SpillFrameworkSuite withResource(spillableBuffer) { _ => // the refcount of the original buffer is 0 because it spilled assertResult(0)(hmb.getRefCount) - withResource(spillableBuffer.getHostBuffer) { memoryBuffer => + withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => assertResult(memoryBuffer.getLength)(hmb.getLength) } } @@ -536,7 +536,7 @@ class SpillFrameworkSuite val spillPriority = -10 val hmb = HostMemoryBuffer.allocate(1L * 1024) withResource(SpillableHostBuffer(hmb, hmb.getLength, spillPriority)) { spillableBuffer => - withResource(spillableBuffer.getHostBuffer) { _ => + withResource(spillableBuffer.getHostBuffer()) { _ => assertResult(0)(SpillFramework.stores.hostStore.spill(hmb.getLength)) } assertResult(hmb.getLength)(SpillFramework.stores.hostStore.spill(hmb.getLength)) @@ -708,7 +708,7 @@ class SpillFrameworkSuite assertResult(0)(diskStore.numHandles) assertResult(0)(hostStore.numHandles) val expectedBuffer = - withResource(handle.materialize) { devbuf => + withResource(handle.materialize()) { devbuf => closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => hostbuf.copyFromDeviceBuffer(devbuf) hostbuf @@ -717,7 +717,7 @@ class SpillFrameworkSuite withResource(expectedBuffer) { expectedBuffer => deviceStore.spill(handle.approxSizeInBytes) hostStore.spill(handle.approxSizeInBytes) - withResource(handle.host.map(_.materialize).get) { actualHostBuffer => + withResource(handle.host.map(_.materialize()).get) { actualHostBuffer => assertResult(expectedBuffer. asByteBuffer.limit())(actualHostBuffer.asByteBuffer.limit()) } @@ -859,7 +859,7 @@ class SpillFrameworkSuite val deviceStore = SpillFramework.stores.deviceStore withResource(handle) { _ => val expectedBuffer = - withResource(handle.materialize) { devbuf => + withResource(handle.materialize()) { devbuf => closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => hostbuf.copyFromDeviceBuffer(devbuf) hostbuf @@ -871,7 +871,7 @@ class SpillFrameworkSuite deviceStore.spill(handle.approxSizeInBytes) assert(handle.host.map(_.host.isEmpty).get) assert(handle.host.map(_.disk.isDefined).get) - withResource(handle.host.map(_.materialize).get) { buffer => + withResource(handle.host.map(_.materialize()).get) { buffer => assertResult(expectedBuffer.asByteBuffer)(buffer.asByteBuffer) } } @@ -890,7 +890,7 @@ class SpillFrameworkSuite // fill up the host store withResource(SpillableHostBufferHandle(HostMemoryBuffer.allocate(1024))) { hostHandle => // make sure the host handle isn't spillable - withResource(hostHandle.materialize) { _ => + withResource(hostHandle.materialize()) { _ => val (handle, _) = addTableToCatalog() withResource(handle) { _ => val (expectedTable, dataTypes) = buildTable() From d1117e8d78d1f957d29bc609590803103bf8fd76 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 2 Dec 2024 13:44:59 -0800 Subject: [PATCH 13/36] make sure we return from spill non zero only when we actually spill --- .../spark/rapids/spill/SpillFramework.scala | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index a0e4bb59ab1..c21a7cc491c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -274,7 +274,7 @@ class SpillableHostBufferHandle private ( if (!spillable) { 0L } else { - synchronized { + val spilled = synchronized { if (disk.isEmpty) { withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => val outputChannel = diskHandleBuilder.getChannel @@ -293,14 +293,14 @@ class SpillableHostBufferHandle private ( } } disk = Some(diskHandleBuilder.build) - disk + sizeInBytes } } else { - None + 0L } } releaseHostResource() - sizeInBytes + spilled } } @@ -423,13 +423,16 @@ class SpillableDeviceBufferHandle private ( if (!spillable) { 0L } else { - synchronized { + val spilled = synchronized { if (host.isEmpty) { host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(dev.get)) + sizeInBytes + } else { + 0L } } releaseDeviceResources() - sizeInBytes + spilled } } @@ -505,19 +508,22 @@ class SpillableColumnarBatchHandle private ( if (!spillable) { 0L } else { - synchronized { + val spilled = synchronized { if (host.isEmpty) { withChunkedPacker { chunkedPacker => meta = Some(chunkedPacker.getPackedMeta) host = Some(SpillableHostBufferHandle.createHostHandleWithPacker(chunkedPacker)) } + approxSizeInBytes + } else { + 0L } } releaseDeviceResource() // We return the size we were created with. This is not the actual size // of this batch when it is packed, and it is used by the calling code // to figure out more or less how much did we free in the device. - approxSizeInBytes + spilled } } @@ -635,16 +641,19 @@ class SpillableColumnarBatchFromBufferHandle private ( if (!spillable) { 0 } else { - synchronized { + val spilled = synchronized { if (host.isEmpty) { val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuColumnVectorFromBuffer] meta = Some(cvFromBuffer.getTableMeta) host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( cvFromBuffer.getBuffer)) + sizeInBytes + } else { + 0L } } releaseDeviceResource() - sizeInBytes + spilled } } @@ -726,16 +735,19 @@ class SpillableCompressedColumnarBatchHandle private ( if (!spillable) { 0L } else { - synchronized { + val spilled = synchronized { if (host.isEmpty) { val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuCompressedColumnVector] meta = Some(cvFromBuffer.getTableMeta) host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( cvFromBuffer.getTableBuffer)) + compressedSizeInBytes + } else { + 0L } } releaseDeviceResource() - compressedSizeInBytes + spilled } } @@ -817,7 +829,7 @@ class SpillableHostColumnarBatchHandle private ( if (!spillable) { 0L } else { - synchronized { + val spilled = synchronized { if (disk.isEmpty) { withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => GpuTaskMetrics.get.spillToDiskTime { @@ -826,14 +838,14 @@ class SpillableHostColumnarBatchHandle private ( JCudfSerialization.writeToStream(columns, dos, 0, host.get.numRows()) } disk = Some(diskHandleBuilder.build) - disk + approxSizeInBytes } } else { - None + 0L } } releaseHostResource() - approxSizeInBytes + spilled } } From 91006c5eff32c04e02ec7895b6967129d1703c36 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 2 Dec 2024 13:49:59 -0800 Subject: [PATCH 14/36] make ShuffleInfo a type alias --- .../spark/rapids/ShuffleBufferCatalog.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index 9dd31a3df9e..dede76cbd7e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -50,8 +50,8 @@ class ShuffleBufferCatalog extends Logging { * * @param blockMap mapping of block ID to array of buffers for the block */ - private case class ShuffleInfo( - blockMap: ConcurrentHashMap[ShuffleBlockId, ArrayBuffer[ShuffleBufferId]]) + private type ShuffleInfo = + ConcurrentHashMap[ShuffleBlockId, ArrayBuffer[ShuffleBufferId]] private val bufferIdToHandle = new ConcurrentHashMap[ @@ -117,8 +117,7 @@ class ShuffleBufferCatalog extends Logging { * Adds a buffer to the device storage, taking ownership of the buffer. * * @param blockId Spark's `ShuffleBlockId` that identifies this buffer - * @param buffer buffer that will be owned by the store - * @param tableMeta metadata describing the buffer layout + * @param compressedBatch Compressed ColumnarBatch * @param initialSpillPriority starting spill priority value for the buffer * @return RapidsBufferHandle associated with this buffer */ @@ -157,8 +156,7 @@ class ShuffleBufferCatalog extends Logging { * @param shuffleId shuffle identifier */ def registerShuffle(shuffleId: Int): Unit = { - activeShuffles.computeIfAbsent(shuffleId, _ => ShuffleInfo( - new ConcurrentHashMap[ShuffleBlockId, ArrayBuffer[ShuffleBufferId]])) + activeShuffles.computeIfAbsent(shuffleId, _ => new ShuffleInfo) } /** Frees all buffers that correspond to the specified shuffle. */ @@ -176,7 +174,7 @@ class ShuffleBufferCatalog extends Logging { handleAndMeta._1.foreach(_.close()) } } - info.blockMap.forEachValue(Long.MaxValue, bufferRemover) + info.forEachValue(Long.MaxValue, bufferRemover) } else { // currently shuffle unregister can get called on the driver which never saw a register if (!TrampolineUtil.isDriver(SparkEnv.get)) { @@ -193,7 +191,7 @@ class ShuffleBufferCatalog extends Logging { if (info == null) { throw new NoSuchElementException(s"unknown shuffle ${blockId.shuffleId}") } - val entries = info.blockMap.get(blockId) + val entries = info.get(blockId) if (entries == null) { throw new NoSuchElementException(s"unknown shuffle block $blockId") } @@ -244,7 +242,7 @@ class ShuffleBufferCatalog extends Logging { if (info == null) { throw new NoSuchElementException(s"unknown shuffle ${blockId.shuffleId}") } - val entries = info.blockMap.get(blockId) + val entries = info.get(blockId) if (entries == null) { throw new NoSuchElementException(s"unknown shuffle block $blockId") } @@ -270,7 +268,7 @@ class ShuffleBufferCatalog extends Logging { } // associate this new buffer with the shuffle block - val blockBufferIds = info.blockMap.computeIfAbsent(blockId, _ => + val blockBufferIds = info.computeIfAbsent(blockId, _ => new ArrayBuffer[ShuffleBufferId]) blockBufferIds.synchronized { blockBufferIds.append(id) From 455e99c590e04379ffdb31be1b73fcdccf7c84af Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 2 Dec 2024 14:19:40 -0800 Subject: [PATCH 15/36] catch Exception not all of Throwable --- .../scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala index 790b39aac5d..0a7942bd581 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala @@ -214,7 +214,7 @@ class BufferSendState( } needsCleanup = false } catch { - case ex: Throwable => + case ex: Exception => throw new RapidsShuffleSendPrepareException( s"Error while copying to bounce buffer for executor ${peerExecutorId} and " + s"header ${TransportUtils.toHex(peerBufferReceiveHeader)}", ex) From 16a16859e044e7138fcaf2d3d3f9745308c002e9 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 2 Dec 2024 15:03:51 -0800 Subject: [PATCH 16/36] rework comments around locking and cuda synchronization --- .../spark/rapids/spill/SpillFramework.scala | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index c21a7cc491c..a6536fef329 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -62,10 +62,11 @@ import org.apache.spark.storage.BlockId * CUDA/Host synchronization: * * We assume all device backed handles are completely materialized on the device (before adding - * to the store, the CUDA stream has been synchronized with the CPU thread creating the handle). - * - * We assume all host memory backed handles are completely materialized and not mutated by - * other CPU threads once handed to the framework. + * to the store, the CUDA stream has been synchronized with the CPU thread creating the handle), + * and that all host memory backed handles are completely materialized and not mutated by + * other CPU threads, because the contents of the handle may spill at any time, using any CUDA + * stream or thread, without synchronization. If handles added to the store are not synchronized + * we could write incomplete data to host memory or to disk. * * Spillability: * @@ -113,17 +114,17 @@ import org.apache.spark.storage.BlockId * with extra locking is the `SpillableHostStore`, to maintain a `totalSize` number that is * used to figure out cheaply when it is full. * - * Handles hold a lock to protect the user against when it is either in the middle of - * spilling, or closed. Most handles, except for disk handles, hold a reference to an object - * in their respective store (`SpillableDeviceBufferHandle` has a `dev` reference that holds - * a `DeviceMemoryBuffer`, and a `host` reference to `SpillableHostBufferHandle` that is only - * set if spilled), and the handle guarantees that one of these will be set, or none if the handle - * is closed. + * All handles, except for disk handles, hold a reference to an object in their respective store: + * `SpillableDeviceBufferHandle` has a `dev` reference that holds a `DeviceMemoryBuffer`, and a + * `host` reference to `SpillableHostBufferHandle` that is only set if spilled. Disk handles are + * different because they don't spill, as disk is considered the final store. When a user calls + * `materialize` on a handle, the handle must guarantee that it can satisfy that, even if the caller + * should wait until a spill happens. This is currently implemented using the handle lock. * - * We hold the handle lock when we are spilling (performing IO). That means that no other consumer - * can access this spillable device handle while it is being spilled, including a second thread - * trying to spill and asking each of the handles whether they are spillable or not, as that - * requires the handle lock. We will relax this likely in follow on work. + * Note that we hold the handle lock while we are spilling (performing IO). That means that no other + * consumer can access this spillable device handle while it is being spilled, including a second + * thread that is trying to spill and is generating a spill plan, as the handle lock is likely held + * up with IO. We will relax this likely in follow on work. * * We never hold a store-wide coarse grain lock in the stores when we do IO. */ From 6b692df7384e6b6795539ea9a37df1c8a9a2e02e Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 3 Dec 2024 14:16:59 -0800 Subject: [PATCH 17/36] add a comment around how host objects are behaving with this change, which is different than before --- .../spark/rapids/spill/SpillFramework.scala | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index a6536fef329..a28ead8be33 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -104,6 +104,21 @@ import org.apache.spark.storage.BlockId * will either create a host handle tracking an object on the host store (if we made room), or it * will create a host handle that points to a disk handle, tracking a file on disk. * + * Host handles created directly, via the factory methods `SpillableHostBufferHandle(...)` or + * `SpillableHostColumnarBatchHandle(...)`, can trigger immediate spills, if we have host store + * limits. For example: if the host store limit is set to 1GB, and we add a 1.5GB host buffer via + * its factory method, we are going to spill all 1GB that we had in the store, and then track the + * 1.5GB buffer, which is above the limit. Next time we add a host object in this way, or via a + * device -> host spill, we are going to spill the 1.5GB buffer. This is a departure from how + * the spill framework used to work, where the host memory added directly did not cause spills + * directly, and only device->host spills would trigger the spill. + * + * If we don't have a host store limit, spilling from the host store is done entirely via + * host memory allocation failure callbacks. All objects added to the host store are tracked + * immediately, since they were successfully allocated. If we fail to allocate host memory + * during a device->host spill, however, the spill framework will bypass host memory and + * go straight to disk (this last part works the same whether there are host limits or not). + * * If the disk is full, we do not handle this in any special way. We expect this to be a * terminal state to the executor. Every handle spills to its own file on disk, identified * as a "temporary block" `BlockId` from Spark. @@ -194,10 +209,6 @@ object SpillableHostBufferHandle extends Logging { def apply(hmb: HostMemoryBuffer): SpillableHostBufferHandle = { val handle = new SpillableHostBufferHandle(hmb.getLength, host = Some(hmb)) SpillFramework.stores.hostStore.track(handle) - // do we care if: - // handle didn't fit in the store as it is too large. - // we made memory for this so we are going to hold our noses and keep going - // we could spill `handle` at this point. handle } From 301a163bef9a3212178903705221527e7a50c162 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 3 Dec 2024 20:59:18 -0800 Subject: [PATCH 18/36] Add columnar batch spill to host/disk and reconstitute tests --- .../rapids/spill/SpillFrameworkSuite.scala | 98 ++++++++++++++++--- 1 file changed, 83 insertions(+), 15 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala index f1b53ad1ccb..bd53d750d29 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala @@ -67,6 +67,16 @@ class SpillFrameworkSuite } } + private def buildNonContiguousTableOfLongs( + numRows: Int): (Table, Array[DataType])= { + val vals = (0 until numRows).map(_.toLong) + withResource(HostColumnVector.fromLongs(vals: _*)) { hcv => + withResource(hcv.copyToDevice()) { cv => + (new Table(cv), Array[DataType](LongType)) + } + } + } + private def buildTable(): (Table, Array[DataType]) = { val tbl = new Table.TestBuilder() .column(5, null.asInstanceOf[java.lang.Integer], 3, 1) @@ -102,7 +112,7 @@ class SpillFrameworkSuite } private def testBufferFileDeletion(canShareDiskPaths: Boolean): Unit = { - val (_, handle, _) = addContiguousTableToCatalog() + val (_, handle, _) = addContiguousTableToFramework() var path: File = null withResource(handle) { _ => SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) @@ -115,7 +125,7 @@ class SpillFrameworkSuite assert(!path.exists) } - private def addContiguousTableToCatalog(): ( + private def addContiguousTableToFramework(): ( Long, SpillableColumnarBatchFromBufferHandle, Array[DataType]) = { val (ct, dataTypes) = buildContiguousTable() val bufferSize = ct.getBuffer.getLength @@ -123,7 +133,7 @@ class SpillFrameworkSuite (bufferSize, handle, dataTypes) } - private def addTableToCatalog(): (SpillableColumnarBatchHandle, Array[DataType]) = { + private def addTableToFramework(): (SpillableColumnarBatchHandle, Array[DataType]) = { // store takes ownership of the table val (tbl, dataTypes) = buildTable() val cb = withResource(tbl) { _ => GpuColumnVector.from(tbl, dataTypes) } @@ -131,7 +141,7 @@ class SpillFrameworkSuite (handle, dataTypes) } - private def addZeroRowsTableToCatalog(): (SpillableColumnarBatchHandle, Array[DataType]) = { + private def addZeroRowsTableToFramework(): (SpillableColumnarBatchHandle, Array[DataType]) = { val (table, dataTypes) = buildEmptyTable() val cb = withResource(table) { _ => GpuColumnVector.from(table, dataTypes) } val handle = SpillableColumnarBatchHandle(cb) @@ -167,9 +177,6 @@ class SpillFrameworkSuite }, hostCols.head.getRowCount.toInt), dataTypes) } - // TODO: AB: add tests that span multiple byte buffers for host->disk, and - // test that span multiple chunked pack bounce buffers - test("add table registers with device store") { val (ct, dataTypes) = buildContiguousTable() withResource(SpillableColumnarBatchFromBufferHandle(ct, dataTypes)) { _ => @@ -435,13 +442,13 @@ class SpillFrameworkSuite assertResult(0)(SpillFramework.stores.diskStore.numHandles) } - test("spill updates catalog") { + test("spill updates store state") { val diskStore = SpillFramework.stores.diskStore val hostStore = SpillFramework.stores.hostStore val deviceStore = SpillFramework.stores.deviceStore val (bufferSize, handle, _) = - addContiguousTableToCatalog() + addContiguousTableToFramework() withResource(handle) { _ => assertResult(1)(deviceStore.numHandles) @@ -661,7 +668,7 @@ class SpillFrameworkSuite } test("get columnar batch after spilling to disk") { - val (size, handle, dataTypes) = addContiguousTableToCatalog() + val (size, handle, dataTypes) = addContiguousTableToFramework() val diskStore = SpillFramework.stores.diskStore val hostStore = SpillFramework.stores.hostStore val deviceStore = SpillFramework.stores.deviceStore @@ -698,6 +705,67 @@ class SpillFrameworkSuite } } + val hostSpillStorageSizes = Seq("-1", "1MB", "16MB") + val spillToDiskBounceBuffers = Seq("128KB", "2MB", "128MB") + val chunkedPackBounceBuffers = Seq("1MB", "8MB", "128MB") + hostSpillStorageSizes.foreach { hostSpillStorageSize => + spillToDiskBounceBuffers.foreach { spillToDiskBounceBufferSize => + chunkedPackBounceBuffers.foreach { chunkedPackBounceBufferSize => + test("materialize non-contiguous batch after " + + s"host_storage_size=$hostSpillStorageSize " + + s"spilling chunked_pack_bb=$chunkedPackBounceBufferSize " + + s"spill_to_disk_bb=$spillToDiskBounceBufferSize") { + SpillFramework.shutdown() + try { + val sc = new SparkConf + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, hostSpillStorageSize) + sc.set(RapidsConf.CHUNKED_PACK_BOUNCE_BUFFER_SIZE.key, chunkedPackBounceBufferSize) + sc.set(RapidsConf.SPILL_TO_DISK_BOUNCE_BUFFER_SIZE.key, spillToDiskBounceBufferSize) + SpillFramework.initialize(new RapidsConf(sc)) + val (largeTable, dataTypes) = buildNonContiguousTableOfLongs(numRows = 1000000) + val handle = SpillableColumnarBatchHandle(largeTable, dataTypes) + val diskStore = SpillFramework.stores.diskStore + val hostStore = SpillFramework.stores.hostStore + val deviceStore = SpillFramework.stores.deviceStore + withResource(handle) { _ => + assertResult(1)(deviceStore.numHandles) + assertResult(0)(diskStore.numHandles) + assertResult(0)(hostStore.numHandles) + + val expectedTable = + withResource(handle.materialize(dataTypes)) { beforeSpill => + withResource(GpuColumnVector.from(beforeSpill)) { table => + table.contiguousSplit()(0) + } + } // closing the batch from the store so that we can spill it + + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable.getTable, dataTypes)) { expectedBatch => + deviceStore.spill(handle.approxSizeInBytes) + hostStore.spill(handle.approxSizeInBytes) + + assertResult(0)(deviceStore.numHandles) + assertResult(0)(hostStore.numHandles) + assertResult(1)(diskStore.numHandles) + + val diskHandle = handle.host.flatMap(_.disk).get + val path = diskStore.getFile(diskHandle.blockId) + assert(path.exists) + withResource(handle.materialize(dataTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } + } + } + } finally { + SpillFramework.shutdown() + } + } + } + } + } + test("get memory buffer after spilling to disk") { val handle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(123)) val diskStore = SpillFramework.stores.diskStore @@ -822,7 +890,7 @@ class SpillFrameworkSuite .thenReturn(new RapidsSerializerManager(conf)) shareDiskPaths.foreach { _ => - val (_, handle, dataTypes) = addContiguousTableToCatalog() + val (_, handle, dataTypes) = addContiguousTableToFramework() withResource(handle) { _ => val expectedCt = withResource(handle.materialize(dataTypes)) { devbatch => withResource(GpuColumnVector.from(devbatch)) { tmpTbl => @@ -891,7 +959,7 @@ class SpillFrameworkSuite withResource(SpillableHostBufferHandle(HostMemoryBuffer.allocate(1024))) { hostHandle => // make sure the host handle isn't spillable withResource(hostHandle.materialize()) { _ => - val (handle, _) = addTableToCatalog() + val (handle, _) = addTableToFramework() withResource(handle) { _ => val (expectedTable, dataTypes) = buildTable() withResource(expectedTable) { _ => @@ -926,7 +994,7 @@ class SpillFrameworkSuite sc.set(RapidsConf.CHUNKED_PACK_BOUNCE_BUFFER_SIZE.key, "1MB") val rapidsConf = new RapidsConf(sc) SpillFramework.initialize(rapidsConf) - val (handle, _) = addTableToCatalog() + val (handle, _) = addTableToFramework() withResource(handle) { _ => val (expectedTable, dataTypes) = buildTable() withResource(expectedTable) { _ => @@ -948,8 +1016,8 @@ class SpillFrameworkSuite } test("0-byte table is never spillable") { - val (handle, _) = addZeroRowsTableToCatalog() - val (handle2, _) = addTableToCatalog() + val (handle, _) = addZeroRowsTableToFramework() + val (handle2, _) = addTableToFramework() withResource(handle) { _ => withResource(handle2) { _ => From a882fb2cf90a0663338f681b6aa8f3e86332e582 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 4 Dec 2024 07:27:12 -0800 Subject: [PATCH 19/36] Add more tests in shuffle catalog suite --- .../rapids/ShuffleBufferCatalogSuite.scala | 70 ++++++++++++++++++- .../shuffle/RapidsShuffleTestHelper.scala | 13 ++-- 2 files changed, 77 insertions(+), 6 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala index d86aa8fccb2..a7bea7a5b37 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ShuffleBufferCatalogSuite.scala @@ -16,11 +16,31 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.shuffle.RapidsShuffleTestHelper +import com.nvidia.spark.rapids.spill.SpillFramework +import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar -class ShuffleBufferCatalogSuite extends AnyFunSuite with MockitoSugar { - // TODO: AB: more tests please +import org.apache.spark.SparkConf +import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.storage.ShuffleBlockId + +class ShuffleBufferCatalogSuite + extends AnyFunSuite with MockitoSugar with BeforeAndAfterEach { + + override def beforeEach(): Unit = { + super.beforeEach() + SpillFramework.initialize(new RapidsConf(new SparkConf)) + } + + override def afterEach(): Unit = { + super.afterEach() + SpillFramework.shutdown() + } + test("registered shuffles should be active") { val shuffleCatalog = new ShuffleBufferCatalog() assertResult(false)(shuffleCatalog.hasActiveShuffle(123)) @@ -30,4 +50,50 @@ class ShuffleBufferCatalogSuite extends AnyFunSuite with MockitoSugar { assertResult(false)(shuffleCatalog.hasActiveShuffle(123)) } + test("adding a degenerate batch") { + val shuffleCatalog = new ShuffleBufferCatalog() + val tableMeta = mock[TableMeta] + // need to register the shuffle id first + assertThrows[IllegalStateException] { + shuffleCatalog.addDegenerateRapidsBuffer(ShuffleBlockId(1, 1L, 1), tableMeta) + } + shuffleCatalog.registerShuffle(1) + shuffleCatalog.addDegenerateRapidsBuffer(ShuffleBlockId(1,1L,1), tableMeta) + val storedMetas = shuffleCatalog.blockIdToMetas(ShuffleBlockId(1, 1L, 1)) + assertResult(1)(storedMetas.size) + assertResult(tableMeta)(storedMetas.head) + } + + test("adding a contiguous batch adds it to the spill store") { + val shuffleCatalog = new ShuffleBufferCatalog() + val ct = RapidsShuffleTestHelper.buildContiguousTable(1000) + shuffleCatalog.registerShuffle(1) + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + shuffleCatalog.addContiguousTable(ShuffleBlockId(1, 1L, 1), ct, -1) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + val storedMetas = shuffleCatalog.blockIdToMetas(ShuffleBlockId(1, 1L, 1)) + assertResult(1)(storedMetas.size) + shuffleCatalog.unregisterShuffle(1) + } + + test("get a columnar batch iterator from catalog") { + val shuffleCatalog = new ShuffleBufferCatalog() + shuffleCatalog.registerShuffle(1) + // add metadata only table + val tableMeta = RapidsShuffleTestHelper.mockTableMeta(0) + shuffleCatalog.addDegenerateRapidsBuffer(ShuffleBlockId(1, 1L, 1), tableMeta) + val ct = RapidsShuffleTestHelper.buildContiguousTable(1000) + shuffleCatalog.addContiguousTable(ShuffleBlockId(1, 1L, 1), ct, -1) + val iter = + shuffleCatalog.getColumnarBatchIterator( + ShuffleBlockId(1, 1L, 1), Array[DataType](IntegerType)) + withResource(iter.toArray) { cbs => + assertResult(2)(cbs.length) + assertResult(0)(cbs.head.numRows()) + assertResult(1)(cbs.head.numCols()) + assertResult(1000)(cbs.last.numRows()) + assertResult(1)(cbs.last.numCols()) + shuffleCatalog.unregisterShuffle(1) + } + } } diff --git a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index deb31f8476e..b31d20ecd36 100644 --- a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -182,21 +182,26 @@ object RapidsShuffleTestHelper extends MockitoSugar { MetaUtils.buildDegenerateTableMeta(new ColumnarBatch(Array.empty, 123)) } - def withMockContiguousTable[T](numRows: Long)(body: ContiguousTable => T): T = { + def buildContiguousTable(numRows: Long): ContiguousTable = { val rows: Seq[Integer] = (0 until numRows.toInt).map(Int.box) withResource(ColumnVector.fromBoxedInts(rows:_*)) { cvBase => cvBase.incRefCount() val gpuCv = GpuColumnVector.from(cvBase, IntegerType) withResource(new ColumnarBatch(Array(gpuCv))) { cb => withResource(GpuColumnVector.from(cb)) { table => - withResource(table.contiguousSplit(0, numRows.toInt)) { ct => - body(ct(1)) // we get a degenerate table at 0 and another at 2 - } + val cts = table.contiguousSplit() + cts(0) } } } } + def withMockContiguousTable[T](numRows: Long)(body: ContiguousTable => T): T = { + withResource(buildContiguousTable(numRows)) { ct => + body(ct) + } + } + def mockMetaResponse( mockTransaction: Transaction, numRows: Long, From 82fdff62d2c7acc899e6216ecc0d65d21fda00cc Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 4 Dec 2024 07:27:38 -0800 Subject: [PATCH 20/36] Check number of host store handles in order to decide whether to retry or not --- .../com/nvidia/spark/rapids/HostAlloc.scala | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala index e25f6b48a40..6079c0352df 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala @@ -143,19 +143,24 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L currentPinnedAllocated + currentNonPinnedAllocated } - // TODO: AB fix this - //val attemptMsg = if (retryCount > 0) { - // s"Attempt $retryCount" - //} else { - // "First attempt" - //} + val attemptMsg = if (retryCount > 0) { + s"Attempt $retryCount" + } else { + "First attempt" + } val amountSpilled = store.spill(allocSize) if (amountSpilled == 0) { - logWarning(s"Host store exhausted, unable to allocate $allocSize bytes. " + - s"Total host allocated is $totalSize bytes.") - false + val shouldRetry = store.numHandles > 0 + val exhaustedMsg = s"Host store exhausted, unable to allocate $allocSize bytes. " + + s"Total host allocated is $totalSize bytes. $attemptMsg." + if (!shouldRetry) { + logWarning(exhaustedMsg) + } else { + logWarning(s"$exhaustedMsg Attempting a retry.") + } + shouldRetry } else { logInfo(s"Spilled $amountSpilled bytes from the host store") true From 4d4651184d076eb64379818b827d02c4dec2cbe9 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 4 Dec 2024 07:44:33 -0800 Subject: [PATCH 21/36] scala 2.13 fixes --- .../com/nvidia/spark/rapids/implicits.scala | 56 +------------------ .../spark/rapids/ShuffleBufferCatalog.scala | 2 +- 2 files changed, 2 insertions(+), 56 deletions(-) diff --git a/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala b/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala index 5bdded6dbd4..1e4c5e39a19 100644 --- a/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala +++ b/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,26 +63,6 @@ object RapidsPluginImplicits { } } - implicit class RapidsBufferColumn(rapidsBuffer: RapidsBuffer) { - - /** - * safeFree: Is an implicit on RapidsBuffer class that tries to free the resource, if an - * Exception was thrown prior to this free, it adds the new exception to the suppressed - * exceptions, otherwise just throws - * - * @param e Exception which we don't want to suppress - */ - def safeFree(e: Throwable = null): Unit = { - if (rapidsBuffer != null) { - try { - rapidsBuffer.free() - } catch { - case suppressed: Throwable if e != null => e.addSuppressed(suppressed) - } - } - } - } - implicit class AutoCloseableSeq[A <: AutoCloseable](val in: collection.Iterable[A]) { /** * safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each @@ -111,46 +91,12 @@ object RapidsPluginImplicits { } } - implicit class RapidsBufferSeq[A <: RapidsBuffer](val in: collection.SeqLike[A, _]) { - /** - * safeFree: Is an implicit on a sequence of RapidsBuffer classes that tries to free each - * element of the sequence, even if prior free calls fail. In case of failure in any of the - * free calls, an Exception is thrown containing the suppressed exceptions (getSuppressed), - * if any. - */ - def safeFree(error: Throwable = null): Unit = if (in != null) { - var freeException: Throwable = null - in.foreach { element => - if (element != null) { - try { - element.free() - } catch { - case e: Throwable if error != null => error.addSuppressed(e) - case e: Throwable if freeException == null => freeException = e - case e: Throwable => freeException.addSuppressed(e) - } - } - } - if (freeException != null) { - // an exception happened while we were trying to safely free - // resources, throw the exception to alert the caller - throw freeException - } - } - } - implicit class AutoCloseableArray[A <: AutoCloseable](val in: Array[A]) { def safeClose(e: Throwable = null): Unit = if (in != null) { in.toSeq.safeClose(e) } } - implicit class RapidsBufferArray[A <: RapidsBuffer](val in: Array[A]) { - def safeFree(e: Throwable = null): Unit = if (in != null) { - in.toSeq.safeFree(e) - } - } - class IterableMapsSafely[A, From[A] <: collection.Iterable[A] with collection.IterableOps[A, From, _]] { /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index dede76cbd7e..111db8e7ff2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -250,7 +250,7 @@ class ShuffleBufferCatalog extends Logging { entries.map(bufferIdToHandle.get).map { case (_, meta) => meta } - } + }.toSeq } /** Allocate a new shuffle buffer identifier and update the shuffle block mapping. */ From 0f1f03176506b650745c005aceed5d7a66c5c5a0 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 4 Dec 2024 10:02:00 -0800 Subject: [PATCH 22/36] more scala 2.13 changes, and fix bug in host spill tracking --- .../com/nvidia/spark/rapids/implicits.scala | 2 +- .../spark/rapids/spill/SpillFramework.scala | 2 +- .../shuffle/RapidsShuffleTestHelper.scala | 27 +++++++++++-------- .../shuffle/RapidsShuffleTestHelper.scala | 8 +++--- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala b/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala index 1e4c5e39a19..6c5a0f145f2 100644 --- a/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala +++ b/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index a28ead8be33..4aaf5910d1f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -1058,7 +1058,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) } else { synchronized { val storeMaxSize = maxSize.get - if (totalSize + handle.approxSizeInBytes > storeMaxSize) { + if (totalSize > 0 && totalSize + handle.approxSizeInBytes > storeMaxSize) { // we want to try to make room for this buffer false } else { diff --git a/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index 7c73de691c6..e3eff2b2125 100644 --- a/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -36,6 +36,7 @@ package com.nvidia.spark.rapids.shuffle import java.nio.ByteBuffer import java.util.concurrent.Executor +import scala.collection.immutable import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer} @@ -184,25 +185,30 @@ object RapidsShuffleTestHelper extends MockitoSugar { MetaUtils.buildDegenerateTableMeta(new ColumnarBatch(Array.empty, 123)) } - def withMockContiguousTable[T](numRows: Long)(body: ContiguousTable => T): T = { + def buildContiguousTable(numRows: Long): ContiguousTable = { val rows: Seq[Integer] = (0 until numRows.toInt).map(Int.box) withResource(ColumnVector.fromBoxedInts(rows:_*)) { cvBase => cvBase.incRefCount() val gpuCv = GpuColumnVector.from(cvBase, IntegerType) withResource(new ColumnarBatch(Array(gpuCv))) { cb => withResource(GpuColumnVector.from(cb)) { table => - withResource(table.contiguousSplit(0, numRows.toInt)) { ct => - body(ct(1)) // we get a degenerate table at 0 and another at 2 - } + val cts = table.contiguousSplit() + cts(0) } } } } + def withMockContiguousTable[T](numRows: Long)(body: ContiguousTable => T): T = { + withResource(buildContiguousTable(numRows)) { ct => + body(ct) + } + } + def mockMetaResponse( mockTransaction: Transaction, numRows: Long, - numBatches: Int): (Seq[TableMeta], MetadataTransportBuffer) = + numBatches: Int): (immutable.Seq[TableMeta], MetadataTransportBuffer) = withMockContiguousTable(numRows) { ct => val tableMetas = (0 until numBatches).map(b => buildMockTableMeta(b, ct)) val res = ShuffleMetadata.buildMetaResponse(tableMetas) @@ -213,7 +219,7 @@ object RapidsShuffleTestHelper extends MockitoSugar { def mockDegenerateMetaResponse( mockTransaction: Transaction, - numBatches: Int): (Seq[TableMeta], MetadataTransportBuffer) = { + numBatches: Int): (immutable.Seq[TableMeta], MetadataTransportBuffer) = { val tableMetas = (0 until numBatches).map(b => buildDegenerateMockTableMeta()) val res = ShuffleMetadata.buildMetaResponse(tableMetas) val refCountedRes = new MetadataTransportBuffer(new RefCountedDirectByteBuffer(res)) @@ -245,8 +251,8 @@ object RapidsShuffleTestHelper extends MockitoSugar { tableMeta } - def getShuffleBlocks: Seq[(ShuffleBlockBatchId, Long, Int)] = { - Seq( + def getShuffleBlocks: Array[(ShuffleBlockBatchId, Long, Int)] = { + Array( (ShuffleBlockBatchId(1,1,1,1), 123L, 1), (ShuffleBlockBatchId(2,2,2,2), 456L, 2), (ShuffleBlockBatchId(3,3,3,3), 456L, 3) @@ -260,8 +266,8 @@ object RapidsShuffleTestHelper extends MockitoSugar { bmId } - def getBlocksByAddress: Array[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - val blocksByAddress = new ArrayBuffer[(BlockManagerId, Seq[(BlockId, Long, Int)])]() + def getBlocksByAddress: Array[(BlockManagerId, immutable.Seq[(BlockId, Long, Int)])] = { + val blocksByAddress = new ArrayBuffer[(BlockManagerId, immutable.Seq[(BlockId, Long, Int)])]() val blocks = getShuffleBlocks blocksByAddress.append((makeMockBlockManager("2", "2"), blocks)) blocksByAddress.toArray @@ -288,4 +294,3 @@ class MockClientConnection(mockTransaction: Transaction) extends ClientConnectio override def registerReceiveHandler(messageType: MessageType.Value): Unit = {} } - diff --git a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index b31d20ecd36..f4796ff5bbb 100644 --- a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -205,7 +205,7 @@ object RapidsShuffleTestHelper extends MockitoSugar { def mockMetaResponse( mockTransaction: Transaction, numRows: Long, - numBatches: Int): (Seq[TableMeta], MetadataTransportBuffer) = + numBatches: Int): (collection.Seq[TableMeta], MetadataTransportBuffer) = withMockContiguousTable(numRows) { ct => val tableMetas = (0 until numBatches).map(b => buildMockTableMeta(b, ct)) val res = ShuffleMetadata.buildMetaResponse(tableMetas) @@ -216,7 +216,7 @@ object RapidsShuffleTestHelper extends MockitoSugar { def mockDegenerateMetaResponse( mockTransaction: Transaction, - numBatches: Int): (Seq[TableMeta], MetadataTransportBuffer) = { + numBatches: Int): (collection.Seq[TableMeta], MetadataTransportBuffer) = { val tableMetas = (0 until numBatches).map(b => buildDegenerateMockTableMeta()) val res = ShuffleMetadata.buildMetaResponse(tableMetas) val refCountedRes = new MetadataTransportBuffer(new RefCountedDirectByteBuffer(res)) @@ -248,8 +248,8 @@ object RapidsShuffleTestHelper extends MockitoSugar { tableMeta } - def getShuffleBlocks: collection.Seq[(ShuffleBlockBatchId, Long, Int)] = { - collection.Seq( + def getShuffleBlocks: Array[(ShuffleBlockBatchId, Long, Int)] = { + Array( (ShuffleBlockBatchId(1,1,1,1), 123L, 1), (ShuffleBlockBatchId(2,2,2,2), 456L, 2), (ShuffleBlockBatchId(3,3,3,3), 456L, 3) From 3682434109bb79a32fea7fd7287511b8cc75805a Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 4 Dec 2024 11:55:58 -0800 Subject: [PATCH 23/36] scala 2.13 fixes in rapids shuffle iterator suite --- .../shuffle/RapidsShuffleIterator.scala | 7 +- .../shuffle/RapidsShuffleIteratorSuite.scala | 77 ++++--------------- .../shuffle/RapidsShuffleTestHelper.scala | 19 ++++- .../shuffle/RapidsShuffleTestHelper.scala | 18 ++++- 4 files changed, 53 insertions(+), 68 deletions(-) diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala index e06f708b6f5..f2f655efef5 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala @@ -36,6 +36,7 @@ package com.nvidia.spark.rapids.shuffle import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} +import scala.collection import scala.collection.mutable import ai.rapids.cudf.{NvtxColor, NvtxRange} @@ -71,7 +72,7 @@ class RapidsShuffleIterator( localBlockManagerId: BlockManagerId, rapidsConf: RapidsConf, transport: RapidsShuffleTransport, - blocksByAddress: Array[(BlockManagerId, Seq[(BlockId, Long, Int)])], + blocksByAddress: Array[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], metricsUpdater: ShuffleMetricsUpdater, sparkTypes: Array[DataType], taskAttemptId: Long, @@ -180,7 +181,7 @@ class RapidsShuffleIterator( val (local, remote) = blocksByAddress.partition(ba => ba._1.host == localHost) (local ++ remote).foreach { - case (blockManagerId: BlockManagerId, blockIds: Seq[(BlockId, Long, Int)]) => { + case (blockManagerId: BlockManagerId, blockIds: collection.Seq[(BlockId, Long, Int)]) => { val shuffleRequestsMapIndex: Seq[BlockIdMapIndex] = blockIds.map { case (blockId, _, mapIndex) => /** @@ -200,7 +201,7 @@ class RapidsShuffleIterator( throw new IllegalArgumentException( s"${blockId.getClass} $blockId is not currently supported") } - } + }.toSeq val client = try { transport.makeClient(blockManagerId) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala index 4962d3fa219..af24b332c83 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala @@ -31,18 +31,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val taskId = 1 try { RmmSpark.currentThreadIsDedicatedToTask(taskId) - val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - - val cl = new RapidsShuffleIterator( - RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - mockConf, - mockTransport, - blocksByAddress, - testMetricsUpdater, - Array.empty, - taskId, - mockCatalog, - 123) + val cl = + RapidsShuffleTestHelper.makeIterator( + mockConf, mockTransport, testMetricsUpdater, taskId, mockCatalog) when(mockTransaction.getStatus).thenReturn(TransactionStatus.Error) @@ -65,18 +56,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { RmmSpark.currentThreadIsDedicatedToTask(taskId) when(mockTransaction.getStatus).thenReturn(status) - val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - - val cl = spy(new RapidsShuffleIterator( - RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - mockConf, - mockTransport, - blocksByAddress, - testMetricsUpdater, - Array.empty, - taskId, - mockCatalog, - 123)) + val cl = + RapidsShuffleTestHelper.makeIterator( + mockConf, mockTransport, testMetricsUpdater, taskId, mockCatalog) val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any())).thenReturn(client) @@ -113,18 +95,10 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val taskId = 1 try { RmmSpark.currentThreadIsDedicatedToTask(taskId) - val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - - val cl = spy(new RapidsShuffleIterator( - RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - mockConf, - mockTransport, - blocksByAddress, - testMetricsUpdater, - Array.empty, - taskId, - mockCatalog, - 123)) + + val cl = + RapidsShuffleTestHelper.makeIterator( + mockConf, mockTransport, testMetricsUpdater, taskId, mockCatalog) val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any())).thenReturn(client) @@ -163,18 +137,10 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val taskId = 1 try { RmmSpark.currentThreadIsDedicatedToTask(taskId) - val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - - val cl = spy(new RapidsShuffleIterator( - RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - mockConf, - mockTransport, - blocksByAddress, - testMetricsUpdater, - Array.empty, - taskId, - mockCatalog, - 123)) + + val cl = + RapidsShuffleTestHelper.makeIterator( + mockConf, mockTransport, testMetricsUpdater, taskId, mockCatalog) when(mockTransport.makeClient(any())).thenReturn(client) doNothing().when(client).doFetch(any(), any()) @@ -199,18 +165,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val taskId = 1 try { RmmSpark.currentThreadIsDedicatedToTask(taskId) - val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - - val cl = new RapidsShuffleIterator( - RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - mockConf, - mockTransport, - blocksByAddress, - testMetricsUpdater, - Array.empty, - taskId, - mockCatalog, - 123) + val cl = + RapidsShuffleTestHelper.makeIterator( + mockConf, mockTransport, testMetricsUpdater, taskId, mockCatalog) val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any())).thenReturn(client) diff --git a/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index e3eff2b2125..525f30c7a87 100644 --- a/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark320/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -266,11 +266,24 @@ object RapidsShuffleTestHelper extends MockitoSugar { bmId } - def getBlocksByAddress: Array[(BlockManagerId, immutable.Seq[(BlockId, Long, Int)])] = { - val blocksByAddress = new ArrayBuffer[(BlockManagerId, immutable.Seq[(BlockId, Long, Int)])]() + def makeIterator(conf: RapidsConf, + transport: RapidsShuffleTransport, + testMetricsUpdater: TestShuffleMetricsUpdater, + taskId: Long, + catalog: ShuffleReceivedBufferCatalog): RapidsShuffleIterator = { + val blocksByAddress = new ArrayBuffer[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]() val blocks = getShuffleBlocks blocksByAddress.append((makeMockBlockManager("2", "2"), blocks)) - blocksByAddress.toArray + spy(new RapidsShuffleIterator( + RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), + conf, + transport, + blocksByAddress.toArray, + testMetricsUpdater, + Array.empty, + taskId, + catalog, + 123)) } } diff --git a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index f4796ff5bbb..0efcb4f1d7d 100644 --- a/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -263,11 +263,25 @@ object RapidsShuffleTestHelper extends MockitoSugar { bmId } - def getBlocksByAddress: Array[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = { + def makeIterator( + conf: RapidsConf, + transport: RapidsShuffleTransport, + testMetricsUpdater: TestShuffleMetricsUpdater, + taskId: Long, + catalog: ShuffleReceivedBufferCatalog): RapidsShuffleIterator = { val blocksByAddress = new ArrayBuffer[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]() val blocks = getShuffleBlocks blocksByAddress.append((makeMockBlockManager("2", "2"), blocks)) - blocksByAddress.toArray + spy(new RapidsShuffleIterator( + RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), + conf, + transport, + blocksByAddress.toArray, + testMetricsUpdater, + Array.empty, + taskId, + catalog, + 123)) } } From 6434057cca657f52d0761f6053ffd37767843da5 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 4 Dec 2024 14:37:47 -0800 Subject: [PATCH 24/36] Check that the dev/host optional is defined when we are spilling it --- .../spark/rapids/spill/SpillFramework.scala | 33 ++++++++----------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 4aaf5910d1f..b016a5693f1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -287,20 +287,18 @@ class SpillableHostBufferHandle private ( 0L } else { val spilled = synchronized { - if (disk.isEmpty) { + if (disk.isEmpty && host.isDefined) { withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => val outputChannel = diskHandleBuilder.getChannel GpuTaskMetrics.get.spillToDiskTime { - withResource(getHostBuffer.get) { hmb => - val iter = new HostByteBufferIterator(hmb) - iter.foreach { bb => - try { - while (bb.hasRemaining) { - outputChannel.write(bb) - } - } finally { - RapidsStorageUtils.dispose(bb) + val iter = new HostByteBufferIterator(host.get) + iter.foreach { bb => + try { + while (bb.hasRemaining) { + outputChannel.write(bb) } + } finally { + RapidsStorageUtils.dispose(bb) } } } @@ -324,11 +322,6 @@ class SpillableHostBufferHandle private ( } } - private def getHostBuffer: Option[HostMemoryBuffer] = synchronized { - host.foreach(_.incRefCount()) - host - } - override def close(): Unit = { releaseHostResource() synchronized { @@ -436,7 +429,7 @@ class SpillableDeviceBufferHandle private ( 0L } else { val spilled = synchronized { - if (host.isEmpty) { + if (host.isEmpty && dev.isDefined) { host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(dev.get)) sizeInBytes } else { @@ -521,7 +514,7 @@ class SpillableColumnarBatchHandle private ( 0L } else { val spilled = synchronized { - if (host.isEmpty) { + if (host.isEmpty && dev.isDefined) { withChunkedPacker { chunkedPacker => meta = Some(chunkedPacker.getPackedMeta) host = Some(SpillableHostBufferHandle.createHostHandleWithPacker(chunkedPacker)) @@ -654,7 +647,7 @@ class SpillableColumnarBatchFromBufferHandle private ( 0 } else { val spilled = synchronized { - if (host.isEmpty) { + if (host.isEmpty && dev.isDefined) { val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuColumnVectorFromBuffer] meta = Some(cvFromBuffer.getTableMeta) host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( @@ -748,7 +741,7 @@ class SpillableCompressedColumnarBatchHandle private ( 0L } else { val spilled = synchronized { - if (host.isEmpty) { + if (host.isEmpty && dev.isDefined) { val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuCompressedColumnVector] meta = Some(cvFromBuffer.getTableMeta) host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff( @@ -842,7 +835,7 @@ class SpillableHostColumnarBatchHandle private ( 0L } else { val spilled = synchronized { - if (disk.isEmpty) { + if (disk.isEmpty && host.isDefined) { withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => GpuTaskMetrics.get.spillToDiskTime { val dos = diskHandleBuilder.getDataOutputStream From 1990c276258de0d3854bfdc26b337f79416f2ba5 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 4 Dec 2024 20:47:53 -0800 Subject: [PATCH 25/36] A direct-added host handle doesnt trigger spill --- .../spark/rapids/spill/SpillFramework.scala | 37 ++++++++++----- .../rapids/spill/SpillFrameworkSuite.scala | 47 ++++++++++++++++++- 2 files changed, 72 insertions(+), 12 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index b016a5693f1..4165b1a6cc0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -105,13 +105,11 @@ import org.apache.spark.storage.BlockId * will create a host handle that points to a disk handle, tracking a file on disk. * * Host handles created directly, via the factory methods `SpillableHostBufferHandle(...)` or - * `SpillableHostColumnarBatchHandle(...)`, can trigger immediate spills, if we have host store - * limits. For example: if the host store limit is set to 1GB, and we add a 1.5GB host buffer via - * its factory method, we are going to spill all 1GB that we had in the store, and then track the - * 1.5GB buffer, which is above the limit. Next time we add a host object in this way, or via a - * device -> host spill, we are going to spill the 1.5GB buffer. This is a departure from how - * the spill framework used to work, where the host memory added directly did not cause spills - * directly, and only device->host spills would trigger the spill. + * `SpillableHostColumnarBatchHandle(...)`, do not trigger immediate spills. For example: + * if the host store limit is set to 1GB, and we add a 1.5GB host buffer via + * its factory method, we are going to have 2.5GB worth of host memory in the host store. + * That said, if we run out of device memory and we need to spill to host, 1.5GB will be spilled + * to disk, as device OOM triggers the pipeline spill. * * If we don't have a host store limit, spilling from the host store is done entirely via * host memory allocation failure callbacks. All objects added to the host store are tracked @@ -208,7 +206,7 @@ trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle { object SpillableHostBufferHandle extends Logging { def apply(hmb: HostMemoryBuffer): SpillableHostBufferHandle = { val handle = new SpillableHostBufferHandle(hmb.getLength, host = Some(hmb)) - SpillFramework.stores.hostStore.track(handle) + SpillFramework.stores.hostStore.trackNoSpill(handle) handle } @@ -778,7 +776,7 @@ object SpillableHostColumnarBatchHandle { def apply(cb: ColumnarBatch): SpillableHostColumnarBatchHandle = { val sizeInBytes = RapidsHostColumnVector.getTotalHostMemoryUsed(cb) val handle = new SpillableHostColumnarBatchHandle(sizeInBytes, cb.numRows(), host = Some(cb)) - SpillFramework.stores.hostStore.track(handle) + SpillFramework.stores.hostStore.trackNoSpill(handle) handle } } @@ -1037,7 +1035,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) extends SpillableStore with Logging { - private var totalSize: Long = 0L + private[spill] var totalSize: Long = 0L private def tryTrack(handle: SpillableHandle): Boolean = { if (maxSize.isEmpty || handle.approxSizeInBytes == 0) { @@ -1119,6 +1117,20 @@ class SpillableHostStore(val maxSize: Option[Long] = None) tracked } + /** + * This is a special method in the host store where spillable handles can be added + * but they will not trigger the cascade host->disk spill logic. This is to replicate + * how the stores used to work in the past, and is only called from factory + * methods that are used by client code. + */ + def trackNoSpill(handle: SpillableHandle): Unit = { + synchronized { + if (doTrack(handle)) { + totalSize += handle.approxSizeInBytes + } + } + } + override def remove(handle: SpillableHandle): Unit = { synchronized { if (doRemove(handle)) { @@ -1135,12 +1147,15 @@ class SpillableHostStore(val maxSize: Option[Long] = None) * Host store locks and disk store locks will be taken/released during this call, but * after the builder is created, no locks are held in the store. * + * @note When creating the host buffer handle, never call the factory Spillable* methods, + * instead, construct the handles directly. This is because the factory methods + * trigger a spill to disk, and that standard behavior of the spill framework so far. * @param handle a host handle that only has a size set, and no backing store. * @return the builder to be closed by caller */ def makeBuilder(handle: SpillableHostBufferHandle): SpillableHostBufferHandleBuilder = { var builder: Option[SpillableHostBufferHandleBuilder] = None - if (handle.sizeInBytes < maxSize.getOrElse(Long.MaxValue)) { + if (handle.sizeInBytes <= maxSize.getOrElse(Long.MaxValue)) { HostAlloc.tryAlloc(handle.sizeInBytes).foreach { hmb => withResource(hmb) { _ => if (trackInternal(handle)) { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala index bd53d750d29..0694e5d8319 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala @@ -19,6 +19,8 @@ package com.nvidia.spark.rapids.spill import java.io.File import java.math.RoundingMode +import scala.collection.mutable.ArrayBuffer + import ai.rapids.cudf._ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} @@ -638,6 +640,49 @@ class SpillFrameworkSuite } } + // this is a key behavior that we wanted to keep during the spill refactor + // where host objects that are added directly to the store do not cause a + // host->disk spill on their own, instead they will get spilled later + // due to device->host spills. + test("host factory methods do not spill on addition") { + SpillFramework.shutdown() + val sc = new SparkConf + // set a very small store size + sc.set(RapidsConf.HOST_SPILL_STORAGE_SIZE.key, "1KB") + SpillFramework.initialize(new RapidsConf(sc)) + + try { + // add a lot of batches, surpassing the limits of the store + val handles = new ArrayBuffer[SpillableHostColumnarBatchHandle]() + var dataTypes: Array[DataType] = null + (0 until 100).foreach { _ => + val (hostBatch, dt) = buildHostBatch() + if (dataTypes == null) { + dataTypes = dt + } + handles.append(SpillableHostColumnarBatchHandle(hostBatch)) + } + // no spill to disk + assertResult(100)(SpillFramework.stores.hostStore.numHandles) + + val dmb = DeviceMemoryBuffer.allocate(1024) + withResource(SpillableDeviceBufferHandle(dmb)) { _ => + // simulate an OOM by spilling device memory + SpillFramework.stores.deviceStore.spill(1024) + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + } + + val buffersSpilledToDisk = SpillFramework.stores.diskStore.numHandles + // we spilled to disk + assert(SpillFramework.stores.diskStore.numHandles > 0) + // and the remaining objects that didn't spill, are still in the host store + assertResult(100 - buffersSpilledToDisk)(SpillFramework.stores.hostStore.numHandles) + assert(SpillFramework.stores.hostStore.totalSize <= 1024) + } finally { + SpillFramework.shutdown() + } + } + test("direct spill to disk: when buffer exceeds maximum size") { var (bigTable, sparkTypes) = buildTableOfLongs(2 * 1024 * 1024) closeOnExcept(bigTable) { _ => @@ -1055,4 +1100,4 @@ class SpillFrameworkSuite testBufferFileDeletion(canShareDiskPaths = true) } -} \ No newline at end of file +} From ec8a267498f09df0888e492a1a2f0281037edd70 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 4 Dec 2024 20:58:50 -0800 Subject: [PATCH 26/36] Make sure we dont try to close a TableMeta only RapidsShuffleHandle --- .../com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala index e591c825d15..450622ef3ba 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{DeviceMemoryBuffer, Table} import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableColumn import com.nvidia.spark.rapids.format.TableMeta import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle @@ -28,7 +29,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch case class RapidsShuffleHandle( spillable: SpillableDeviceBufferHandle, tableMeta: TableMeta) extends AutoCloseable { override def close(): Unit = { - spillable.close() + spillable.safeClose() } } From 418bfe40cbc83dcdc21f25b775c18d291736b1cc Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Thu, 5 Dec 2024 05:03:22 -0800 Subject: [PATCH 27/36] Update copyright --- .../src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala b/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala index 6c5a0f145f2..1e4c5e39a19 100644 --- a/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala +++ b/sql-plugin/src/main/scala-2.13/com/nvidia/spark/rapids/implicits.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 8e8aefcbbfff20ccf7d61b5d296cbb712e98820f Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Fri, 6 Dec 2024 08:23:29 -0800 Subject: [PATCH 28/36] Break down spill into a two-step process, so we can inject a device synchronize between --- .../spark/rapids/spill/SpillFramework.scala | 165 +++++++++--------- 1 file changed, 78 insertions(+), 87 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 4165b1a6cc0..f0ba32a9b89 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -99,6 +99,11 @@ import org.apache.spark.storage.BlockId * responsibility to initiate that spill, and to track the spilled handle (a device spillable * would have a `host` handle, which tracks the host spilled object). * + * Spill is broken down into two methods: `spill` and `releaseSpilled`. This is a two stage + * process because we need to make sure that there is no code running kernels on the spilled + * data before we actually free it. See method documentations for `spill` and `releasedSpilled` + * for more info. + * * A cascade of spills can occur device -> host -> disk, given that host allocations can fail, or * could not fit in the SpillableHostStore's limit (if defined). In this case, the call to spill * will either create a host handle tracking an object on the host store (if we made room), or it @@ -160,6 +165,10 @@ trait SpillableHandle extends StoreHandle { /** * Method called to spill this handle. It can be triggered from the spill store, * or directly against the handle. + * + * This will not free the spilled data. If you would like to free the spill + * call `releaseSpilled` + * * @note The size returned from this method is only used by the spill framework * to track the approximate size. It should just return `approxSizeInBytes`, as * that's the size that it used when it first started tracking the object. @@ -167,6 +176,16 @@ trait SpillableHandle extends StoreHandle { */ def spill(): Long + /** + * Part two of the two-stage process for spilling. We call `releaseSpilled` after + * a handle has spilled, and after a device synchronize. This prevents a race + * between threads working on cuDF kernels, that did not synchronize while holding the + * materialized handle's refCount, and the spiller thread (the spiller thread cannot + * free a device buffer that the worker thread isn't done with). + * See https://github.com/NVIDIA/spark-rapids/issues/8610 for more info. + */ + def releaseSpilled(): Unit + /** * Method used to determine whether a handle tracks an object that could be spilled * @note At the level of `SpillableHandle`, the only requirement of spillability @@ -188,6 +207,18 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { private[spill] override def spillable: Boolean = synchronized { super.spillable && dev.isDefined } + + protected def releaseDeviceResource(): Unit = { + SpillFramework.removeFromDeviceStore(this) + synchronized { + dev.foreach(_.close()) + dev = None + } + } + + override def releaseSpilled(): Unit = { + releaseDeviceResource() + } } /** @@ -201,6 +232,18 @@ trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle { private[spill] override def spillable: Boolean = synchronized { super.spillable && host.isDefined } + + protected def releaseHostResource(): Unit = { + SpillFramework.removeFromHostStore(this) + synchronized { + host.foreach(_.close()) + host = None + } + } + + override def releaseSpilled(): Unit = { + releaseHostResource() + } } object SpillableHostBufferHandle extends Logging { @@ -284,7 +327,7 @@ class SpillableHostBufferHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { + synchronized { if (disk.isEmpty && host.isDefined) { withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => val outputChannel = diskHandleBuilder.getChannel @@ -307,19 +350,9 @@ class SpillableHostBufferHandle private ( 0L } } - releaseHostResource() - spilled } } - private def releaseHostResource(): Unit = { - SpillFramework.removeFromHostStore(this) - synchronized { - host.foreach(_.close()) - host = None - } - } - override def close(): Unit = { releaseHostResource() synchronized { @@ -399,12 +432,12 @@ class SpillableDeviceBufferHandle private ( var materialized: DeviceMemoryBuffer = null var hostHandle: SpillableHostBufferHandle = null synchronized { - if (dev.isDefined) { - materialized = dev.get - materialized.incRefCount() - } else if (host.isDefined) { + if (host.isDefined) { // since we spilled, host must be set. hostHandle = host.get + } else if (dev.isDefined) { + materialized = dev.get + materialized.incRefCount() } else { throw new IllegalStateException( "attempting to materialize a closed handle") @@ -426,7 +459,7 @@ class SpillableDeviceBufferHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { + synchronized { if (host.isEmpty && dev.isDefined) { host = Some(SpillableHostBufferHandle.createHostHandleFromDeviceBuff(dev.get)) sizeInBytes @@ -434,21 +467,11 @@ class SpillableDeviceBufferHandle private ( 0L } } - releaseDeviceResources() - spilled - } - } - - private def releaseDeviceResources(): Unit = { - SpillFramework.removeFromDeviceStore(this) - synchronized { - dev.foreach(_.close()) - dev = None } } override def close(): Unit = { - releaseDeviceResources() + releaseDeviceResource() synchronized { host.foreach(_.close()) host = None @@ -483,10 +506,10 @@ class SpillableColumnarBatchHandle private ( var materialized: ColumnarBatch = null var hostHandle: SpillableHostBufferHandle = null synchronized { - if (dev.isDefined) { - materialized = GpuColumnVector.incRefCounts(dev.get) - } else if (host.isDefined) { + if (host.isDefined) { hostHandle = host.get + } else if (dev.isDefined) { + materialized = GpuColumnVector.incRefCounts(dev.get) } else { throw new IllegalStateException( "attempting to materialize a closed handle") @@ -511,22 +534,20 @@ class SpillableColumnarBatchHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { + synchronized { if (host.isEmpty && dev.isDefined) { withChunkedPacker { chunkedPacker => meta = Some(chunkedPacker.getPackedMeta) host = Some(SpillableHostBufferHandle.createHostHandleWithPacker(chunkedPacker)) } + // We return the size we were created with. This is not the actual size + // of this batch when it is packed, and it is used by the calling code + // to figure out more or less how much did we free in the device. approxSizeInBytes } else { 0L } } - releaseDeviceResource() - // We return the size we were created with. This is not the actual size - // of this batch when it is packed, and it is used by the calling code - // to figure out more or less how much did we free in the device. - spilled } } @@ -544,14 +565,6 @@ class SpillableColumnarBatchHandle private ( } } - private def releaseDeviceResource(): Unit = { - SpillFramework.removeFromDeviceStore(this) - synchronized { - dev.foreach(_.close()) - dev = None - } - } - override def close(): Unit = { releaseDeviceResource() synchronized { @@ -616,10 +629,10 @@ class SpillableColumnarBatchFromBufferHandle private ( var materialized: ColumnarBatch = null var hostHandle: SpillableHostBufferHandle = null synchronized { - if (dev.isDefined) { - materialized = GpuColumnVector.incRefCounts(dev.get) - } else if (host.isDefined) { + if (host.isDefined) { hostHandle = host.get + } else if (dev.isDefined) { + materialized = GpuColumnVector.incRefCounts(dev.get) } else { throw new IllegalStateException( "attempting to materialize a closed handle") @@ -644,7 +657,7 @@ class SpillableColumnarBatchFromBufferHandle private ( if (!spillable) { 0 } else { - val spilled = synchronized { + synchronized { if (host.isEmpty && dev.isDefined) { val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuColumnVectorFromBuffer] meta = Some(cvFromBuffer.getTableMeta) @@ -655,16 +668,6 @@ class SpillableColumnarBatchFromBufferHandle private ( 0L } } - releaseDeviceResource() - spilled - } - } - - private def releaseDeviceResource(): Unit = { - SpillFramework.removeFromDeviceStore(this) - synchronized { - dev.foreach(_.close()) - dev = None } } @@ -713,10 +716,10 @@ class SpillableCompressedColumnarBatchHandle private ( var materialized: ColumnarBatch = null var hostHandle: SpillableHostBufferHandle = null synchronized { - if (dev.isDefined) { - materialized = GpuCompressedColumnVector.incRefCounts(dev.get) - } else if (host.isDefined) { + if (host.isDefined) { hostHandle = host.get + } else if (dev.isDefined) { + materialized = GpuCompressedColumnVector.incRefCounts(dev.get) } else { throw new IllegalStateException( "attempting to materialize a closed handle") @@ -738,7 +741,7 @@ class SpillableCompressedColumnarBatchHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { + synchronized { if (host.isEmpty && dev.isDefined) { val cvFromBuffer = dev.get.column(0).asInstanceOf[GpuCompressedColumnVector] meta = Some(cvFromBuffer.getTableMeta) @@ -749,16 +752,6 @@ class SpillableCompressedColumnarBatchHandle private ( 0L } } - releaseDeviceResource() - spilled - } - } - - private def releaseDeviceResource(): Unit = { - SpillFramework.removeFromDeviceStore(this) - synchronized { - dev.foreach(_.close()) - dev = None } } @@ -832,7 +825,7 @@ class SpillableHostColumnarBatchHandle private ( if (!spillable) { 0L } else { - val spilled = synchronized { + synchronized { if (disk.isEmpty && host.isDefined) { withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => GpuTaskMetrics.get.spillToDiskTime { @@ -847,16 +840,6 @@ class SpillableHostColumnarBatchHandle private ( 0L } } - releaseHostResource() - spilled - } - } - - private def releaseHostResource(): Unit = { - SpillFramework.removeFromHostStore(this) - synchronized { - host.foreach(_.close()) - host = None } } @@ -1019,11 +1002,19 @@ trait SpillableStore extends HandleStore[SpillableHandle] with Logging { // this thread was successful at spilling handle. amountSpilled += spilled numSpilled += 1 - } // else, either: - // - this thread lost the race and the handle was closed - // - another thread spilled it - // - the handle isn't spillable anymore, due to ref count. + } else { + // else, either: + // - this thread lost the race and the handle was closed + // - another thread spilled it + // - the handle isn't spillable anymore, due to ref count. + it.remove() + } } + // spillables is the list of handles that have to be closed + // we synchronize every thread before we release what was spilled + Cuda.deviceSynchronize() + // this is safe to be called unconditionally if another thread spilled + spillables.forEach(_.releaseSpilled()) amountSpilled } From a7db69b956a5607a9bd5ecdae66a6563c55db50b Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 9 Dec 2024 14:00:32 -0800 Subject: [PATCH 29/36] Remove unintened change in TimeZonePerfSuite --- .../com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala index cd918694b43..26ac7a177a3 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala @@ -51,7 +51,7 @@ import org.apache.spark.sql.types._ */ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAll { // perf test is disabled by default since it's a long running time in UT. - private val enablePerfTest = true // java.lang.Boolean.getBoolean("enableTimeZonePerf") + private val enablePerfTest = java.lang.Boolean.getBoolean("enableTimeZonePerf") private val timeZoneStrings = System.getProperty("TZs", "Asia/Shanghai") // rows for perf test From 03b75d5c5a723f331919d57ca9bfc2b9f3bcbfa0 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 9 Dec 2024 14:00:56 -0800 Subject: [PATCH 30/36] Fix comment in ShuffleBufferCatalog --- .../com/nvidia/spark/rapids/ShuffleBufferCatalog.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index 111db8e7ff2..542fb2deb2d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -46,9 +46,8 @@ case class ShuffleBufferId( class ShuffleBufferCatalog extends Logging { /** * Information stored for each active shuffle. - * NOTE: ArrayBuffer in blockMap must be explicitly locked when using it! - * - * @param blockMap mapping of block ID to array of buffers for the block + * A shuffle block can be comprised of multiple batches. Each batch + * is given a `ShuffleBufferId`. */ private type ShuffleInfo = ConcurrentHashMap[ShuffleBlockId, ArrayBuffer[ShuffleBufferId]] @@ -149,7 +148,6 @@ class ShuffleBufferCatalog extends Logging { trackDegenerate(bufferId, meta) } - /** * Register a new shuffle. * This must be called before any buffer identifiers associated with this shuffle can be tracked. From 29190308e56091ffbf10d8ae6224a44dbfca67e4 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 10 Dec 2024 07:16:41 -0800 Subject: [PATCH 31/36] Code review comments Signed-off-by: Alessandro Bellina --- .../spark/rapids/spill/SpillFramework.scala | 212 +++++++++++------- .../rapids/spill/SpillFrameworkSuite.scala | 18 +- 2 files changed, 135 insertions(+), 95 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index f0ba32a9b89..ff38f25fbe2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -81,7 +81,8 @@ import org.apache.spark.storage.BlockId * We handle aliasing of objects, either in the spill framework or outside, by looking at the * reference count. All objects added to the store should support a ref count. * If the ref count is greater than the expected value, we assume it is being aliased, - * and therefore we don't waste time spilling the aliased object. + * and therefore we don't waste time spilling the aliased object. Please take a look at the + * `spillable` method in each of the handles on how this is implemented. * * Materialization: * @@ -176,16 +177,6 @@ trait SpillableHandle extends StoreHandle { */ def spill(): Long - /** - * Part two of the two-stage process for spilling. We call `releaseSpilled` after - * a handle has spilled, and after a device synchronize. This prevents a race - * between threads working on cuDF kernels, that did not synchronize while holding the - * materialized handle's refCount, and the spiller thread (the spiller thread cannot - * free a device buffer that the worker thread isn't done with). - * See https://github.com/NVIDIA/spark-rapids/issues/8610 for more info. - */ - def releaseSpilled(): Unit - /** * Method used to determine whether a handle tracks an object that could be spilled * @note At the level of `SpillableHandle`, the only requirement of spillability @@ -209,14 +200,22 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { } protected def releaseDeviceResource(): Unit = { - SpillFramework.removeFromDeviceStore(this) + SpillFramework.remove(this) synchronized { dev.foreach(_.close()) dev = None } } - override def releaseSpilled(): Unit = { + /** + * Part two of the two-stage process for spilling device buffers. We call `releaseSpilled` after + * a handle has spilled, and after a device synchronize. This prevents a race + * between threads working on cuDF kernels, that did not synchronize while holding the + * materialized handle's refCount, and the spiller thread (the spiller thread cannot + * free a device buffer that the worker thread isn't done with). + * See https://github.com/NVIDIA/spark-rapids/issues/8610 for more info. + */ + def releaseSpilled(): Unit = { releaseDeviceResource() } } @@ -234,16 +233,12 @@ trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle { } protected def releaseHostResource(): Unit = { - SpillFramework.removeFromHostStore(this) + SpillFramework.remove(this) synchronized { host.foreach(_.close()) host = None } } - - override def releaseSpilled(): Unit = { - releaseHostResource() - } } object SpillableHostBufferHandle extends Logging { @@ -327,7 +322,7 @@ class SpillableHostBufferHandle private ( if (!spillable) { 0L } else { - synchronized { + val spilled = synchronized { if (disk.isEmpty && host.isDefined) { withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => val outputChannel = diskHandleBuilder.getChannel @@ -350,6 +345,8 @@ class SpillableHostBufferHandle private ( 0L } } + releaseHostResource() + spilled } } @@ -825,7 +822,7 @@ class SpillableHostColumnarBatchHandle private ( if (!spillable) { 0L } else { - synchronized { + val spilled = synchronized { if (disk.isEmpty && host.isDefined) { withResource(DiskHandleStore.makeBuilder) { diskHandleBuilder => GpuTaskMetrics.get.spillToDiskTime { @@ -840,6 +837,8 @@ class SpillableHostColumnarBatchHandle private ( 0L } } + releaseHostResource() + spilled } } @@ -905,7 +904,7 @@ class DiskHandle private( } override def close(): Unit = { - SpillFramework.removeFromDiskStore(this) + SpillFramework.remove(this) SpillFramework.stores.diskStore.deleteFile(blockId) } @@ -973,49 +972,81 @@ trait HandleStore[T <: StoreHandle] extends AutoCloseable with Logging { } } -trait SpillableStore extends HandleStore[SpillableHandle] with Logging { +trait SpillableStore[T <: SpillableHandle] + extends HandleStore[T] with Logging { protected def spillNvtxRange: NvtxRange + /** + * Internal class to provide an interface to our plan for this spill. + * + * We will build up this SpillPlan by adding spillables: handles + * that are marked spillable given the `spillable` method returning true. + * The spill store will call `trySpill`, which moves handles from the + * `spillableHandles` array to the `spilledHandles` array. + * + * At any point in time, a spill framework can call `getSpilled` + * to obtain the list of spilled handles. The device store does this + * to inject CUDA synchronization before actually releasing device handles. + */ + class SpillPlan { + private val spillableHandles = new util.ArrayList[T]() + private val spilledHandles = new util.ArrayList[T]() + + def add(spillable: T): Unit = { + spillableHandles.add(spillable) + } + + def trySpill(): Long = { + var amountSpilled = 0L + val it = spillableHandles.iterator() + while (it.hasNext) { + val handle = it.next() + val spilled = handle.spill() + if (spilled > 0) { + // this thread was successful at spilling handle. + amountSpilled += spilled + spilledHandles.add(handle) + } else { + // else, either: + // - this thread lost the race and the handle was closed + // - another thread spilled it + // - the handle isn't spillable anymore, due to ref count. + it.remove() + } + } + amountSpilled + } + + def getSpilled: util.ArrayList[T] = { + spilledHandles + } + } + + private def makeSpillPlan(spillNeeded: Long): SpillPlan = { + val plan = new SpillPlan() + var amountToSpill = 0L + val allHandles = handles.keySet().iterator() + // two threads could be here trying to spill and creating a list of spillables + while (allHandles.hasNext && amountToSpill < spillNeeded) { + val handle = allHandles.next() + if (handle.spillable) { + amountToSpill += handle.approxSizeInBytes + plan.add(handle) + } + } + plan + } + + protected def postSpill(plan: SpillPlan): Unit = {} + def spill(spillNeeded: Long): Long = { if (spillNeeded == 0) { 0L } else { withResource(spillNvtxRange) { _ => - var amountSpilled = 0L - val spillables = new util.ArrayList[SpillableHandle]() - var amountToSpill = 0L - val allHandles = handles.keySet().iterator() - // two threads could be here trying to spill and creating a list of spillables - while (allHandles.hasNext && amountToSpill < spillNeeded) { - val handle = allHandles.next() - if (handle.spillable) { - amountToSpill += handle.approxSizeInBytes - spillables.add(handle) - } - } - val it = spillables.iterator() - var numSpilled = 0 - while (it.hasNext) { - val handle = it.next() - val spilled = handle.spill() - if (spilled > 0) { - // this thread was successful at spilling handle. - amountSpilled += spilled - numSpilled += 1 - } else { - // else, either: - // - this thread lost the race and the handle was closed - // - another thread spilled it - // - the handle isn't spillable anymore, due to ref count. - it.remove() - } - } - // spillables is the list of handles that have to be closed - // we synchronize every thread before we release what was spilled - Cuda.deviceSynchronize() - // this is safe to be called unconditionally if another thread spilled - spillables.forEach(_.releaseSpilled()) - + val plan = makeSpillPlan(spillNeeded) + val amountSpilled = plan.trySpill() + postSpill(plan) amountSpilled } } @@ -1023,12 +1054,12 @@ trait SpillableStore extends HandleStore[SpillableHandle] with Logging { } class SpillableHostStore(val maxSize: Option[Long] = None) - extends SpillableStore + extends SpillableStore[HostSpillableHandle[_]] with Logging { private[spill] var totalSize: Long = 0L - private def tryTrack(handle: SpillableHandle): Boolean = { + private def tryTrack(handle: HostSpillableHandle[_]): Boolean = { if (maxSize.isEmpty || handle.approxSizeInBytes == 0) { super.doTrack(handle) // for now, keep this totalSize part, we technically @@ -1054,11 +1085,11 @@ class SpillableHostStore(val maxSize: Option[Long] = None) } } - override def track(handle: SpillableHandle): Unit = { + override def track(handle: HostSpillableHandle[_]): Unit = { trackInternal(handle) } - private def trackInternal(handle: SpillableHandle): Boolean = { + private def trackInternal(handle: HostSpillableHandle[_]): Boolean = { // try to track the handle: in the case of no limits // this should just be add to the store var tracked = false @@ -1114,7 +1145,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) * how the stores used to work in the past, and is only called from factory * methods that are used by client code. */ - def trackNoSpill(handle: SpillableHandle): Unit = { + def trackNoSpill(handle: HostSpillableHandle[_]): Unit = { synchronized { if (doTrack(handle)) { totalSize += handle.approxSizeInBytes @@ -1122,7 +1153,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) } } - override def remove(handle: SpillableHandle): Unit = { + override def remove(handle: HostSpillableHandle[_]): Unit = { synchronized { if (doRemove(handle)) { totalSize -= handle.approxSizeInBytes @@ -1287,9 +1318,17 @@ class SpillableHostStore(val maxSize: Option[Long] = None) new NvtxRange("disk spill", NvtxColor.RED) } -class SpillableDeviceStore extends SpillableStore { +class SpillableDeviceStore extends SpillableStore[DeviceSpillableHandle[_]] { override protected def spillNvtxRange: NvtxRange = new NvtxRange("device spill", NvtxColor.ORANGE) + + override def postSpill(plan: SpillPlan): Unit = { + // spillables is the list of handles that have to be closed + // we synchronize every thread before we release what was spilled + Cuda.deviceSynchronize() + // this is safe to be called unconditionally if another thread spilled + plan.getSpilled.forEach(_.releaseSpilled()) + } } class DiskHandleStore(conf: SparkConf) @@ -1448,8 +1487,8 @@ object SpillableColumnarBatchHandle { } object SpillFramework extends Logging { - // public for tests - var storesInternal: SpillableStores = _ + // pivate[spill] for tests + private[spill] var storesInternal: SpillableStores = _ def stores: SpillableStores = { if (storesInternal == null) { @@ -1473,6 +1512,8 @@ object SpillFramework extends Logging { } def initialize(rapidsConf: RapidsConf): Unit = synchronized { + require(storesInternal != null, + "cannot initialize SpillFramework multiple times") val hostSpillStorageSize = if (rapidsConf.offHeapLimitEnabled) { // Disable the limit because it is handled by the RapidsHostMemoryStore None @@ -1529,30 +1570,27 @@ object SpillFramework extends Logging { var chunkedPackBounceBufferPool: DeviceBounceBufferPool = _ - def removeFromDeviceStore(value: SpillableHandle): Unit = { - // if the stores have already shut down, we don't want to create them here - val deviceStore = synchronized { - Option(storesInternal).map(_.deviceStore) - } - deviceStore.foreach(_.remove(value)) - } - - def removeFromHostStore(value: SpillableHandle): Unit = { + def remove(handle: StoreHandle): Unit = { // if the stores have already shut down, we don't want to create them here - val hostStore = synchronized { - Option(storesInternal).map(_.hostStore) - } - hostStore.foreach(_.remove(value)) - } - - def removeFromDiskStore(value: DiskHandle): Unit = { - // if the stores have already shut down, we don't want to create them here - val maybeStores = synchronized { - Option(storesInternal).map(_.diskStore) + // so we use `storesInternal` directly. + handle match { + case ds: DeviceSpillableHandle[_] => + synchronized { + Option(storesInternal).map(_.deviceStore) + }.foreach(_.remove(ds)) + case hs: HostSpillableHandle[_] => + synchronized { + Option(storesInternal).map(_.hostStore) + }.foreach(_.remove(hs)) + case dh: DiskHandle => + synchronized { + Option(storesInternal).map(_.diskStore) + }.foreach(_.remove(dh)) + case _ => + throw new IllegalStateException( + s"unknown handle ${handle} cannot be removed") } - maybeStores.foreach(_.remove(value)) } - } /** diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala index 0694e5d8319..31377695fe4 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SpillFrameworkSuite.scala @@ -247,7 +247,7 @@ class SpillFrameworkSuite } } - test("an aliased contiguous table is not spillable (until closing the original) ") { + test("an aliased contiguous table is not spillable (until closing the alias) ") { val (table, dataTypes) = buildContiguousTable() withResource(SpillableColumnarBatchFromBufferHandle(table, dataTypes)) { handle => assertResult(1)(SpillFramework.stores.deviceStore.numHandles) @@ -255,10 +255,11 @@ class SpillFrameworkSuite val materialized = handle.materialize(dataTypes) // note that materialized is a batch "from buffer", it is not a regular batch withResource(SpillableColumnarBatchFromBufferHandle(materialized)) { aliasHandle => + // we now have two copies in the store assertResult(2)(SpillFramework.stores.deviceStore.numHandles) assert(!handle.spillable) assert(!aliasHandle.spillable) - } // we now have two copies in the store + } assert(handle.spillable) assertResult(1)(SpillFramework.stores.deviceStore.numHandles) } @@ -319,7 +320,7 @@ class SpillFrameworkSuite } } - private def decompressBach(cb: ColumnarBatch): ColumnarBatch = { + private def decompressBatch(cb: ColumnarBatch): ColumnarBatch = { val schema = new StructType().add("i", LongType) .add("j", DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 3)) val sparkTypes = GpuColumnVector.extractTypes(schema) @@ -356,7 +357,7 @@ class SpillFrameworkSuite test("a compressed batch can be added and recovered after being spilled to host") { val ct = buildCompressedBatch(0, 1000) - withResource(decompressBach(ct)) { decompressedExpected => + withResource(decompressBatch(ct)) { decompressedExpected => withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => assert(handle.spillable) SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) @@ -364,7 +365,7 @@ class SpillFrameworkSuite assert(handle.dev.isEmpty) assert(handle.host.isDefined) withResource(handle.materialize()) { materialized => - withResource(decompressBach(materialized)) { decompressed => + withResource(decompressBatch(materialized)) { decompressed => TestUtils.compareBatches(decompressedExpected, decompressed) } } @@ -374,7 +375,7 @@ class SpillFrameworkSuite test("a compressed batch can be added and recovered after being spilled to disk") { val ct = buildCompressedBatch(0, 1000) - withResource(decompressBach(ct)) { decompressedExpected => + withResource(decompressBatch(ct)) { decompressedExpected => withResource(SpillableCompressedColumnarBatchHandle(ct)) { handle => assert(handle.spillable) SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes) @@ -385,7 +386,7 @@ class SpillFrameworkSuite assert(handle.host.get.host.isEmpty) assert(handle.host.get.disk.isDefined) withResource(handle.materialize()) { materialized => - withResource(decompressBach(materialized)) { decompressed => + withResource(decompressBatch(materialized)) { decompressed => TestUtils.compareBatches(decompressedExpected, decompressed) } } @@ -621,7 +622,7 @@ class SpillFrameworkSuite test("host originated: a host batch supports aliasing and duplicated columns") { SpillFramework.shutdown() val sc = new SparkConf - // disables the host store limit + // disables the host store limit by enabling off heap limits sc.set(RapidsConf.OFF_HEAP_LIMIT_ENABLED.key, "true") SpillFramework.initialize(new RapidsConf(sc)) @@ -750,6 +751,7 @@ class SpillFrameworkSuite } } + // -1 disables the host store limit val hostSpillStorageSizes = Seq("-1", "1MB", "16MB") val spillToDiskBounceBuffers = Seq("128KB", "2MB", "128MB") val chunkedPackBounceBuffers = Seq("1MB", "8MB", "128MB") From 754a2f9ff5bdab567ef60fa047009d2c83de91d6 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 10 Dec 2024 07:20:27 -0800 Subject: [PATCH 32/36] SpillFramework.remove should be private[spill] --- .../scala/com/nvidia/spark/rapids/spill/SpillFramework.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index ff38f25fbe2..5ea7ef54e2b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -1570,7 +1570,7 @@ object SpillFramework extends Logging { var chunkedPackBounceBufferPool: DeviceBounceBufferPool = _ - def remove(handle: StoreHandle): Unit = { + private[spill] def remove(handle: StoreHandle): Unit = { // if the stores have already shut down, we don't want to create them here // so we use `storesInternal` directly. handle match { From 5bea81bf5567352cd613863cc049e49c5191c878 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 10 Dec 2024 07:31:59 -0800 Subject: [PATCH 33/36] Remove whitespace changes --- .../com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala index 8ca525a1f5c..d018726ef35 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/NonDeterministicRetrySuite.scala @@ -63,7 +63,7 @@ class NonDeterministicRetrySuite extends RmmSparkRetrySuiteBase { } } } - + test("GPU project retry with GPU rand") { def projectRand(): Seq[GpuExpression] = Seq( GpuAlias(GpuRand(GpuLiteral(RAND_SEED)), "rand")()) @@ -154,4 +154,5 @@ class NonDeterministicRetrySuite extends RmmSparkRetrySuiteBase { } } } + } From 0702bd4b62f1623c143ea69b16fa9554d18a19de Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 10 Dec 2024 08:24:53 -0800 Subject: [PATCH 34/36] storesInternal cant be private[spill] due to HostAllocSuite --- .../scala/com/nvidia/spark/rapids/spill/SpillFramework.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 5ea7ef54e2b..ce4219c2a73 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -1487,8 +1487,9 @@ object SpillableColumnarBatchHandle { } object SpillFramework extends Logging { - // pivate[spill] for tests - private[spill] var storesInternal: SpillableStores = _ + // public for tests. Some tests not in the `spill` package require setting this + // because they need fine control over allocations. + var storesInternal: SpillableStores = _ def stores: SpillableStores = { if (storesInternal == null) { From 5aacedb8131dea3cda69b2bc20450870a2ca6ae7 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 10 Dec 2024 10:59:16 -0800 Subject: [PATCH 35/36] fix require condition --- .../com/nvidia/spark/rapids/spill/SpillFramework.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index ce4219c2a73..5fb1421b083 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -1511,10 +1511,11 @@ object SpillFramework extends Logging { new SparkConf() } } - + def initialize(rapidsConf: RapidsConf): Unit = synchronized { - require(storesInternal != null, - "cannot initialize SpillFramework multiple times") + require(storesInternal == null, + s"cannot initialize SpillFramework multiple times.") + val hostSpillStorageSize = if (rapidsConf.offHeapLimitEnabled) { // Disable the limit because it is handled by the RapidsHostMemoryStore None From 88c18b73e25f87e9dbfd2dbfb63f9b218c340365 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 11 Dec 2024 14:01:19 -0800 Subject: [PATCH 36/36] Split SpillFramework.remove Signed-off-by: Alessandro Bellina --- .../spark/rapids/spill/SpillFramework.scala | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 5fb1421b083..57f2a823432 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -200,7 +200,7 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { } protected def releaseDeviceResource(): Unit = { - SpillFramework.remove(this) + SpillFramework.removeFromDeviceStore(this) synchronized { dev.foreach(_.close()) dev = None @@ -233,7 +233,7 @@ trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle { } protected def releaseHostResource(): Unit = { - SpillFramework.remove(this) + SpillFramework.removeFromHostStore(this) synchronized { host.foreach(_.close()) host = None @@ -904,7 +904,7 @@ class DiskHandle private( } override def close(): Unit = { - SpillFramework.remove(this) + SpillFramework.removeFromDiskStore(this) SpillFramework.stores.diskStore.deleteFile(blockId) } @@ -1572,26 +1572,25 @@ object SpillFramework extends Logging { var chunkedPackBounceBufferPool: DeviceBounceBufferPool = _ - private[spill] def remove(handle: StoreHandle): Unit = { - // if the stores have already shut down, we don't want to create them here - // so we use `storesInternal` directly. - handle match { - case ds: DeviceSpillableHandle[_] => - synchronized { - Option(storesInternal).map(_.deviceStore) - }.foreach(_.remove(ds)) - case hs: HostSpillableHandle[_] => - synchronized { - Option(storesInternal).map(_.hostStore) - }.foreach(_.remove(hs)) - case dh: DiskHandle => - synchronized { - Option(storesInternal).map(_.diskStore) - }.foreach(_.remove(dh)) - case _ => - throw new IllegalStateException( - s"unknown handle ${handle} cannot be removed") - } + // if the stores have already shut down, we don't want to create them here + // so we use `storesInternal` directly in these remove functions. + + private[spill] def removeFromDeviceStore(handle: DeviceSpillableHandle[_]): Unit = { + synchronized { + Option(storesInternal).map(_.deviceStore) + }.foreach(_.remove(handle)) + } + + private[spill] def removeFromHostStore(handle: HostSpillableHandle[_]): Unit = { + synchronized { + Option(storesInternal).map(_.hostStore) + }.foreach(_.remove(handle)) + } + + private[spill] def removeFromDiskStore(handle: DiskHandle): Unit = { + synchronized { + Option(storesInternal).map(_.diskStore) + }.foreach(_.remove(handle)) } }