From 391b11f26ace3f6ec7a82719c431a71a9c5619f6 Mon Sep 17 00:00:00 2001 From: Zach Puller Date: Tue, 19 Nov 2024 11:34:40 -0800 Subject: [PATCH] pr comments Signed-off-by: Zach Puller --- .../com/nvidia/spark/rapids/HostAlloc.scala | 19 +++++++------------ .../spark/sql/rapids/GpuTaskMetrics.scala | 3 --- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala index 6f083b78092..501749433fc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala @@ -54,16 +54,13 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L } } - private def reportHostAllocMetrics(metrics: GpuTaskMetrics): String = { - try { - val taskId = TaskContext.get().taskAttemptId() + private def getHostAllocMetricsLogStr(metrics: GpuTaskMetrics): String = { + Option(TaskContext.get()).map({ context => + val taskId = context.taskAttemptId() val totalSize = metrics.getHostBytesAllocated val maxSize = metrics.getMaxHostBytesAllocated s"total size for task $taskId is $totalSize, max size is $maxSize" - } catch { - case _: NullPointerException => - "allocated memory outside of a task context" - } + }).getOrElse("allocated memory outside of a task context") } private def releasePinned(ptr: Long, amount: Long): Unit = { @@ -72,7 +69,7 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L } val metrics = GpuTaskMetrics.get metrics.decHostBytesAllocated(amount) - logDebug(reportHostAllocMetrics(metrics)) + logDebug(getHostAllocMetricsLogStr(metrics)) RmmSpark.cpuDeallocate(ptr, amount) } @@ -82,7 +79,7 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L } val metrics = GpuTaskMetrics.get metrics.decHostBytesAllocated(amount) - logDebug(reportHostAllocMetrics(metrics)) + logDebug(getHostAllocMetricsLogStr(metrics)) RmmSpark.cpuDeallocate(ptr, amount) } @@ -206,11 +203,9 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L allocAttemptFinishedWithoutException = true } finally { if (ret.isDefined) { - // Alternatively we could do the host watermark tracking in the JNI code to make - // it consistent with how we handle device memory tracking val metrics = GpuTaskMetrics.get metrics.incHostBytesAllocated(amount) - logDebug(reportHostAllocMetrics(metrics)) + logDebug(getHostAllocMetricsLogStr(metrics)) RmmSpark.postCpuAllocSuccess(ret.get.getAddress, amount, blocking, isRecursive) } else { // shouldRetry should indicate if spill did anything for us and we should try again. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala index 043da5d4279..ec0c7050044 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuTaskMetrics.scala @@ -145,9 +145,6 @@ class GpuTaskMetrics extends Serializable { def decHostBytesAllocated(bytes: Long): Unit = { hostBytesAllocated -= bytes - // For some reason it's possible for the task to start out by releasing resources, - // possibly from a previous task, in such case we probably should just ignore it. - hostBytesAllocated = hostBytesAllocated.max(0) }