Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a PrioritySemaphore to back the GpuSemaphore #11376

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.nvidia.spark.rapids

import java.util
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, Semaphore}
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -183,6 +183,9 @@ private final class SemaphoreTaskInfo() extends Logging {
* If this task holds the GPU semaphore or not.
*/
private var hasSemaphore = false
private var lastHeld: Long = 0

type GpuBackingSemaphore = PrioritySemaphore[Long]

/**
* Does this task have the GPU semaphore or not. Be careful because it can change at
Expand Down Expand Up @@ -216,7 +219,7 @@ private final class SemaphoreTaskInfo() extends Logging {
* Block the current thread until we have the semaphore.
* @param semaphore what we are going to wait on.
*/
def blockUntilReady(semaphore: Semaphore): Unit = {
def blockUntilReady(semaphore: GpuBackingSemaphore): Unit = {
val t = Thread.currentThread()
// All threads start out in blocked, but will move out of it inside of the while loop.
synchronized {
Expand Down Expand Up @@ -250,7 +253,7 @@ private final class SemaphoreTaskInfo() extends Logging {
if (!done && shouldBlockOnSemaphore) {
// We cannot be in a synchronized block and wait on the semaphore
// so we have to release it and grab it again afterwards.
semaphore.acquire(numPermits)
semaphore.acquire(numPermits, lastHeld)
synchronized {
// We now own the semaphore so we need to wake up all of the other tasks that are
// waiting.
Expand All @@ -277,7 +280,7 @@ private final class SemaphoreTaskInfo() extends Logging {
}
}

def tryAcquire(semaphore: Semaphore): Boolean = synchronized {
def tryAcquire(semaphore: GpuBackingSemaphore): Boolean = synchronized {
val t = Thread.currentThread()
if (hasSemaphore) {
activeThreads.add(t)
Expand All @@ -299,12 +302,13 @@ private final class SemaphoreTaskInfo() extends Logging {
}
}

def releaseSemaphore(semaphore: Semaphore): Unit = synchronized {
def releaseSemaphore(semaphore: GpuBackingSemaphore): Unit = synchronized {
val t = Thread.currentThread()
activeThreads.remove(t)
if (hasSemaphore) {
semaphore.release(numPermits)
hasSemaphore = false
lastHeld = System.currentTimeMillis()
}
// It should be impossible for the current thread to be blocked when releasing the semaphore
// because no blocked thread should ever leave `blockUntilReady`, which is where we put it in
Expand All @@ -317,7 +321,9 @@ private final class SemaphoreTaskInfo() extends Logging {

private final class GpuSemaphore() extends Logging {
import GpuSemaphore._
private val semaphore = new Semaphore(MAX_PERMITS)

type GpuBackingSemaphore = PrioritySemaphore[Long]
private val semaphore = new GpuBackingSemaphore(MAX_PERMITS)
// Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU
private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import java.util.concurrent.locks.{Condition, ReentrantLock}

import scala.collection.mutable.PriorityQueue

object PrioritySemaphore {
private val DEFAULT_MAX_PERMITS = 1000
}

class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) {
// This lock is used to generate condition variables, which affords us the flexibility to notify
// specific threads at a time. If we use the regular synchronized pattern, we have to either
// notify randomly, or if we try creating condition variables not tied to a shared lock, they
// won't work together properly, and we see things like deadlocks.
private val lock = new ReentrantLock()
revans2 marked this conversation as resolved.
Show resolved Hide resolved
private var occupiedSlots: Int = 0

private case class ThreadInfo(priority: T, condition: Condition)

// We expect a relatively small number of threads to be contending for this lock at any given
// time, therefore we are not concerned with the insertion/removal time complexity.
private val waitingQueue: PriorityQueue[ThreadInfo] = PriorityQueue()(Ordering.by(_.priority))
zpuller marked this conversation as resolved.
Show resolved Hide resolved

def this()(implicit ordering: Ordering[T]) = this(PrioritySemaphore.DEFAULT_MAX_PERMITS)(ordering)

def tryAcquire(numPermits: Int): Boolean = {
lock.lock()
try {
if (canAcquire(numPermits)) {
commitAcquire(numPermits)
true
} else {
false
}
} finally {
lock.unlock()
}
}

def acquire(numPermits: Int, priority: T): Unit = {
lock.lock()
try {
val condition = lock.newCondition()
while (!canAcquire(numPermits)) {
waitingQueue.enqueue(ThreadInfo(priority, condition))
condition.await()
}
Comment on lines +61 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have an issue here with spurious wakeups? Seems like we'll enqueue multiple ThreadInfo instances. Seems like we should do something like this to protect against it:

var queued = false
while (!canAcquire(numPermits)) {
  if (!queued) {
    waitingQueue.enqueue(ThreadInfo(priority, condition))
    queued = true
  }
  condition.await()
}

Copy link
Collaborator Author

@zpuller zpuller Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you say spurious wakeups, do you mean something triggered by another thread releaseing, or by another means? In the former case, the original ThreadInfo is dequeued. In the latter, I think I agree with you, just not sure how that can happen exactly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spurious wakeups can theoretically occur anytime one waits on a condition variable. It's often OS dependent, but it's documented that it can happen and that programmers need to account for it. See the documentation for java.util.concurrent.locks.Condition.awaitNanos (which Condition.await points to) which mentions spurious wakeup as a possibility.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, I can push a fix for that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch @jlowe

commitAcquire(numPermits)

} finally {
lock.unlock()
}}

private def commitAcquire(numPermits: Int): Unit = {
occupiedSlots += numPermits
}

def release(numPermits: Int): Unit = {
lock.lock()
try {
occupiedSlots -= numPermits
if (waitingQueue.nonEmpty) {
val nextThread = waitingQueue.dequeue()
nextThread.condition.signal()
}
} finally {
lock.unlock()
}
}

private def canAcquire(numPermits: Int): Boolean = {
occupiedSlots + numPermits <= maxPermits
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.collection.JavaConverters._

import org.scalatest.funsuite.AnyFunSuite

class PrioritySemaphoreSuite extends AnyFunSuite {
type TestPrioritySemaphore = PrioritySemaphore[Long]

test("tryAcquire should return true if permits are available") {
val semaphore = new TestPrioritySemaphore(10)

assert(semaphore.tryAcquire(5))
assert(semaphore.tryAcquire(3))
assert(semaphore.tryAcquire(2))
assert(!semaphore.tryAcquire(1))
}

test("acquire and release should work correctly") {
val semaphore = new TestPrioritySemaphore(1)

assert(semaphore.tryAcquire(1))

val latch = new CountDownLatch(1)
val t = new Thread(() => {
try {
semaphore.acquire(1, 1)
fail("Should not acquire permit")
} catch {
case _: InterruptedException =>
semaphore.acquire(1, 1)
} finally {
latch.countDown()
}
})
t.start()

Thread.sleep(100)
t.interrupt()

semaphore.release(1)

latch.await(1, TimeUnit.SECONDS)
}

test("multiple threads should handle permits and priority correctly") {
val semaphore = new TestPrioritySemaphore(0)
val latch = new CountDownLatch(3)
val results = new java.util.ArrayList[Int]()

def taskWithPriority(priority: Int) = new Runnable {
override def run(): Unit = {
try {
semaphore.acquire(1, priority)
results.add(priority)
semaphore.release(1)
} finally {
latch.countDown()
}
}
}

new Thread(taskWithPriority(2)).start()
new Thread(taskWithPriority(1)).start()
new Thread(taskWithPriority(3)).start()

Thread.sleep(100)
semaphore.release(1)
abellina marked this conversation as resolved.
Show resolved Hide resolved

latch.await(1, TimeUnit.SECONDS)
assert(results.asScala.toList == List(3, 2, 1))
}
}
Loading