From 61b49f1645254993c8fefefee77a32ec8a833910 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Fri, 7 Aug 2020 22:05:19 -0500 Subject: [PATCH] Fixes bugs around GpuShuffleEnv initialization Signed-off-by: Alessandro Bellina --- .../spark/sql/rapids/GpuShuffleEnv.scala | 46 ++++++++++--------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index 932b249d733..53e26d8c550 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -38,12 +38,6 @@ class GpuShuffleEnv(rapidsConf: RapidsConf) extends Logging { conf.get("spark.shuffle.manager") == GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS } - lazy val isRapidsShuffleEnabled: Boolean = { - val env = SparkEnv.get - val isRapidsManager = GpuShuffleEnv.isRapidsShuffleManagerInitialized - val externalShuffle = env.blockManager.externalShuffleServiceEnabled - isRapidsManager && !externalShuffle - } def initStorage(devInfo: CudaMemInfo): Unit = { if (isRapidsShuffleConfigured) { @@ -99,6 +93,31 @@ object GpuShuffleEnv extends Logging { private var isRapidsShuffleManagerInitialized: Boolean = false @volatile private var env: GpuShuffleEnv = _ + // + // Functions below get called from the driver or executors + // + + def isRapidsShuffleEnabled: Boolean = { + val isRapidsManager = GpuShuffleEnv.isRapidsShuffleManagerInitialized + val externalShuffle = SparkEnv.get.blockManager.externalShuffleServiceEnabled + isRapidsManager && !externalShuffle + } + + def setRapidsShuffleManagerInitialized(initialized: Boolean, className: String): Unit = { + assert(className == GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS) + logInfo("RapidsShuffleManager is initialized") + isRapidsShuffleManagerInitialized = initialized + } + + def shutdown(): Unit = { + // in the driver, this will not be set + Option(env).foreach(_.closeStorage()) + } + + // + // Functions below only get called from the executor + // + def init(conf: RapidsConf, devInfo: CudaMemInfo): Unit = { Option(env).foreach(_.closeStorage()) val shuffleEnv = new GpuShuffleEnv(conf) @@ -106,24 +125,9 @@ object GpuShuffleEnv extends Logging { env = shuffleEnv } - def shutdown(): Unit = { - env.closeStorage() - } - - def get: GpuShuffleEnv = env - def getCatalog: ShuffleBufferCatalog = env.getCatalog def getReceivedCatalog: ShuffleReceivedBufferCatalog = env.getReceivedCatalog def getDeviceStorage: RapidsDeviceMemoryStore = env.getDeviceStorage - - def isRapidsShuffleEnabled: Boolean = env.isRapidsShuffleEnabled - - // the shuffle plugin will call this on initialize - def setRapidsShuffleManagerInitialized(initialized: Boolean, className: String): Unit = { - assert(className == GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS) - logInfo("RapidsShuffleManager is initialized") - isRapidsShuffleManagerInitialized = initialized - } }