From 7b5c05b51c6f6a88023aa87ea30df2589d6fb295 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 18 Dec 2024 13:29:45 -0800 Subject: [PATCH] Change order of initialization so pinned pool is available for spill framework buffers Signed-off-by: Alessandro Bellina --- .../spark/rapids/GpuDeviceManager.scala | 56 ++++++++++--------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala index 42776a6cab0..4e64b73ca50 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala @@ -284,10 +284,29 @@ object GpuDeviceManager extends Logging { private var memoryEventHandler: DeviceMemoryEventHandler = _ - private def initializeRmm(gpuId: Int, rapidsConf: Option[RapidsConf]): Unit = { - if (!Rmm.isInitialized) { - val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf)) + private def initializeSpillAndMemoryEvents(conf: RapidsConf): Unit = { + SpillFramework.initialize(conf) + + memoryEventHandler = new DeviceMemoryEventHandler( + SpillFramework.stores.deviceStore, + conf.gpuOomDumpDir, + conf.gpuOomMaxRetries) + + if (conf.sparkRmmStateEnable) { + val debugLoc = if (conf.sparkRmmDebugLocation.isEmpty) { + null + } else { + conf.sparkRmmDebugLocation + } + RmmSpark.setEventHandler(memoryEventHandler, debugLoc) + } else { + logWarning("SparkRMM retry has been disabled") + Rmm.setEventHandler(memoryEventHandler) + } + } + private def initializeRmmGpuPool(gpuId: Int, conf: RapidsConf): Unit = { + if (!Rmm.isInitialized) { val poolSize = conf.chunkedPackPoolSize chunkedPackMemoryResource = if (poolSize > 0) { @@ -391,30 +410,10 @@ object GpuDeviceManager extends Logging { } } - SpillFramework.initialize(conf) - - memoryEventHandler = new DeviceMemoryEventHandler( - SpillFramework.stores.deviceStore, - conf.gpuOomDumpDir, - conf.gpuOomMaxRetries) - - if (conf.sparkRmmStateEnable) { - val debugLoc = if (conf.sparkRmmDebugLocation.isEmpty) { - null - } else { - conf.sparkRmmDebugLocation - } - RmmSpark.setEventHandler(memoryEventHandler, debugLoc) - } else { - logWarning("SparkRMM retry has been disabled") - Rmm.setEventHandler(memoryEventHandler) - } - GpuShuffleEnv.init(conf) } } - private def initializeOffHeapLimits(gpuId: Int, rapidsConf: Option[RapidsConf]): Unit = { - val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf)) + private def initializePinnedPoolAndOffHeapLimits(gpuId: Int, conf: RapidsConf): Unit = { val setCuioDefaultResource = conf.pinnedPoolCuioDefault val (pinnedSize, nonPinnedLimit) = if (conf.offHeapLimitEnabled) { logWarning("OFF HEAP MEMORY LIMITS IS ENABLED. " + @@ -508,8 +507,13 @@ object GpuDeviceManager extends Logging { "Cannot initialize memory due to previous shutdown failing") } else if (singletonMemoryInitialized == Uninitialized) { val gpu = gpuId.getOrElse(findGpuAndAcquire()) - initializeRmm(gpu, rapidsConf) - initializeOffHeapLimits(gpu, rapidsConf) + val conf = rapidsConf.getOrElse(new RapidsConf(SparkEnv.get.conf)) + initializePinnedPoolAndOffHeapLimits(gpu, conf) + initializeRmmGpuPool(gpu, conf) + // we want to initialize this last because we want to take advantage + // of pinned memory if it is configured + initializeSpillAndMemoryEvents(conf) + GpuShuffleEnv.init(conf) singletonMemoryInitialized = Initialized } }