diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala index 3f510493c52..c293fe4cb0d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala @@ -16,14 +16,15 @@ package com.nvidia.spark.rapids -import java.util.concurrent.{ConcurrentHashMap, Semaphore} +import java.util +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, Semaphore} import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.jni.RmmSpark -import org.apache.commons.lang3.mutable.MutableInt import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -102,7 +103,7 @@ object GpuSemaphore { private val MAX_PERMITS = 1000 - private def computeNumPermits(conf: SQLConf): Int = { + def computeNumPermits(conf: SQLConf): Int = { val concurrentStr = conf.getConfString(RapidsConf.CONCURRENT_GPU_TASKS.key, null) val concurrentInt = Option(concurrentStr) .map(ConfHelper.toInteger(_, RapidsConf.CONCURRENT_GPU_TASKS.key)) @@ -115,41 +116,168 @@ object GpuSemaphore { } } +/** + * This represents the state associated with a given task. A task can have multiple threads + * associated with it. That tends to happen when there is a UDF in an external language + * a.k.a python. In that case a writer thread is created to feed the python process and + * the original thread is used as a reader thread that pulls data from the python process. + * For the GPU semaphore to avoid deadlocks we either allow all threads associated with a task + * on the GPU or none of them. But this requires coordination to block all of them or wake up + * all of them. That is the primary job of this class. + * + * It should be noted that there is no special coordination when releasing the semaphore. This + * can result in one thread running on the GPU when it thinks it has the semaphore, but does + * not. As the semaphore is used as a first line of defense to avoid using too much GPU memory + * this is considered to be okay as there are other mechanisms in place, and it should be rather + * rare. + */ +private final class SemaphoreTaskInfo() extends Logging { + /** + * This holds threads that are not on the GPU yet. Most of the time they are + * blocked waiting for the semaphore to let them on, but it may hold one + * briefly even when this task is holding the semaphore. This is a queue + * mostly to give us a simple way to elect one thread to block on the semaphore + * while the others will block with a call to `wait`. There should typically be + * very few threads in here, if any. + */ + private val blockedThreads = new LinkedBlockingQueue[Thread]() + /** + * All threads that are currently active on the GPU. This is mostly used for + * debugging. It is a `Set` to avoid duplicates, not for performance because there + * should be very few in here at a time. + */ + private val activeThreads = new util.LinkedHashSet[Thread]() + private lazy val numPermits = GpuSemaphore.computeNumPermits(SQLConf.get) + /** + * If this task holds the GPU semaphore or not. + */ + private var hasSemaphore = false + + /** + * Does this task have the GPU semaphore or not. Be careful because it can change at + * any point in time. So only use it for logging. + */ + def isHoldingSemaphore: Boolean = synchronized { + hasSemaphore + } + + /** + * Get the list of threads currently running on the GPU Semaphore for this task. Be + * careful because these can change at any point in time. So only use it for logging. + */ + def getActiveThreads: Seq[Thread] = synchronized { + val ret = ArrayBuffer.empty[Thread] + activeThreads.forEach { item => + ret += item + } + ret + } + + private def moveToActive(t: Thread): Unit = synchronized { + if (!hasSemaphore) { + throw new IllegalStateException("Should not move to active without holding the semaphore") + } + blockedThreads.remove(t) + activeThreads.add(t) + } + + /** + * Block the current thread until we have the semaphore. + * @param semaphore what we are going to wait on. + */ + def blockUntilReady(semaphore: Semaphore): Unit = { + val t = Thread.currentThread() + // All threads start out in blocked, but will move out of it inside of the while loop. + synchronized { + blockedThreads.add(t) + } + var done = false + var shouldBlockOnSemaphore = false + while (!done) { + try { + synchronized { + // This thread can continue if this task owns the GPU semaphore. When that happens + // move the state of the thread from blocked to active. + done = hasSemaphore + if (done) { + moveToActive(t) + } + // Only one thread can block on the semaphore itself, we pick the first thread in + // blockedThread to be that one. This is arbitrary and does not matter, it is just + // simple to do. + shouldBlockOnSemaphore = t == blockedThreads.peek + if (!done && !shouldBlockOnSemaphore) { + // If we need to block and are not blocking on the semaphore we will wait + // on this class until the task has the semaphore and we wake up. + wait() + if (hasSemaphore) { + moveToActive(t) + done = true + } + } + } + if (!done && shouldBlockOnSemaphore) { + // We cannot be in a synchronized block and wait on the semaphore + // so we have to release it and grab it again afterwards. + semaphore.acquire(numPermits) + synchronized { + // We now own the semaphore so we need to wake up all of the other tasks that are + // waiting. + hasSemaphore = true + moveToActive(t) + notifyAll() + done = true + } + } + } catch { + case throwable: Throwable => + synchronized { + // a thread is exiting because of an exception, so we want to reset things, + // and possibly elect another thread to wait on the semaphore. + blockedThreads.remove(t) + activeThreads.remove(t) + if (!hasSemaphore && shouldBlockOnSemaphore) { + // wake up the other threads so a new thread tries to get the semaphore + notifyAll() + } + } + throw throwable + } + } + } + + def releaseSemaphore(semaphore: Semaphore): Unit = synchronized { + val t = Thread.currentThread() + activeThreads.remove(t) + if (hasSemaphore) { + semaphore.release(numPermits) + hasSemaphore = false + } + // It should be impossible for the current thread to be blocked when releasing the semaphore + // because no blocked thread should ever leave `blockUntilReady`, which is where we put it in + // the blocked state. So this is just a sanity test that we didn't do something stupid. + if (blockedThreads.remove(t)) { + throw new IllegalStateException(s"$t tried to release the semaphore when it is blocked!!!") + } + } +} + private final class GpuSemaphore() extends Logging { import GpuSemaphore._ private val semaphore = new Semaphore(MAX_PERMITS) - - // Map to track which tasks have acquired the semaphore. - case class TaskInfo(count: MutableInt, thread: Thread, numPermits: Int) - private val activeTasks = new ConcurrentHashMap[Long, TaskInfo] + // Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU + private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo] def acquireIfNecessary(context: TaskContext): Unit = { GpuTaskMetrics.get.semWaitTime { val taskAttemptId = context.taskAttemptId() - val refs = activeTasks.get(taskAttemptId) - if (refs == null || refs.count.getValue == 0) { - val permits = if (refs == null) { - computeNumPermits(SQLConf.get) - } else { - refs.numPermits - } - logDebug(s"Task $taskAttemptId acquiring GPU with $permits permits") - semaphore.acquire(permits) - RmmSpark.associateCurrentThreadWithTask(taskAttemptId) - if (refs != null) { - refs.count.increment() - } else { - // first time this task has been seen - activeTasks.put( - taskAttemptId, - TaskInfo(new MutableInt(1), Thread.currentThread(), permits)) - onTaskCompletion(context, completeTask) - } - GpuDeviceManager.initializeFromTask() - } else { - // Already had the semaphore, but we don't know if the thread is new or not - RmmSpark.associateCurrentThreadWithTask(taskAttemptId) - } + val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => { + onTaskCompletion(context, completeTask) + new SemaphoreTaskInfo() + }) + taskInfo.blockUntilReady(semaphore) + RmmSpark.associateCurrentThreadWithTask(taskAttemptId) + GpuDeviceManager.initializeFromTask() } } @@ -159,12 +287,9 @@ private final class GpuSemaphore() extends Logging { val taskAttemptId = context.taskAttemptId() GpuTaskMetrics.get.updateRetry(taskAttemptId) RmmSpark.removeCurrentThreadAssociation() - val refs = activeTasks.get(taskAttemptId) - if (refs != null && refs.count.getValue > 0) { - if (refs.count.decrementAndGet() == 0) { - logDebug(s"Task $taskAttemptId releasing GPU with ${refs.numPermits} permits") - semaphore.release(refs.numPermits) - } + val taskInfo = tasks.get(taskAttemptId) + if (taskInfo != null) { + taskInfo.releaseSemaphore(semaphore) } } finally { nvtxRange.close() @@ -175,38 +300,37 @@ private final class GpuSemaphore() extends Logging { val taskAttemptId = context.taskAttemptId() GpuTaskMetrics.get.updateRetry(taskAttemptId) RmmSpark.taskDone(taskAttemptId) - val refs = activeTasks.remove(taskAttemptId) + val refs = tasks.remove(taskAttemptId) if (refs == null) { throw new IllegalStateException(s"Completion of unknown task $taskAttemptId") } - if (refs.count.getValue > 0) { - logDebug(s"Task $taskAttemptId releasing GPU with ${refs.numPermits} permits") - semaphore.release(refs.numPermits) - } + refs.releaseSemaphore(semaphore) } def dumpActiveStackTracesToLog(): Unit = { try { val stackTracesSemaphoreHeld = new mutable.ArrayBuffer[String]() val otherStackTraces = new mutable.ArrayBuffer[String]() - activeTasks.forEach { (taskAttemptId, taskInfo) => - val sb = new mutable.StringBuilder() - val semaphoreHeld = taskInfo.count.getValue > 0 - taskInfo.thread.getStackTrace.foreach { stackTraceElement => - sb.append(" " + stackTraceElement + "\n") - } - if (semaphoreHeld) { - stackTracesSemaphoreHeld.append( - s"Semaphore held. " + - s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}") - } else { - otherStackTraces.append( - s"Semaphore not held. " + - s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}") + tasks.forEach { (taskAttemptId, taskInfo) => + val semaphoreHeld = taskInfo.isHoldingSemaphore + taskInfo.getActiveThreads.foreach { thread => + val sb = new mutable.StringBuilder() + thread.getStackTrace.foreach { stackTraceElement => + sb.append(" " + stackTraceElement + "\n") + } + if (semaphoreHeld) { + stackTracesSemaphoreHeld.append( + s"Semaphore held. " + + s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}") + } else { + otherStackTraces.append( + s"Semaphore not held. " + + s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}") + } } } - logWarning(s"Dumping stack traces. The semaphore sees ${activeTasks.size()} tasks, " + - s"${stackTracesSemaphoreHeld.size} are holding onto the semaphore. " + + logWarning(s"Dumping stack traces. The semaphore sees ${tasks.size()} tasks, " + + s"${stackTracesSemaphoreHeld.size} threads are holding onto the semaphore. " + stackTracesSemaphoreHeld.mkString("\n", "\n", "\n") + otherStackTraces.mkString("\n", "\n", "\n")) } catch { @@ -216,8 +340,8 @@ private final class GpuSemaphore() extends Logging { } def shutdown(): Unit = { - if (!activeTasks.isEmpty) { - logDebug(s"shutting down with ${activeTasks.size} tasks still registered") + if (!tasks.isEmpty) { + logDebug(s"shutting down with ${tasks.size} tasks still registered") } } -} +} \ No newline at end of file