From f93ca9791248b681abd308e83e3fbcf26e3a6bd2 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 20 Oct 2023 08:54:42 -0500 Subject: [PATCH] Fix GpuSemaphore to support multiple threads per task Signed-off-by: Robert (Bobby) Evans --- .../nvidia/spark/rapids/GpuSemaphore.scala | 196 ++++++++++++------ 1 file changed, 135 insertions(+), 61 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala index 3f510493c52..1b7241d8922 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala @@ -16,14 +16,15 @@ package com.nvidia.spark.rapids -import java.util.concurrent.{ConcurrentHashMap, Semaphore} +import java.util +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, Semaphore} import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.jni.RmmSpark -import org.apache.commons.lang3.mutable.MutableInt import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -102,7 +103,7 @@ object GpuSemaphore { private val MAX_PERMITS = 1000 - private def computeNumPermits(conf: SQLConf): Int = { + def computeNumPermits(conf: SQLConf): Int = { val concurrentStr = conf.getConfString(RapidsConf.CONCURRENT_GPU_TASKS.key, null) val concurrentInt = Option(concurrentStr) .map(ConfHelper.toInteger(_, RapidsConf.CONCURRENT_GPU_TASKS.key)) @@ -115,41 +116,118 @@ object GpuSemaphore { } } +private final class TaskInfo(val taskId: Long) extends Logging { + private val blockedThreads = new LinkedBlockingQueue[Thread]() + private val activeThreads = new util.LinkedHashSet[Thread]() + private lazy val numPermits = GpuSemaphore.computeNumPermits(SQLConf.get) + private var hasSemaphore = false + + def isHoldingSemaphore: Boolean = synchronized { + hasSemaphore + } + + def getActiveThreads: Seq[Thread] = synchronized { + val ret = ArrayBuffer.empty[Thread] + activeThreads.forEach { item => + ret += item + } + ret + } + + def addCurrentThread(): Unit = synchronized { + // All threads start out in blocked, but will move out of it when + // they call blockUntilReady. + val t = Thread.currentThread() + blockedThreads.add(t) + } + + private def moveToActive(t: Thread): Unit = synchronized { + if (!hasSemaphore) { + throw new IllegalStateException("Should not move to active without holding the semaphore") + } + blockedThreads.remove(t) + activeThreads.add(t) + } + + def blockUntilReady(semaphore: Semaphore): Unit = { + val t = Thread.currentThread() + var done = false + var shouldBlockOnSemaphore = false + while (!done) { + try { + synchronized { + done = hasSemaphore + if (done) { + moveToActive(t) + } + shouldBlockOnSemaphore = t == blockedThreads.peek + if (!done && !shouldBlockOnSemaphore) { + wait() + if (hasSemaphore) { + moveToActive(t) + done = true + } + } + } + if (!done && shouldBlockOnSemaphore) { + semaphore.acquire(numPermits) + synchronized { + hasSemaphore = true + moveToActive(t) + notifyAll() + done = true + } + } + } catch { + case throwable: Throwable => + synchronized { + // a thread is exiting because of an exception, so we want to reset things if needed. + blockedThreads.remove(t) + activeThreads.remove(t) + if (!hasSemaphore && shouldBlockOnSemaphore) { + // wake up the other threads so a new thread tries to get the semaphore + notifyAll() + } + } + throw throwable + } + } + } + + def releaseSemaphore(semaphore: Semaphore): Unit = synchronized { + val t = Thread.currentThread() + activeThreads.remove(t) + if (hasSemaphore) { + semaphore.release(numPermits) + hasSemaphore = false + } + // This is only an issue because we are on the thread that is supposedly blocked. + // So it is really more of a sanity test. In reality there should be no threads + // that are blocked, but one might have been added here when the semaphore was held + // and now it is being released so it will block. + if (blockedThreads.remove(t)) { + throw new IllegalStateException(s"$t tried to release the semaphore when it is blocked!!!") + } + } +} + private final class GpuSemaphore() extends Logging { import GpuSemaphore._ private val semaphore = new Semaphore(MAX_PERMITS) - - // Map to track which tasks have acquired the semaphore. - case class TaskInfo(count: MutableInt, thread: Thread, numPermits: Int) - private val activeTasks = new ConcurrentHashMap[Long, TaskInfo] + // Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU + private val tasks = new ConcurrentHashMap[Long, TaskInfo] def acquireIfNecessary(context: TaskContext): Unit = { GpuTaskMetrics.get.semWaitTime { val taskAttemptId = context.taskAttemptId() - val refs = activeTasks.get(taskAttemptId) - if (refs == null || refs.count.getValue == 0) { - val permits = if (refs == null) { - computeNumPermits(SQLConf.get) - } else { - refs.numPermits - } - logDebug(s"Task $taskAttemptId acquiring GPU with $permits permits") - semaphore.acquire(permits) - RmmSpark.associateCurrentThreadWithTask(taskAttemptId) - if (refs != null) { - refs.count.increment() - } else { - // first time this task has been seen - activeTasks.put( - taskAttemptId, - TaskInfo(new MutableInt(1), Thread.currentThread(), permits)) - onTaskCompletion(context, completeTask) - } - GpuDeviceManager.initializeFromTask() - } else { - // Already had the semaphore, but we don't know if the thread is new or not - RmmSpark.associateCurrentThreadWithTask(taskAttemptId) - } + val taskInfo = tasks.computeIfAbsent(taskAttemptId, key => { + onTaskCompletion(context, completeTask) + new TaskInfo(key) + }) + taskInfo.addCurrentThread() + taskInfo.blockUntilReady(semaphore) + RmmSpark.associateCurrentThreadWithTask(taskAttemptId) + GpuDeviceManager.initializeFromTask() } } @@ -159,12 +237,9 @@ private final class GpuSemaphore() extends Logging { val taskAttemptId = context.taskAttemptId() GpuTaskMetrics.get.updateRetry(taskAttemptId) RmmSpark.removeCurrentThreadAssociation() - val refs = activeTasks.get(taskAttemptId) - if (refs != null && refs.count.getValue > 0) { - if (refs.count.decrementAndGet() == 0) { - logDebug(s"Task $taskAttemptId releasing GPU with ${refs.numPermits} permits") - semaphore.release(refs.numPermits) - } + val taskInfo = tasks.get(taskAttemptId) + if (taskInfo != null) { + taskInfo.releaseSemaphore(semaphore) } } finally { nvtxRange.close() @@ -175,38 +250,37 @@ private final class GpuSemaphore() extends Logging { val taskAttemptId = context.taskAttemptId() GpuTaskMetrics.get.updateRetry(taskAttemptId) RmmSpark.taskDone(taskAttemptId) - val refs = activeTasks.remove(taskAttemptId) + val refs = tasks.remove(taskAttemptId) if (refs == null) { throw new IllegalStateException(s"Completion of unknown task $taskAttemptId") } - if (refs.count.getValue > 0) { - logDebug(s"Task $taskAttemptId releasing GPU with ${refs.numPermits} permits") - semaphore.release(refs.numPermits) - } + refs.releaseSemaphore(semaphore) } def dumpActiveStackTracesToLog(): Unit = { try { val stackTracesSemaphoreHeld = new mutable.ArrayBuffer[String]() val otherStackTraces = new mutable.ArrayBuffer[String]() - activeTasks.forEach { (taskAttemptId, taskInfo) => - val sb = new mutable.StringBuilder() - val semaphoreHeld = taskInfo.count.getValue > 0 - taskInfo.thread.getStackTrace.foreach { stackTraceElement => - sb.append(" " + stackTraceElement + "\n") - } - if (semaphoreHeld) { - stackTracesSemaphoreHeld.append( - s"Semaphore held. " + - s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}") - } else { - otherStackTraces.append( - s"Semaphore not held. " + - s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}") + tasks.forEach { (taskAttemptId, taskInfo) => + val semaphoreHeld = taskInfo.isHoldingSemaphore + taskInfo.getActiveThreads.foreach { thread => + val sb = new mutable.StringBuilder() + thread.getStackTrace.foreach { stackTraceElement => + sb.append(" " + stackTraceElement + "\n") + } + if (semaphoreHeld) { + stackTracesSemaphoreHeld.append( + s"Semaphore held. " + + s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}") + } else { + otherStackTraces.append( + s"Semaphore not held. " + + s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}") + } } } - logWarning(s"Dumping stack traces. The semaphore sees ${activeTasks.size()} tasks, " + - s"${stackTracesSemaphoreHeld.size} are holding onto the semaphore. " + + logWarning(s"Dumping stack traces. The semaphore sees ${tasks.size()} tasks, " + + s"${stackTracesSemaphoreHeld.size} threads are holding onto the semaphore. " + stackTracesSemaphoreHeld.mkString("\n", "\n", "\n") + otherStackTraces.mkString("\n", "\n", "\n")) } catch { @@ -216,8 +290,8 @@ private final class GpuSemaphore() extends Logging { } def shutdown(): Unit = { - if (!activeTasks.isEmpty) { - logDebug(s"shutting down with ${activeTasks.size} tasks still registered") + if (!tasks.isEmpty) { + logDebug(s"shutting down with ${tasks.size} tasks still registered") } } -} +} \ No newline at end of file