Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
Signed-off-by: Zach Puller <[email protected]>
  • Loading branch information
zpuller committed Nov 19, 2024
1 parent 6c3a566 commit 391b11f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
19 changes: 7 additions & 12 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,13 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L
}
}

private def reportHostAllocMetrics(metrics: GpuTaskMetrics): String = {
try {
val taskId = TaskContext.get().taskAttemptId()
private def getHostAllocMetricsLogStr(metrics: GpuTaskMetrics): String = {
Option(TaskContext.get()).map({ context =>
val taskId = context.taskAttemptId()
val totalSize = metrics.getHostBytesAllocated
val maxSize = metrics.getMaxHostBytesAllocated
s"total size for task $taskId is $totalSize, max size is $maxSize"
} catch {
case _: NullPointerException =>
"allocated memory outside of a task context"
}
}).getOrElse("allocated memory outside of a task context")
}

private def releasePinned(ptr: Long, amount: Long): Unit = {
Expand All @@ -72,7 +69,7 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L
}
val metrics = GpuTaskMetrics.get
metrics.decHostBytesAllocated(amount)
logDebug(reportHostAllocMetrics(metrics))
logDebug(getHostAllocMetricsLogStr(metrics))
RmmSpark.cpuDeallocate(ptr, amount)
}

Expand All @@ -82,7 +79,7 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L
}
val metrics = GpuTaskMetrics.get
metrics.decHostBytesAllocated(amount)
logDebug(reportHostAllocMetrics(metrics))
logDebug(getHostAllocMetricsLogStr(metrics))
RmmSpark.cpuDeallocate(ptr, amount)
}

Expand Down Expand Up @@ -206,11 +203,9 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L
allocAttemptFinishedWithoutException = true
} finally {
if (ret.isDefined) {
// Alternatively we could do the host watermark tracking in the JNI code to make
// it consistent with how we handle device memory tracking
val metrics = GpuTaskMetrics.get
metrics.incHostBytesAllocated(amount)
logDebug(reportHostAllocMetrics(metrics))
logDebug(getHostAllocMetricsLogStr(metrics))
RmmSpark.postCpuAllocSuccess(ret.get.getAddress, amount, blocking, isRecursive)
} else {
// shouldRetry should indicate if spill did anything for us and we should try again.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ class GpuTaskMetrics extends Serializable {

def decHostBytesAllocated(bytes: Long): Unit = {
hostBytesAllocated -= bytes
// For some reason it's possible for the task to start out by releasing resources,
// possibly from a previous task, in such case we probably should just ignore it.
hostBytesAllocated = hostBytesAllocated.max(0)
}


Expand Down

0 comments on commit 391b11f

Please sign in to comment.