diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala index 1b7241d8922..c293fe4cb0d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala @@ -116,16 +116,55 @@ object GpuSemaphore { } } -private final class TaskInfo(val taskId: Long) extends Logging { +/** + * This represents the state associated with a given task. A task can have multiple threads + * associated with it. That tends to happen when there is a UDF in an external language + * a.k.a python. In that case a writer thread is created to feed the python process and + * the original thread is used as a reader thread that pulls data from the python process. + * For the GPU semaphore to avoid deadlocks we either allow all threads associated with a task + * on the GPU or none of them. But this requires coordination to block all of them or wake up + * all of them. That is the primary job of this class. + * + * It should be noted that there is no special coordination when releasing the semaphore. This + * can result in one thread running on the GPU when it thinks it has the semaphore, but does + * not. As the semaphore is used as a first line of defense to avoid using too much GPU memory + * this is considered to be okay as there are other mechanisms in place, and it should be rather + * rare. + */ +private final class SemaphoreTaskInfo() extends Logging { + /** + * This holds threads that are not on the GPU yet. Most of the time they are + * blocked waiting for the semaphore to let them on, but it may hold one + * briefly even when this task is holding the semaphore. This is a queue + * mostly to give us a simple way to elect one thread to block on the semaphore + * while the others will block with a call to `wait`. There should typically be + * very few threads in here, if any. + */ private val blockedThreads = new LinkedBlockingQueue[Thread]() + /** + * All threads that are currently active on the GPU. This is mostly used for + * debugging. It is a `Set` to avoid duplicates, not for performance because there + * should be very few in here at a time. + */ private val activeThreads = new util.LinkedHashSet[Thread]() private lazy val numPermits = GpuSemaphore.computeNumPermits(SQLConf.get) + /** + * If this task holds the GPU semaphore or not. + */ private var hasSemaphore = false + /** + * Does this task have the GPU semaphore or not. Be careful because it can change at + * any point in time. So only use it for logging. + */ def isHoldingSemaphore: Boolean = synchronized { hasSemaphore } + /** + * Get the list of threads currently running on the GPU Semaphore for this task. Be + * careful because these can change at any point in time. So only use it for logging. + */ def getActiveThreads: Seq[Thread] = synchronized { val ret = ArrayBuffer.empty[Thread] activeThreads.forEach { item => @@ -134,13 +173,6 @@ private final class TaskInfo(val taskId: Long) extends Logging { ret } - def addCurrentThread(): Unit = synchronized { - // All threads start out in blocked, but will move out of it when - // they call blockUntilReady. - val t = Thread.currentThread() - blockedThreads.add(t) - } - private def moveToActive(t: Thread): Unit = synchronized { if (!hasSemaphore) { throw new IllegalStateException("Should not move to active without holding the semaphore") @@ -149,19 +181,34 @@ private final class TaskInfo(val taskId: Long) extends Logging { activeThreads.add(t) } + /** + * Block the current thread until we have the semaphore. + * @param semaphore what we are going to wait on. + */ def blockUntilReady(semaphore: Semaphore): Unit = { val t = Thread.currentThread() + // All threads start out in blocked, but will move out of it inside of the while loop. + synchronized { + blockedThreads.add(t) + } var done = false var shouldBlockOnSemaphore = false while (!done) { try { synchronized { + // This thread can continue if this task owns the GPU semaphore. When that happens + // move the state of the thread from blocked to active. done = hasSemaphore if (done) { moveToActive(t) } + // Only one thread can block on the semaphore itself, we pick the first thread in + // blockedThread to be that one. This is arbitrary and does not matter, it is just + // simple to do. shouldBlockOnSemaphore = t == blockedThreads.peek if (!done && !shouldBlockOnSemaphore) { + // If we need to block and are not blocking on the semaphore we will wait + // on this class until the task has the semaphore and we wake up. wait() if (hasSemaphore) { moveToActive(t) @@ -170,8 +217,12 @@ private final class TaskInfo(val taskId: Long) extends Logging { } } if (!done && shouldBlockOnSemaphore) { + // We cannot be in a synchronized block and wait on the semaphore + // so we have to release it and grab it again afterwards. semaphore.acquire(numPermits) synchronized { + // We now own the semaphore so we need to wake up all of the other tasks that are + // waiting. hasSemaphore = true moveToActive(t) notifyAll() @@ -181,7 +232,8 @@ private final class TaskInfo(val taskId: Long) extends Logging { } catch { case throwable: Throwable => synchronized { - // a thread is exiting because of an exception, so we want to reset things if needed. + // a thread is exiting because of an exception, so we want to reset things, + // and possibly elect another thread to wait on the semaphore. blockedThreads.remove(t) activeThreads.remove(t) if (!hasSemaphore && shouldBlockOnSemaphore) { @@ -201,10 +253,9 @@ private final class TaskInfo(val taskId: Long) extends Logging { semaphore.release(numPermits) hasSemaphore = false } - // This is only an issue because we are on the thread that is supposedly blocked. - // So it is really more of a sanity test. In reality there should be no threads - // that are blocked, but one might have been added here when the semaphore was held - // and now it is being released so it will block. + // It should be impossible for the current thread to be blocked when releasing the semaphore + // because no blocked thread should ever leave `blockUntilReady`, which is where we put it in + // the blocked state. So this is just a sanity test that we didn't do something stupid. if (blockedThreads.remove(t)) { throw new IllegalStateException(s"$t tried to release the semaphore when it is blocked!!!") } @@ -215,16 +266,15 @@ private final class GpuSemaphore() extends Logging { import GpuSemaphore._ private val semaphore = new Semaphore(MAX_PERMITS) // Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU - private val tasks = new ConcurrentHashMap[Long, TaskInfo] + private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo] def acquireIfNecessary(context: TaskContext): Unit = { GpuTaskMetrics.get.semWaitTime { val taskAttemptId = context.taskAttemptId() - val taskInfo = tasks.computeIfAbsent(taskAttemptId, key => { + val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => { onTaskCompletion(context, completeTask) - new TaskInfo(key) + new SemaphoreTaskInfo() }) - taskInfo.addCurrentThread() taskInfo.blockUntilReady(semaphore) RmmSpark.associateCurrentThreadWithTask(taskAttemptId) GpuDeviceManager.initializeFromTask()