Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GpuSemaphore to support multiple threads per task [databricks] #9501

Merged
merged 2 commits into from
Oct 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 185 additions & 61 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

package com.nvidia.spark.rapids

import java.util.concurrent.{ConcurrentHashMap, Semaphore}
import java.util
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, Semaphore}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import com.nvidia.spark.rapids.jni.RmmSpark
import org.apache.commons.lang3.mutable.MutableInt

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -102,7 +103,7 @@ object GpuSemaphore {

private val MAX_PERMITS = 1000

private def computeNumPermits(conf: SQLConf): Int = {
def computeNumPermits(conf: SQLConf): Int = {
val concurrentStr = conf.getConfString(RapidsConf.CONCURRENT_GPU_TASKS.key, null)
val concurrentInt = Option(concurrentStr)
.map(ConfHelper.toInteger(_, RapidsConf.CONCURRENT_GPU_TASKS.key))
Expand All @@ -115,41 +116,168 @@ object GpuSemaphore {
}
}

/**
* This represents the state associated with a given task. A task can have multiple threads
* associated with it. That tends to happen when there is a UDF in an external language
* a.k.a python. In that case a writer thread is created to feed the python process and
* the original thread is used as a reader thread that pulls data from the python process.
* For the GPU semaphore to avoid deadlocks we either allow all threads associated with a task
* on the GPU or none of them. But this requires coordination to block all of them or wake up
* all of them. That is the primary job of this class.
*
* It should be noted that there is no special coordination when releasing the semaphore. This
* can result in one thread running on the GPU when it thinks it has the semaphore, but does
* not. As the semaphore is used as a first line of defense to avoid using too much GPU memory
* this is considered to be okay as there are other mechanisms in place, and it should be rather
* rare.
*/
private final class SemaphoreTaskInfo() extends Logging {
/**
* This holds threads that are not on the GPU yet. Most of the time they are
* blocked waiting for the semaphore to let them on, but it may hold one
* briefly even when this task is holding the semaphore. This is a queue
* mostly to give us a simple way to elect one thread to block on the semaphore
* while the others will block with a call to `wait`. There should typically be
* very few threads in here, if any.
*/
private val blockedThreads = new LinkedBlockingQueue[Thread]()
/**
* All threads that are currently active on the GPU. This is mostly used for
* debugging. It is a `Set` to avoid duplicates, not for performance because there
* should be very few in here at a time.
*/
private val activeThreads = new util.LinkedHashSet[Thread]()
private lazy val numPermits = GpuSemaphore.computeNumPermits(SQLConf.get)
/**
* If this task holds the GPU semaphore or not.
*/
private var hasSemaphore = false

/**
* Does this task have the GPU semaphore or not. Be careful because it can change at
* any point in time. So only use it for logging.
*/
def isHoldingSemaphore: Boolean = synchronized {
hasSemaphore
}

/**
* Get the list of threads currently running on the GPU Semaphore for this task. Be
* careful because these can change at any point in time. So only use it for logging.
*/
def getActiveThreads: Seq[Thread] = synchronized {
val ret = ArrayBuffer.empty[Thread]
activeThreads.forEach { item =>
ret += item
}
ret
}

private def moveToActive(t: Thread): Unit = synchronized {
if (!hasSemaphore) {
throw new IllegalStateException("Should not move to active without holding the semaphore")
}
blockedThreads.remove(t)
activeThreads.add(t)
}

/**
* Block the current thread until we have the semaphore.
* @param semaphore what we are going to wait on.
*/
def blockUntilReady(semaphore: Semaphore): Unit = {
val t = Thread.currentThread()
// All threads start out in blocked, but will move out of it inside of the while loop.
synchronized {
blockedThreads.add(t)
}
var done = false
var shouldBlockOnSemaphore = false
while (!done) {
try {
synchronized {
// This thread can continue if this task owns the GPU semaphore. When that happens
// move the state of the thread from blocked to active.
done = hasSemaphore
if (done) {
moveToActive(t)
}
// Only one thread can block on the semaphore itself, we pick the first thread in
// blockedThread to be that one. This is arbitrary and does not matter, it is just
// simple to do.
shouldBlockOnSemaphore = t == blockedThreads.peek
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took me a little bit to understand why we need to check the last item in blockedThreads but I think I get it: there's a race to after adding to blockedThreads and calling blockUntilReady, so you want to have all the "winners" block while the last thread to get added to blockedThreads is the one that will flag the semaphore as acquired and notifies the blocked threads.

Should shouldBlockOnSemaphore be: shouldNotBlockOnSemaphore? If true as it is (we are the last thread inserted to blockedThreads) we won't block, instead we are the notifier thread, if I don't misunderstand.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this just saying this thread is at the head of the queue so if it doesn't have semaphore already this one should be blocked on it and try ot acquire it? If they aren't the head of the queue then wait()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I need to add in some comments to explain what is happening.

if (!done && !shouldBlockOnSemaphore) {
// If we need to block and are not blocking on the semaphore we will wait
// on this class until the task has the semaphore and we wake up.
wait()
if (hasSemaphore) {
moveToActive(t)
done = true
}
}
}
if (!done && shouldBlockOnSemaphore) {
// We cannot be in a synchronized block and wait on the semaphore
// so we have to release it and grab it again afterwards.
semaphore.acquire(numPermits)
synchronized {
// We now own the semaphore so we need to wake up all of the other tasks that are
// waiting.
hasSemaphore = true
moveToActive(t)
notifyAll()
done = true
}
}
} catch {
case throwable: Throwable =>
synchronized {
// a thread is exiting because of an exception, so we want to reset things,
// and possibly elect another thread to wait on the semaphore.
blockedThreads.remove(t)
activeThreads.remove(t)
if (!hasSemaphore && shouldBlockOnSemaphore) {
// wake up the other threads so a new thread tries to get the semaphore
notifyAll()
}
}
throw throwable
}
}
}

def releaseSemaphore(semaphore: Semaphore): Unit = synchronized {
val t = Thread.currentThread()
activeThreads.remove(t)
if (hasSemaphore) {
semaphore.release(numPermits)
hasSemaphore = false
}
// It should be impossible for the current thread to be blocked when releasing the semaphore
// because no blocked thread should ever leave `blockUntilReady`, which is where we put it in
// the blocked state. So this is just a sanity test that we didn't do something stupid.
if (blockedThreads.remove(t)) {
throw new IllegalStateException(s"$t tried to release the semaphore when it is blocked!!!")
}
}
}

private final class GpuSemaphore() extends Logging {
import GpuSemaphore._
private val semaphore = new Semaphore(MAX_PERMITS)

// Map to track which tasks have acquired the semaphore.
case class TaskInfo(count: MutableInt, thread: Thread, numPermits: Int)
private val activeTasks = new ConcurrentHashMap[Long, TaskInfo]
// Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU
private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo]

def acquireIfNecessary(context: TaskContext): Unit = {
GpuTaskMetrics.get.semWaitTime {
val taskAttemptId = context.taskAttemptId()
val refs = activeTasks.get(taskAttemptId)
if (refs == null || refs.count.getValue == 0) {
val permits = if (refs == null) {
computeNumPermits(SQLConf.get)
} else {
refs.numPermits
}
logDebug(s"Task $taskAttemptId acquiring GPU with $permits permits")
semaphore.acquire(permits)
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
if (refs != null) {
refs.count.increment()
} else {
// first time this task has been seen
activeTasks.put(
taskAttemptId,
TaskInfo(new MutableInt(1), Thread.currentThread(), permits))
onTaskCompletion(context, completeTask)
}
GpuDeviceManager.initializeFromTask()
} else {
// Already had the semaphore, but we don't know if the thread is new or not
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
}
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
})
taskInfo.blockUntilReady(semaphore)
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
GpuDeviceManager.initializeFromTask()
}
}

Expand All @@ -159,12 +287,9 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
GpuTaskMetrics.get.updateRetry(taskAttemptId)
RmmSpark.removeCurrentThreadAssociation()
val refs = activeTasks.get(taskAttemptId)
if (refs != null && refs.count.getValue > 0) {
if (refs.count.decrementAndGet() == 0) {
logDebug(s"Task $taskAttemptId releasing GPU with ${refs.numPermits} permits")
semaphore.release(refs.numPermits)
}
val taskInfo = tasks.get(taskAttemptId)
if (taskInfo != null) {
taskInfo.releaseSemaphore(semaphore)
}
} finally {
nvtxRange.close()
Expand All @@ -175,38 +300,37 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
GpuTaskMetrics.get.updateRetry(taskAttemptId)
RmmSpark.taskDone(taskAttemptId)
val refs = activeTasks.remove(taskAttemptId)
val refs = tasks.remove(taskAttemptId)
if (refs == null) {
throw new IllegalStateException(s"Completion of unknown task $taskAttemptId")
}
if (refs.count.getValue > 0) {
logDebug(s"Task $taskAttemptId releasing GPU with ${refs.numPermits} permits")
semaphore.release(refs.numPermits)
}
refs.releaseSemaphore(semaphore)
}

def dumpActiveStackTracesToLog(): Unit = {
try {
val stackTracesSemaphoreHeld = new mutable.ArrayBuffer[String]()
val otherStackTraces = new mutable.ArrayBuffer[String]()
activeTasks.forEach { (taskAttemptId, taskInfo) =>
val sb = new mutable.StringBuilder()
val semaphoreHeld = taskInfo.count.getValue > 0
taskInfo.thread.getStackTrace.foreach { stackTraceElement =>
sb.append(" " + stackTraceElement + "\n")
}
if (semaphoreHeld) {
stackTracesSemaphoreHeld.append(
s"Semaphore held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
} else {
otherStackTraces.append(
s"Semaphore not held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
tasks.forEach { (taskAttemptId, taskInfo) =>
val semaphoreHeld = taskInfo.isHoldingSemaphore
taskInfo.getActiveThreads.foreach { thread =>
val sb = new mutable.StringBuilder()
thread.getStackTrace.foreach { stackTraceElement =>
sb.append(" " + stackTraceElement + "\n")
}
if (semaphoreHeld) {
stackTracesSemaphoreHeld.append(
s"Semaphore held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
} else {
otherStackTraces.append(
s"Semaphore not held. " +
s"Stack trace for task attempt id $taskAttemptId:\n${sb.toString()}")
}
}
}
logWarning(s"Dumping stack traces. The semaphore sees ${activeTasks.size()} tasks, " +
s"${stackTracesSemaphoreHeld.size} are holding onto the semaphore. " +
logWarning(s"Dumping stack traces. The semaphore sees ${tasks.size()} tasks, " +
s"${stackTracesSemaphoreHeld.size} threads are holding onto the semaphore. " +
stackTracesSemaphoreHeld.mkString("\n", "\n", "\n") +
otherStackTraces.mkString("\n", "\n", "\n"))
} catch {
Expand All @@ -216,8 +340,8 @@ private final class GpuSemaphore() extends Logging {
}

def shutdown(): Unit = {
if (!activeTasks.isEmpty) {
logDebug(s"shutting down with ${activeTasks.size} tasks still registered")
if (!tasks.isEmpty) {
logDebug(s"shutting down with ${tasks.size} tasks still registered")
}
}
}
}