Skip to content

Commit

Permalink
Fix GpuSemaphore to support multiple threads per task
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 committed Oct 20, 2023
1 parent 1baa350 commit f93ca97
Showing 1 changed file with 135 additions and 61 deletions.
196 changes: 135 additions & 61 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

package com.nvidia.spark.rapids

import java.util.concurrent.{ConcurrentHashMap, Semaphore}
import java.util
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, Semaphore}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import com.nvidia.spark.rapids.jni.RmmSpark
import org.apache.commons.lang3.mutable.MutableInt

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -102,7 +103,7 @@ object GpuSemaphore {

private val MAX_PERMITS = 1000

private def computeNumPermits(conf: SQLConf): Int = {
def computeNumPermits(conf: SQLConf): Int = {
val concurrentStr = conf.getConfString(RapidsConf.CONCURRENT_GPU_TASKS.key, null)
val concurrentInt = Option(concurrentStr)
.map(ConfHelper.toInteger(_, RapidsConf.CONCURRENT_GPU_TASKS.key))
Expand All @@ -115,41 +116,118 @@ object GpuSemaphore {
}
}

private final class TaskInfo(val taskId: Long) extends Logging {
private val blockedThreads = new LinkedBlockingQueue[Thread]()
private val activeThreads = new util.LinkedHashSet[Thread]()
private lazy val numPermits = GpuSemaphore.computeNumPermits(SQLConf.get)
private var hasSemaphore = false

def isHoldingSemaphore: Boolean = synchronized {
hasSemaphore
}

def getActiveThreads: Seq[Thread] = synchronized {
val ret = ArrayBuffer.empty[Thread]
activeThreads.forEach { item =>
ret += item
}
ret
}

def addCurrentThread(): Unit = synchronized {
// All threads start out in blocked, but will move out of it when
// they call blockUntilReady.
val t = Thread.currentThread()
blockedThreads.add(t)
}

private def moveToActive(t: Thread): Unit = synchronized {
if (!hasSemaphore) {
throw new IllegalStateException("Should not move to active without holding the semaphore")
}
blockedThreads.remove(t)
activeThreads.add(t)
}

def blockUntilReady(semaphore: Semaphore): Unit = {
val t = Thread.currentThread()
var done = false
var shouldBlockOnSemaphore = false
while (!done) {
try {
synchronized {
done = hasSemaphore
if (done) {
moveToActive(t)
}
shouldBlockOnSemaphore = t == blockedThreads.peek
if (!done && !shouldBlockOnSemaphore) {
wait()
if (hasSemaphore) {
moveToActive(t)
done = true
}
}
}
if (!done && shouldBlockOnSemaphore) {
semaphore.acquire(numPermits)
synchronized {
hasSemaphore = true
moveToActive(t)
notifyAll()
done = true
}
}
} catch {
case throwable: Throwable =>
synchronized {
// a thread is exiting because of an exception, so we want to reset things if needed.
blockedThreads.remove(t)
activeThreads.remove(t)
if (!hasSemaphore && shouldBlockOnSemaphore) {
// wake up the other threads so a new thread tries to get the semaphore
notifyAll()
}
}
throw throwable
}
}
}

def releaseSemaphore(semaphore: Semaphore): Unit = synchronized {
val t = Thread.currentThread()
activeThreads.remove(t)
if (hasSemaphore) {
semaphore.release(numPermits)
hasSemaphore = false
}
// This is only an issue because we are on the thread that is supposedly blocked.
// So it is really more of a sanity test. In reality there should be no threads
// that are blocked, but one might have been added here when the semaphore was held
// and now it is being released so it will block.
if (blockedThreads.remove(t)) {
throw new IllegalStateException(s"$t tried to release the semaphore when it is blocked!!!")
}
}
}

private final class GpuSemaphore() extends Logging {
import GpuSemaphore._
private val semaphore = new Semaphore(MAX_PERMITS)

// Map to track which tasks have acquired the semaphore.
case class TaskInfo(count: MutableInt, thread: Thread, numPermits: Int)
private val activeTasks = new ConcurrentHashMap[Long, TaskInfo]
// Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU
private val tasks = new ConcurrentHashMap[Long, TaskInfo]

def acquireIfNecessary(context: TaskContext): Unit = {
GpuTaskMetrics.get.semWaitTime {
val taskAttemptId = context.taskAttemptId()
val refs = activeTasks.get(taskAttemptId)
if (refs == null || refs.count.getValue == 0) {
val permits = if (refs == null) {
computeNumPermits(SQLConf.get)
} else {
refs.numPermits
}
logDebug(s"Task $taskAttemptId acquiring GPU with $permits permits")
semaphore.acquire(permits)
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
if (refs != null) {
refs.count.increment()
} else {
// first time this task has been seen
activeTasks.put(
taskAttemptId,
TaskInfo(new MutableInt(1), Thread.currentThread(), permits))
onTaskCompletion(context, completeTask)
}
GpuDeviceManager.initializeFromTask()
} else {
// Already had the semaphore, but we don't know if the thread is new or not
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
}
val taskInfo = tasks.computeIfAbsent(taskAttemptId, key => {
onTaskCompletion(context, completeTask)
new TaskInfo(key)
})
taskInfo.addCurrentThread()
taskInfo.blockUntilReady(semaphore)
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
GpuDeviceManager.initializeFromTask()
}
}

Expand All @@ -159,12 +237,9 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
GpuTaskMetrics.get.updateRetry(taskAttemptId)
RmmSpark.removeCurrentThreadAssociation()
val refs = activeTasks.get(taskAttemptId)
if (refs != null && refs.count.getValue > 0) {
if (refs.count.decrementAndGet() == 0) {
logDebug(s"Task $taskAttemptId releasing GPU with ${refs.numPermits} permits")
semaphore.release(refs.numPermits)
}
val taskInfo = tasks.get(taskAttemptId)
if (taskInfo != null) {
taskInfo.releaseSemaphore(semaphore)
}
} finally {
nvtxRange.close()
Expand All @@ -175,38 +250,37 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
GpuTaskMetrics.get.updateRetry(taskAttemptId)
RmmSpark.taskDone(taskAttemptId)
val refs = activeTasks.remove(taskAttemptId)
val refs = tasks.remove(taskAttemptId)
if (refs == null) {
throw new IllegalStateException(s"Completion of unknown task $taskAttemptId")
}
if (refs.count.getValue > 0) {
logDebug(s"Task $taskAttemptId releasing GPU with ${refs.numPermits} permits")
semaphore.release(refs.numPermits)
}
refs.releaseSemaphore(semaphore)
}

def dumpActiveStackTracesToLog(): Unit = {
try {
val stackTracesSemaphoreHeld = new mutable.ArrayBuffer[String]()
val otherStackTraces = new mutable.ArrayBuffer[String]()
activeTasks.forEach { (taskAttemptId, taskInfo) =>
val sb = new mutable.StringBuilder()
val semaphoreHeld = taskInfo.count.getValue > 0
taskInfo.thread.getStackTrace.foreach { stackTraceElement =>
sb.append(" " + stackTraceElement + "\n")
}
if (semaphoreHeld) {
stackTracesSemaphoreHeld.append(
s"Semaphore held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
} else {
otherStackTraces.append(
s"Semaphore not held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
tasks.forEach { (taskAttemptId, taskInfo) =>
val semaphoreHeld = taskInfo.isHoldingSemaphore
taskInfo.getActiveThreads.foreach { thread =>
val sb = new mutable.StringBuilder()
thread.getStackTrace.foreach { stackTraceElement =>
sb.append(" " + stackTraceElement + "\n")
}
if (semaphoreHeld) {
stackTracesSemaphoreHeld.append(
s"Semaphore held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
} else {
otherStackTraces.append(
s"Semaphore not held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
}
}
}
logWarning(s"Dumping stack traces. The semaphore sees ${activeTasks.size()} tasks, " +
s"${stackTracesSemaphoreHeld.size} are holding onto the semaphore. " +
logWarning(s"Dumping stack traces. The semaphore sees ${tasks.size()} tasks, " +
s"${stackTracesSemaphoreHeld.size} threads are holding onto the semaphore. " +
stackTracesSemaphoreHeld.mkString("\n", "\n", "\n") +
otherStackTraces.mkString("\n", "\n", "\n"))
} catch {
Expand All @@ -216,8 +290,8 @@ private final class GpuSemaphore() extends Logging {
}

def shutdown(): Unit = {
if (!activeTasks.isEmpty) {
logDebug(s"shutting down with ${activeTasks.size} tasks still registered")
if (!tasks.isEmpty) {
logDebug(s"shutting down with ${tasks.size} tasks still registered")
}
}
}
}

0 comments on commit f93ca97

Please sign in to comment.