diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala index 78ff83a6994..52cc9123088 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -215,7 +215,6 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole // so they all need to be lazy values that are executed when things are first called // NOTE: this can be null in the driver side. - private[this] lazy val catalog = GpuShuffleEnv.getCatalog private lazy val env = SparkEnv.get private lazy val blockManager = env.blockManager private lazy val shouldFallThroughOnEverything = { @@ -232,10 +231,17 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole private lazy val localBlockManagerId = blockManager.blockManagerId + // Code that expects the shuffle catalog to be initialized gets it this way, + // with error checking in case we are in a bad state. + private def getCatalogOrThrow: ShuffleBufferCatalog = + Option(GpuShuffleEnv.getCatalog).getOrElse( + throw new IllegalStateException("The ShuffleBufferCatalog is not initialized but the " + + "RapidsShuffleManager is configured")) + private lazy val resolver = if (shouldFallThroughOnEverything) { wrapped.shuffleBlockResolver } else { - new GpuShuffleBlockResolver(wrapped.shuffleBlockResolver, catalog) + new GpuShuffleBlockResolver(wrapped.shuffleBlockResolver, getCatalogOrThrow) } private[this] lazy val transport: Option[RapidsShuffleTransport] = { @@ -248,6 +254,7 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole private[this] lazy val server: Option[RapidsShuffleServer] = { if (rapidsConf.shuffleTransportEnabled && !isDriver) { + val catalog = getCatalogOrThrow val requestHandler = new RapidsShuffleRequestHandler() { override def acquireShuffleBuffer(tableId: Int): RapidsBuffer = { val shuffleBufferId = catalog.getShuffleBufferId(tableId) @@ -295,7 +302,7 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole gpu.asInstanceOf[GpuShuffleHandle[K, V]], mapId, metricsReporter, - catalog, + getCatalogOrThrow, RapidsBufferCatalog.getDeviceStorage, server, gpu.dependency.metrics) @@ -335,7 +342,7 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole context, metrics, transport, - catalog, + getCatalogOrThrow, gpu.dependency.sparkTypes) case other => { val shuffleHandle = RapidsShuffleInternalManagerBase.unwrapHandle(other) @@ -345,6 +352,7 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole } def registerGpuShuffle(shuffleId: Int): Unit = { + val catalog = GpuShuffleEnv.getCatalog if (catalog != null) { // Note that in local mode this can be called multiple times. logInfo(s"Registering shuffle $shuffleId") @@ -353,6 +361,7 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole } def unregisterGpuShuffle(shuffleId: Int): Unit = { + val catalog = GpuShuffleEnv.getCatalog if (catalog != null) { logInfo(s"Unregistering shuffle $shuffleId") catalog.unregisterShuffle(shuffleId)