diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index cb38dd9f6db..ebc271aadc5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -124,6 +124,12 @@ object GpuShuffleEnv extends Logging { Option(env).foreach(_.closeStorage()) } + def getCatalog: ShuffleBufferCatalog = if (env == null) { + null + } else { + env.getCatalog + } + // // Functions below only get called from the executor // @@ -135,8 +141,6 @@ object GpuShuffleEnv extends Logging { env = shuffleEnv } - def getCatalog: ShuffleBufferCatalog = env.getCatalog - def getReceivedCatalog: ShuffleReceivedBufferCatalog = env.getReceivedCatalog def getDeviceStorage: RapidsDeviceMemoryStore = env.getDeviceStorage diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala index 2939645fbb8..ca3458cd234 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala @@ -202,6 +202,8 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole //Many of these values like blockManager are not initialized when the constructor is called, // so they all need to be lazy values that are executed when things are first called + + // NOTE: this can be null in the driver side. private[this] lazy val catalog = GpuShuffleEnv.getCatalog private lazy val env = SparkEnv.get private lazy val blockManager = env.blockManager @@ -284,7 +286,7 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole gpu.asInstanceOf[GpuShuffleHandle[K, V]], mapId, metrics, - GpuShuffleEnv.getCatalog, + catalog, GpuShuffleEnv.getDeviceStorage, server) case other =>