diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala index eb6fe1583e0..dbd0943eb4e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala @@ -28,6 +28,11 @@ import org.apache.spark.internal.Logging import org.apache.spark.resource.ResourceInformation import org.apache.spark.sql.rapids.GpuShuffleEnv +sealed trait MemoryState +private case object Initialized extends MemoryState +private case object Uninitialized extends MemoryState +private case object Errored extends MemoryState + object GpuDeviceManager extends Logging { // This config controls whether RMM/Pinned memory are initialized from the task // or from the executor side plugin. The default is to initialize from the @@ -43,7 +48,7 @@ object GpuDeviceManager extends Logging { } private val threadGpuInitialized = new ThreadLocal[Boolean]() - @volatile private var singletonMemoryInitialized: Boolean = false + @volatile private var singletonMemoryInitialized: MemoryState = Uninitialized @volatile private var deviceId: Option[Int] = None /** @@ -127,9 +132,11 @@ object GpuDeviceManager extends Logging { } def shutdown(): Unit = synchronized { + // assume error during shutdown until we complete it + singletonMemoryInitialized = Errored RapidsBufferCatalog.close() Rmm.shutdown() - singletonMemoryInitialized = false + singletonMemoryInitialized = Uninitialized } def getResourcesFromTaskContext: Map[String, ResourceInformation] = { @@ -283,15 +290,18 @@ object GpuDeviceManager extends Logging { * @param rapidsConf the config to use. */ def initializeMemory(gpuId: Option[Int], rapidsConf: Option[RapidsConf] = None): Unit = { - if (singletonMemoryInitialized == false) { + if (singletonMemoryInitialized != Initialized) { // Memory or memory related components that only need to be initialized once per executor. // This synchronize prevents multiple tasks from trying to initialize these at the same time. GpuDeviceManager.synchronized { - if (singletonMemoryInitialized == false) { + if (singletonMemoryInitialized == Errored) { + throw new IllegalStateException( + "Cannot initialize memory due to previous shutdown failing") + } else if (singletonMemoryInitialized == Uninitialized) { val gpu = gpuId.getOrElse(findGpuAndAcquire()) initializeRmm(gpu, rapidsConf) allocatePinnedMemory(gpu, rapidsConf) - singletonMemoryInitialized = true + singletonMemoryInitialized = Initialized } } }