Skip to content

Commit

Permalink
Fixes bugs around GpuShuffleEnv initialization (NVIDIA#534)
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
  • Loading branch information
abellina authored Aug 10, 2020
1 parent fbea5c5 commit 28ad4a3
Showing 1 changed file with 25 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ class GpuShuffleEnv(rapidsConf: RapidsConf) extends Logging {
conf.get("spark.shuffle.manager") == GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS
}

lazy val isRapidsShuffleEnabled: Boolean = {
val env = SparkEnv.get
val isRapidsManager = GpuShuffleEnv.isRapidsShuffleManagerInitialized
val externalShuffle = env.blockManager.externalShuffleServiceEnabled
isRapidsManager && !externalShuffle
}

def initStorage(devInfo: CudaMemInfo): Unit = {
if (isRapidsShuffleConfigured) {
Expand Down Expand Up @@ -99,31 +93,41 @@ object GpuShuffleEnv extends Logging {
private var isRapidsShuffleManagerInitialized: Boolean = false
@volatile private var env: GpuShuffleEnv = _

//
// Functions below get called from the driver or executors
//

def isRapidsShuffleEnabled: Boolean = {
val isRapidsManager = GpuShuffleEnv.isRapidsShuffleManagerInitialized
val externalShuffle = SparkEnv.get.blockManager.externalShuffleServiceEnabled
isRapidsManager && !externalShuffle
}

def setRapidsShuffleManagerInitialized(initialized: Boolean, className: String): Unit = {
assert(className == GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS)
logInfo("RapidsShuffleManager is initialized")
isRapidsShuffleManagerInitialized = initialized
}

def shutdown(): Unit = {
// in the driver, this will not be set
Option(env).foreach(_.closeStorage())
}

//
// Functions below only get called from the executor
//

def init(conf: RapidsConf, devInfo: CudaMemInfo): Unit = {
Option(env).foreach(_.closeStorage())
val shuffleEnv = new GpuShuffleEnv(conf)
shuffleEnv.initStorage(devInfo)
env = shuffleEnv
}

def shutdown(): Unit = {
env.closeStorage()
}

def get: GpuShuffleEnv = env

def getCatalog: ShuffleBufferCatalog = env.getCatalog

def getReceivedCatalog: ShuffleReceivedBufferCatalog = env.getReceivedCatalog

def getDeviceStorage: RapidsDeviceMemoryStore = env.getDeviceStorage

def isRapidsShuffleEnabled: Boolean = env.isRapidsShuffleEnabled

// the shuffle plugin will call this on initialize
def setRapidsShuffleManagerInitialized(initialized: Boolean, className: String): Unit = {
assert(className == GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS)
logInfo("RapidsShuffleManager is initialized")
isRapidsShuffleManagerInitialized = initialized
}
}

0 comments on commit 28ad4a3

Please sign in to comment.