Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GpuSemaphore to support multiple threads per task [databricks] #9501

Merged
merged 2 commits into from
Oct 23, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 135 additions & 61 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

package com.nvidia.spark.rapids

import java.util.concurrent.{ConcurrentHashMap, Semaphore}
import java.util
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, Semaphore}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import com.nvidia.spark.rapids.jni.RmmSpark
import org.apache.commons.lang3.mutable.MutableInt

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -102,7 +103,7 @@ object GpuSemaphore {

private val MAX_PERMITS = 1000

private def computeNumPermits(conf: SQLConf): Int = {
def computeNumPermits(conf: SQLConf): Int = {
val concurrentStr = conf.getConfString(RapidsConf.CONCURRENT_GPU_TASKS.key, null)
val concurrentInt = Option(concurrentStr)
.map(ConfHelper.toInteger(_, RapidsConf.CONCURRENT_GPU_TASKS.key))
Expand All @@ -115,41 +116,118 @@ object GpuSemaphore {
}
}

private final class TaskInfo(val taskId: Long) extends Logging {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
private val blockedThreads = new LinkedBlockingQueue[Thread]()
private val activeThreads = new util.LinkedHashSet[Thread]()
private lazy val numPermits = GpuSemaphore.computeNumPermits(SQLConf.get)
private var hasSemaphore = false

def isHoldingSemaphore: Boolean = synchronized {
hasSemaphore
}

def getActiveThreads: Seq[Thread] = synchronized {
val ret = ArrayBuffer.empty[Thread]
activeThreads.forEach { item =>
ret += item
}
ret
}

def addCurrentThread(): Unit = synchronized {
// All threads start out in blocked, but will move out of it when
// they call blockUntilReady.
val t = Thread.currentThread()
blockedThreads.add(t)
}

private def moveToActive(t: Thread): Unit = synchronized {
if (!hasSemaphore) {
throw new IllegalStateException("Should not move to active without holding the semaphore")
}
blockedThreads.remove(t)
activeThreads.add(t)
}

def blockUntilReady(semaphore: Semaphore): Unit = {
val t = Thread.currentThread()
var done = false
var shouldBlockOnSemaphore = false
while (!done) {
try {
synchronized {
done = hasSemaphore
if (done) {
moveToActive(t)
}
shouldBlockOnSemaphore = t == blockedThreads.peek
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took me a little bit to understand why we need to check the last item in blockedThreads but I think I get it: there's a race to after adding to blockedThreads and calling blockUntilReady, so you want to have all the "winners" block while the last thread to get added to blockedThreads is the one that will flag the semaphore as acquired and notifies the blocked threads.

Should shouldBlockOnSemaphore be: shouldNotBlockOnSemaphore? If true as it is (we are the last thread inserted to blockedThreads) we won't block, instead we are the notifier thread, if I don't misunderstand.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this just saying this thread is at the head of the queue so if it doesn't have semaphore already this one should be blocked on it and try ot acquire it? If they aren't the head of the queue then wait()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I need to add in some comments to explain what is happening.

if (!done && !shouldBlockOnSemaphore) {
wait()
if (hasSemaphore) {
moveToActive(t)
done = true
}
}
}
if (!done && shouldBlockOnSemaphore) {
semaphore.acquire(numPermits)
synchronized {
hasSemaphore = true
moveToActive(t)
notifyAll()
done = true
}
}
} catch {
case throwable: Throwable =>
synchronized {
// a thread is exiting because of an exception, so we want to reset things if needed.
blockedThreads.remove(t)
activeThreads.remove(t)
if (!hasSemaphore && shouldBlockOnSemaphore) {
// wake up the other threads so a new thread tries to get the semaphore
notifyAll()
}
}
throw throwable
}
}
}

def releaseSemaphore(semaphore: Semaphore): Unit = synchronized {
val t = Thread.currentThread()
activeThreads.remove(t)
if (hasSemaphore) {
semaphore.release(numPermits)
hasSemaphore = false
}
// This is only an issue because we are on the thread that is supposedly blocked.
revans2 marked this conversation as resolved.
Show resolved Hide resolved
// So it is really more of a sanity test. In reality there should be no threads
// that are blocked, but one might have been added here when the semaphore was held
// and now it is being released so it will block.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too clear on this comment. Could you provide an example where you think this could happen?

It seems the notify logic above is doing the right things. I was trying to come up with a scenario to match the concern, but I am not sure I have it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it could happen. That is why it is an error.

if (blockedThreads.remove(t)) {
throw new IllegalStateException(s"$t tried to release the semaphore when it is blocked!!!")
}
}
}

private final class GpuSemaphore() extends Logging {
import GpuSemaphore._
private val semaphore = new Semaphore(MAX_PERMITS)

// Map to track which tasks have acquired the semaphore.
case class TaskInfo(count: MutableInt, thread: Thread, numPermits: Int)
private val activeTasks = new ConcurrentHashMap[Long, TaskInfo]
// Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU
private val tasks = new ConcurrentHashMap[Long, TaskInfo]

def acquireIfNecessary(context: TaskContext): Unit = {
GpuTaskMetrics.get.semWaitTime {
val taskAttemptId = context.taskAttemptId()
val refs = activeTasks.get(taskAttemptId)
if (refs == null || refs.count.getValue == 0) {
val permits = if (refs == null) {
computeNumPermits(SQLConf.get)
} else {
refs.numPermits
}
logDebug(s"Task $taskAttemptId acquiring GPU with $permits permits")
semaphore.acquire(permits)
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
if (refs != null) {
refs.count.increment()
} else {
// first time this task has been seen
activeTasks.put(
taskAttemptId,
TaskInfo(new MutableInt(1), Thread.currentThread(), permits))
onTaskCompletion(context, completeTask)
}
GpuDeviceManager.initializeFromTask()
} else {
// Already had the semaphore, but we don't know if the thread is new or not
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
}
val taskInfo = tasks.computeIfAbsent(taskAttemptId, key => {
onTaskCompletion(context, completeTask)
new TaskInfo(key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename key to be taskId

})
taskInfo.addCurrentThread()
taskInfo.blockUntilReady(semaphore)
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
GpuDeviceManager.initializeFromTask()
}
}

Expand All @@ -159,12 +237,9 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
GpuTaskMetrics.get.updateRetry(taskAttemptId)
RmmSpark.removeCurrentThreadAssociation()
val refs = activeTasks.get(taskAttemptId)
if (refs != null && refs.count.getValue > 0) {
if (refs.count.decrementAndGet() == 0) {
logDebug(s"Task $taskAttemptId releasing GPU with ${refs.numPermits} permits")
semaphore.release(refs.numPermits)
}
val taskInfo = tasks.get(taskAttemptId)
if (taskInfo != null) {
taskInfo.releaseSemaphore(semaphore)
}
} finally {
nvtxRange.close()
Expand All @@ -175,38 +250,37 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
GpuTaskMetrics.get.updateRetry(taskAttemptId)
RmmSpark.taskDone(taskAttemptId)
val refs = activeTasks.remove(taskAttemptId)
val refs = tasks.remove(taskAttemptId)
if (refs == null) {
throw new IllegalStateException(s"Completion of unknown task $taskAttemptId")
}
if (refs.count.getValue > 0) {
logDebug(s"Task $taskAttemptId releasing GPU with ${refs.numPermits} permits")
semaphore.release(refs.numPermits)
}
refs.releaseSemaphore(semaphore)
}

def dumpActiveStackTracesToLog(): Unit = {
try {
val stackTracesSemaphoreHeld = new mutable.ArrayBuffer[String]()
val otherStackTraces = new mutable.ArrayBuffer[String]()
activeTasks.forEach { (taskAttemptId, taskInfo) =>
val sb = new mutable.StringBuilder()
val semaphoreHeld = taskInfo.count.getValue > 0
taskInfo.thread.getStackTrace.foreach { stackTraceElement =>
sb.append(" " + stackTraceElement + "\n")
}
if (semaphoreHeld) {
stackTracesSemaphoreHeld.append(
s"Semaphore held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
} else {
otherStackTraces.append(
s"Semaphore not held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
tasks.forEach { (taskAttemptId, taskInfo) =>
val semaphoreHeld = taskInfo.isHoldingSemaphore
taskInfo.getActiveThreads.foreach { thread =>
val sb = new mutable.StringBuilder()
thread.getStackTrace.foreach { stackTraceElement =>
sb.append(" " + stackTraceElement + "\n")
}
if (semaphoreHeld) {
stackTracesSemaphoreHeld.append(
s"Semaphore held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
} else {
otherStackTraces.append(
s"Semaphore not held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
}
}
}
logWarning(s"Dumping stack traces. The semaphore sees ${activeTasks.size()} tasks, " +
s"${stackTracesSemaphoreHeld.size} are holding onto the semaphore. " +
logWarning(s"Dumping stack traces. The semaphore sees ${tasks.size()} tasks, " +
s"${stackTracesSemaphoreHeld.size} threads are holding onto the semaphore. " +
stackTracesSemaphoreHeld.mkString("\n", "\n", "\n") +
otherStackTraces.mkString("\n", "\n", "\n"))
} catch {
Expand All @@ -216,8 +290,8 @@ private final class GpuSemaphore() extends Logging {
}

def shutdown(): Unit = {
if (!activeTasks.isEmpty) {
logDebug(s"shutting down with ${activeTasks.size} tasks still registered")
if (!tasks.isEmpty) {
logDebug(s"shutting down with ${tasks.size} tasks still registered")
}
}
}
}