Skip to content

Commit

Permalink
Added in docs
Browse files Browse the repository at this point in the history
  • Loading branch information
revans2 committed Oct 20, 2023
1 parent f93ca97 commit f4c4b2a
Showing 1 changed file with 67 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,55 @@ object GpuSemaphore {
}
}

private final class TaskInfo(val taskId: Long) extends Logging {
/**
* This represents the state associated with a given task. A task can have multiple threads
* associated with it. That tends to happen when there is a UDF in an external language
* a.k.a python. In that case a writer thread is created to feed the python process and
* the original thread is used as a reader thread that pulls data from the python process.
* For the GPU semaphore to avoid deadlocks we either allow all threads associated with a task
* on the GPU or none of them. But this requires coordination to block all of them or wake up
* all of them. That is the primary job of this class.
*
* It should be noted that there is no special coordination when releasing the semaphore. This
* can result in one thread running on the GPU when it thinks it has the semaphore, but does
* not. As the semaphore is used as a first line of defense to avoid using too much GPU memory
* this is considered to be okay as there are other mechanisms in place, and it should be rather
* rare.
*/
private final class SemaphoreTaskInfo() extends Logging {
/**
* This holds threads that are not on the GPU yet. Most of the time they are
* blocked waiting for the semaphore to let them on, but it may hold one
* briefly even when this task is holding the semaphore. This is a queue
* mostly to give us a simple way to elect one thread to block on the semaphore
* while the others will block with a call to `wait`. There should typically be
* very few threads in here, if any.
*/
private val blockedThreads = new LinkedBlockingQueue[Thread]()
/**
* All threads that are currently active on the GPU. This is mostly used for
* debugging. It is a `Set` to avoid duplicates, not for performance because there
* should be very few in here at a time.
*/
private val activeThreads = new util.LinkedHashSet[Thread]()
private lazy val numPermits = GpuSemaphore.computeNumPermits(SQLConf.get)
/**
* If this task holds the GPU semaphore or not.
*/
private var hasSemaphore = false

/**
* Does this task have the GPU semaphore or not. Be careful because it can change at
* any point in time. So only use it for logging.
*/
def isHoldingSemaphore: Boolean = synchronized {
hasSemaphore
}

/**
* Get the list of threads currently running on the GPU Semaphore for this task. Be
* careful because these can change at any point in time. So only use it for logging.
*/
def getActiveThreads: Seq[Thread] = synchronized {
val ret = ArrayBuffer.empty[Thread]
activeThreads.forEach { item =>
Expand All @@ -134,13 +173,6 @@ private final class TaskInfo(val taskId: Long) extends Logging {
ret
}

def addCurrentThread(): Unit = synchronized {
// All threads start out in blocked, but will move out of it when
// they call blockUntilReady.
val t = Thread.currentThread()
blockedThreads.add(t)
}

private def moveToActive(t: Thread): Unit = synchronized {
if (!hasSemaphore) {
throw new IllegalStateException("Should not move to active without holding the semaphore")
Expand All @@ -149,19 +181,34 @@ private final class TaskInfo(val taskId: Long) extends Logging {
activeThreads.add(t)
}

/**
* Block the current thread until we have the semaphore.
* @param semaphore what we are going to wait on.
*/
def blockUntilReady(semaphore: Semaphore): Unit = {
val t = Thread.currentThread()
// All threads start out in blocked, but will move out of it inside of the while loop.
synchronized {
blockedThreads.add(t)
}
var done = false
var shouldBlockOnSemaphore = false
while (!done) {
try {
synchronized {
// This thread can continue if this task owns the GPU semaphore. When that happens
// move the state of the thread from blocked to active.
done = hasSemaphore
if (done) {
moveToActive(t)
}
// Only one thread can block on the semaphore itself, we pick the first thread in
// blockedThread to be that one. This is arbitrary and does not matter, it is just
// simple to do.
shouldBlockOnSemaphore = t == blockedThreads.peek
if (!done && !shouldBlockOnSemaphore) {
// If we need to block and are not blocking on the semaphore we will wait
// on this class until the task has the semaphore and we wake up.
wait()
if (hasSemaphore) {
moveToActive(t)
Expand All @@ -170,8 +217,12 @@ private final class TaskInfo(val taskId: Long) extends Logging {
}
}
if (!done && shouldBlockOnSemaphore) {
// We cannot be in a synchronized block and wait on the semaphore
// so we have to release it and grab it again afterwards.
semaphore.acquire(numPermits)
synchronized {
// We now own the semaphore so we need to wake up all of the other tasks that are
// waiting.
hasSemaphore = true
moveToActive(t)
notifyAll()
Expand All @@ -181,7 +232,8 @@ private final class TaskInfo(val taskId: Long) extends Logging {
} catch {
case throwable: Throwable =>
synchronized {
// a thread is exiting because of an exception, so we want to reset things if needed.
// a thread is exiting because of an exception, so we want to reset things,
// and possibly elect another thread to wait on the semaphore.
blockedThreads.remove(t)
activeThreads.remove(t)
if (!hasSemaphore && shouldBlockOnSemaphore) {
Expand All @@ -201,10 +253,9 @@ private final class TaskInfo(val taskId: Long) extends Logging {
semaphore.release(numPermits)
hasSemaphore = false
}
// This is only an issue because we are on the thread that is supposedly blocked.
// So it is really more of a sanity test. In reality there should be no threads
// that are blocked, but one might have been added here when the semaphore was held
// and now it is being released so it will block.
// It should be impossible for the current thread to be blocked when releasing the semaphore
// because no blocked thread should ever leave `blockUntilReady`, which is where we put it in
// the blocked state. So this is just a sanity test that we didn't do something stupid.
if (blockedThreads.remove(t)) {
throw new IllegalStateException(s"$t tried to release the semaphore when it is blocked!!!")
}
Expand All @@ -215,16 +266,15 @@ private final class GpuSemaphore() extends Logging {
import GpuSemaphore._
private val semaphore = new Semaphore(MAX_PERMITS)
// Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU
private val tasks = new ConcurrentHashMap[Long, TaskInfo]
private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo]

def acquireIfNecessary(context: TaskContext): Unit = {
GpuTaskMetrics.get.semWaitTime {
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, key => {
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new TaskInfo(key)
new SemaphoreTaskInfo()
})
taskInfo.addCurrentThread()
taskInfo.blockUntilReady(semaphore)
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
GpuDeviceManager.initializeFromTask()
Expand Down

0 comments on commit f4c4b2a

Please sign in to comment.