Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backport fixes of #11573 to branch 24.10 #11588

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ object GpuSemaphore {
* this is considered to be okay as there are other mechanisms in place, and it should be rather
* rare.
*/
private final class SemaphoreTaskInfo() extends Logging {
private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging {
/**
* This holds threads that are not on the GPU yet. Most of the time they are
* blocked waiting for the semaphore to let them on, but it may hold one
Expand Down Expand Up @@ -253,7 +253,7 @@ private final class SemaphoreTaskInfo() extends Logging {
if (!done && shouldBlockOnSemaphore) {
// We cannot be in a synchronized block and wait on the semaphore
// so we have to release it and grab it again afterwards.
semaphore.acquire(numPermits, lastHeld)
semaphore.acquire(numPermits, lastHeld, taskAttemptId)
synchronized {
// We now own the semaphore so we need to wake up all of the other tasks that are
// waiting.
Expand All @@ -280,15 +280,15 @@ private final class SemaphoreTaskInfo() extends Logging {
}
}

def tryAcquire(semaphore: GpuBackingSemaphore): Boolean = synchronized {
def tryAcquire(semaphore: GpuBackingSemaphore, taskAttemptId: Long): Boolean = synchronized {
val t = Thread.currentThread()
if (hasSemaphore) {
activeThreads.add(t)
true
} else {
if (blockedThreads.size() == 0) {
// No other threads for this task are waiting, so we might be able to grab this directly
val ret = semaphore.tryAcquire(numPermits, lastHeld)
val ret = semaphore.tryAcquire(numPermits, lastHeld, taskAttemptId)
if (ret) {
hasSemaphore = true
activeThreads.add(t)
Expand Down Expand Up @@ -333,9 +333,9 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
new SemaphoreTaskInfo(taskAttemptId)
})
if (taskInfo.tryAcquire(semaphore)) {
if (taskInfo.tryAcquire(semaphore, taskAttemptId)) {
GpuDeviceManager.initializeFromTask()
SemaphoreAcquired
} else {
Expand All @@ -357,7 +357,7 @@ private final class GpuSemaphore() extends Logging {
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
new SemaphoreTaskInfo(taskAttemptId)
})
taskInfo.blockUntilReady(semaphore)
GpuDeviceManager.initializeFromTask()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,30 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
private val lock = new ReentrantLock()
private var occupiedSlots: Int = 0

private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int) {
private case class ThreadInfo(priority: T, condition: Condition, numPermits: Int, taskId: Long) {
var signaled: Boolean = false
}

// use task id as tie breaker when priorities are equal (both are 0 because never hold lock)
private val priorityComp = Ordering.by[ThreadInfo, T](_.priority).reverse.
thenComparing((a, b) => a.taskId.compareTo(b.taskId))

// We expect a relatively small number of threads to be contending for this lock at any given
// time, therefore we are not concerned with the insertion/removal time complexity.
private val waitingQueue: PriorityQueue[ThreadInfo] =
new PriorityQueue[ThreadInfo](Ordering.by[ThreadInfo, T](_.priority).reverse)
new PriorityQueue[ThreadInfo](priorityComp)

def tryAcquire(numPermits: Int, priority: T): Boolean = {
def tryAcquire(numPermits: Int, priority: T, taskAttemptId: Long): Boolean = {
lock.lock()
try {
if (waitingQueue.size() > 0 && ordering.gt(waitingQueue.peek.priority, priority)) {
if (waitingQueue.size() > 0 &&
priorityComp.compare(
waitingQueue.peek(),
ThreadInfo(priority, null, numPermits, taskAttemptId)
) < 0) {
false
} else if (!canAcquire(numPermits)) {
}
else if (!canAcquire(numPermits)) {
Comment on lines +52 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
else if (!canAcquire(numPermits)) {
} else if (!canAcquire(numPermits)) {

false
} else {
commitAcquire(numPermits)
Expand All @@ -52,12 +61,12 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
}
}

def acquire(numPermits: Int, priority: T): Unit = {
def acquire(numPermits: Int, priority: T, taskAttemptId: Long): Unit = {
lock.lock()
try {
if (!tryAcquire(numPermits, priority)) {
if (!tryAcquire(numPermits, priority, taskAttemptId)) {
val condition = lock.newCondition()
val info = ThreadInfo(priority, condition, numPermits)
val info = ThreadInfo(priority, condition, numPermits, taskAttemptId)
try {
waitingQueue.add(info)
while (!info.signaled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,24 @@ class PrioritySemaphoreSuite extends AnyFunSuite {
test("tryAcquire should return true if permits are available") {
val semaphore = new TestPrioritySemaphore(10)

assert(semaphore.tryAcquire(5, 0))
assert(semaphore.tryAcquire(3, 0))
assert(semaphore.tryAcquire(2, 0))
assert(!semaphore.tryAcquire(1, 0))
assert(semaphore.tryAcquire(5, 0, 0))
assert(semaphore.tryAcquire(3, 0, 0))
assert(semaphore.tryAcquire(2, 0, 0))
assert(!semaphore.tryAcquire(1, 0, 0))
}

test("acquire and release should work correctly") {
val semaphore = new TestPrioritySemaphore(1)

assert(semaphore.tryAcquire(1, 0))
assert(semaphore.tryAcquire(1, 0, 0))

val t = new Thread(() => {
try {
semaphore.acquire(1, 1)
semaphore.acquire(1, 1, 0)
fail("Should not acquire permit")
} catch {
case _: InterruptedException =>
semaphore.acquire(1, 1)
semaphore.acquire(1, 1, 0)
}
})
t.start()
Expand All @@ -62,7 +62,7 @@ class PrioritySemaphoreSuite extends AnyFunSuite {

def taskWithPriority(priority: Int) = new Runnable {
override def run(): Unit = {
semaphore.acquire(1, priority)
semaphore.acquire(1, priority, 0)
results.add(priority)
semaphore.release(1)
}
Expand All @@ -84,20 +84,46 @@ class PrioritySemaphoreSuite extends AnyFunSuite {

test("low priority thread cannot surpass high priority thread") {
val semaphore = new TestPrioritySemaphore(10)
semaphore.acquire(5, 0)
semaphore.acquire(5, 0, 0)
val t = new Thread(() => {
semaphore.acquire(10, 2)
semaphore.acquire(10, 2, 0)
semaphore.release(10)
})
t.start()
Thread.sleep(100)

// Here, there should be 5 available permits, but a thread with higher priority (2)
// is waiting to acquire, therefore we should get rejected here
assert(!semaphore.tryAcquire(5, 0))
assert(!semaphore.tryAcquire(5, 0, 0))
semaphore.release(5)
t.join(1000)
// After the high priority thread finishes, we can acquire with lower priority
assert(semaphore.tryAcquire(5, 0))
assert(semaphore.tryAcquire(5, 0, 0))
}

// this case is described at https://github.com/NVIDIA/spark-rapids/pull/11574/files#r1795652488
test("thread with larger task id should not surpass smaller task id in the waiting queue") {
val semaphore = new TestPrioritySemaphore(10)
semaphore.acquire(8, 0, 0)
val t = new Thread(() => {
semaphore.acquire(5, 0, 0)
semaphore.release(5)
})
t.start()
Thread.sleep(100)

// Here, there should be 2 available permits, and a thread with same task id (0)
// is waiting to acquire 5 permits, in this case we should succeed here
assert(semaphore.tryAcquire(2, 0, 0))
semaphore.release(2)

// Here, there should be 2 available permits, but a thread with smaller task id (0)
// is waiting to acquire, therefore we should get rejected here
assert(!semaphore.tryAcquire(2, 0, 1))

semaphore.release(8)
t.join(1000)
// After the high priority thread finishes, we can acquire with lower priority
assert(semaphore.tryAcquire(2, 0, 1))
}
}