diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala index 336d9b4f5d3..f028e244286 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala @@ -19,6 +19,8 @@ package com.nvidia.spark.rapids import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{BufferType, NvtxColor, Table} +import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.SpillPriorities.COALESCE_BATCH_ON_DECK_PRIORITY import com.nvidia.spark.rapids.format.{ColumnMeta, SubBufferMeta, TableMeta} import org.apache.spark.TaskContext @@ -29,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.rapids.TempSpillBufferId import org.apache.spark.sql.types.{DataTypes, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -158,9 +161,28 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], peakDevMemory: SQLMetric, opName: String) extends Iterator[ColumnarBatch] with Logging { private val iter = new RemoveEmptyBatchIterator(origIter, numInputBatches) - private var onDeck: Option[ColumnarBatch] = None private var batchInitialized: Boolean = false + /** + * Return true if there is something saved on deck for later processing. + */ + protected def hasOnDeck: Boolean + + /** + * Save a batch for later processing. + */ + protected def saveOnDeck(batch: ColumnarBatch): Unit + + /** + * If there is anything saved on deck close it. + */ + protected def clearOnDeck(): Unit + + /** + * Remove whatever is on deck and return it. + */ + protected def popOnDeck(): ColumnarBatch + /** We need to track the sizes of string columns to make sure we don't exceed 2GB */ private val stringFieldIndices: Array[Int] = schema.fields.zipWithIndex .filter(_._1.dataType == DataTypes.StringType) @@ -172,9 +194,9 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], // note that TaskContext.get() can return null during unit testing so we wrap it in an // option here Option(TaskContext.get()) - .foreach(_.addTaskCompletionListener[Unit](_ => onDeck.foreach(_.close()))) + .foreach(_.addTaskCompletionListener[Unit]( _ => clearOnDeck())) - override def hasNext: Boolean = onDeck.isDefined || iter.hasNext + override def hasNext: Boolean = hasOnDeck || iter.hasNext /** * Called first to initialize any state needed for a new batch to be created. @@ -251,10 +273,9 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], var numBatches = 0 // check if there is a batch "on deck" from a previous call to next() - if (onDeck.isDefined) { - val batch = onDeck.get + if (hasOnDeck) { + val batch = popOnDeck() addBatch(batch) - onDeck = None numBatches += 1 numRows += batch.numRows() columnSizes = getColumnSizes(batch) @@ -265,7 +286,7 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], try { // there is a hard limit of 2^31 rows - while (numRows < Int.MaxValue && onDeck.isEmpty && iter.hasNext) { + while (numRows < Int.MaxValue && !hasOnDeck && iter.hasNext) { val cb = iter.next() val nextRows = cb.numRows() @@ -300,11 +321,11 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], s" but cuDF only supports ${Int.MaxValue} rows. At least $wouldBeRows are in" + s" this partition. Please try increasing your partition count.") } - onDeck = Some(cb) + saveOnDeck(cb) } else if (batchRowLimit > 0 && wouldBeRows > batchRowLimit) { - onDeck = Some(cb) + saveOnDeck(cb) } else if (wouldBeBytes > goal.targetSizeBytes && numBytes > 0) { - onDeck = Some(cb) + saveOnDeck(cb) } else if (wouldBeStringColumnSizes.exists(size => size > Int.MaxValue)) { if (goal == RequireSingleBatch) { throw new IllegalStateException("A single batch is required for this operation," + @@ -312,7 +333,7 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], s" At least ${wouldBeStringColumnSizes.max} are in a single column in this" + s" partition. Please try increasing your partition count.") } - onDeck = Some(cb) + saveOnDeck(cb) } else { addBatch(cb) numBatches += 1 @@ -327,7 +348,7 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], } // enforce single batch limit when appropriate - if (goal == RequireSingleBatch && (onDeck.isDefined || iter.hasNext)) { + if (goal == RequireSingleBatch && (hasOnDeck || iter.hasNext)) { throw new IllegalStateException("A single batch is required for this operation." + " Please try increasing your partition count.") } @@ -500,6 +521,59 @@ class GpuCoalesceIterator(iter: Iterator[ColumnarBatch], peakDevMemory.set(maxDeviceMemory) batches.foreach(_.close()) } + + private var onDeck: Option[TempSpillBufferId] = None + + override protected def hasOnDeck: Boolean = onDeck.isDefined + + override protected def saveOnDeck(batch: ColumnarBatch): Unit = { + assert(onDeck.isEmpty) + val id = TempSpillBufferId() + val priority = COALESCE_BATCH_ON_DECK_PRIORITY + val numColumns = batch.numCols() + + if (numColumns > 0 && batch.column(0).isInstanceOf[GpuCompressedColumnVector]) { + val cv = batch.column(0).asInstanceOf[GpuCompressedColumnVector] + RapidsBufferCatalog.addBuffer(id, cv.getBuffer, cv.getTableMeta, priority) + } else if (numColumns > 0 && + (0 until numColumns) + .forall(i => batch.column(i).isInstanceOf[GpuColumnVectorFromBuffer])) { + val cv = batch.column(0).asInstanceOf[GpuColumnVectorFromBuffer] + withResource(GpuColumnVector.from(batch)) { table => + RapidsBufferCatalog.addTable(id, table, cv.getBuffer, priority) + } + } else { + withResource(batch) { batch => + withResource(GpuColumnVector.from(batch)) { tmpTable => + val contigTables = tmpTable.contiguousSplit(batch.numRows()) + val tab = contigTables.head + contigTables.tail.safeClose() + RapidsBufferCatalog.addTable(id, tab.getTable, tab.getBuffer, priority) + } + } + } + + onDeck = Some(id) + } + + override protected def clearOnDeck(): Unit = { + onDeck.foreach { id => + withResource(RapidsBufferCatalog.acquireBuffer(id)) { rapidsBuffer => + rapidsBuffer.free() + } + } + onDeck = None + } + + override protected def popOnDeck(): ColumnarBatch = { + val id = onDeck.get + val ret = withResource(RapidsBufferCatalog.acquireBuffer(id)) { rapidsBuffer => + rapidsBuffer.free() + rapidsBuffer.getColumnarBatch + } + onDeck = None + ret + } } case class GpuCoalesceBatches(child: SparkPlan, goal: CoalesceGoal) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala index 33bbdd95a58..a716b527c74 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala @@ -233,7 +233,8 @@ object GpuDeviceManager extends Logging { try { Cuda.setDevice(gpuId) Rmm.initialize(init, logConf, initialAllocation, maxAllocation) - GpuShuffleEnv.init(conf, info) + RapidsBufferCatalog.init(conf) + GpuShuffleEnv.init(conf) } catch { case e: Exception => logError("Could not initialize RMM", e) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala index 3f916d9e87d..de8bf674f99 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala @@ -214,6 +214,20 @@ class HostToGpuCoalesceIterator(iter: Iterator[ColumnarBatch], totalRows = 0 peakDevMemory.set(maxDeviceMemory) } + + private var onDeck: Option[ColumnarBatch] = None + + override protected def hasOnDeck: Boolean = onDeck.isDefined + override protected def saveOnDeck(batch: ColumnarBatch): Unit = onDeck = Some(batch) + override protected def clearOnDeck(): Unit = { + onDeck.foreach(_.close()) + onDeck = None + } + override protected def popOnDeck(): ColumnarBatch = { + val ret = onDeck.get + onDeck = None + ret + } } /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index 7cae58be464..c5aea52b0e3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -19,14 +19,18 @@ package com.nvidia.spark.rapids import java.util.concurrent.ConcurrentHashMap import java.util.function.BiFunction -import scala.collection.mutable.ArrayBuffer - +import ai.rapids.cudf.{DeviceMemoryBuffer, Rmm, Table} import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta +import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.sql.rapids.RapidsDiskBlockManager -/** Catalog for lookup of buffers by ID */ +/** + * Catalog for lookup of buffers by ID. The constructor is only visible for testing, generally + * `RapidsBufferCatalog.singleton` should be used instead. + */ class RapidsBufferCatalog extends Logging { /** Map of buffer IDs to buffers */ private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, RapidsBuffer] @@ -99,6 +103,95 @@ class RapidsBufferCatalog extends Logging { } } -object RapidsBufferCatalog { +object RapidsBufferCatalog extends Logging { private val MAX_BUFFER_LOOKUP_ATTEMPTS = 100 -} + + val singleton = new RapidsBufferCatalog + private var deviceStorage: RapidsDeviceMemoryStore = _ + private var hostStorage: RapidsHostMemoryStore = _ + private var diskStorage: RapidsDiskStore = _ + private var memoryEventHandler: DeviceMemoryEventHandler = _ + + private lazy val conf = SparkEnv.get.conf + + def init(rapidsConf: RapidsConf): Unit = { + // We are going to re-initialize so make sure all of the old things were closed... + closeImpl() + assert(memoryEventHandler == null) + deviceStorage = new RapidsDeviceMemoryStore() + hostStorage = new RapidsHostMemoryStore(rapidsConf.hostSpillStorageSize) + val diskBlockManager = new RapidsDiskBlockManager(conf) + diskStorage = new RapidsDiskStore(diskBlockManager) + deviceStorage.setSpillStore(hostStorage) + hostStorage.setSpillStore(diskStorage) + + logInfo("Installing GPU memory handler for spill") + memoryEventHandler = new DeviceMemoryEventHandler(deviceStorage) + Rmm.setEventHandler(memoryEventHandler) + } + + def close(): Unit = { + logInfo("Closing storage") + closeImpl() + } + + private def closeImpl(): Unit = { + if (memoryEventHandler != null) { + // Workaround for shutdown ordering problems where device buffers allocated with this handler + // are being freed after the handler is destroyed + //Rmm.clearEventHandler() + memoryEventHandler = null + } + + if (deviceStorage != null) { + deviceStorage.close() + deviceStorage = null + } + if (hostStorage != null) { + hostStorage.close() + hostStorage = null + } + if (diskStorage != null) { + diskStorage.close() + diskStorage = null + } + } + + def getDeviceStorage: RapidsDeviceMemoryStore = deviceStorage + + /** + * Adds a contiguous table to the device storage, taking ownership of the table. + * @param id buffer ID to associate with this buffer + * @param table cudf table based from the contiguous buffer + * @param contigBuffer device memory buffer backing the table + * @param initialSpillPriority starting spill priority value for the buffer + */ + def addTable( + id: RapidsBufferId, + table: Table, + contigBuffer: DeviceMemoryBuffer, + initialSpillPriority: Long): Unit = + deviceStorage.addTable(id, table, contigBuffer, initialSpillPriority) + + /** + * Adds a buffer to the device storage, taking ownership of the buffer. + * @param id buffer ID to associate with this buffer + * @param buffer buffer that will be owned by the store + * @param tableMeta metadata describing the buffer layout + * @param initialSpillPriority starting spill priority value for the buffer + */ + def addBuffer( + id: RapidsBufferId, + buffer: DeviceMemoryBuffer, + tableMeta: TableMeta, + initialSpillPriority: Long): Unit = + deviceStorage.addBuffer(id, buffer, tableMeta, initialSpillPriority) + + /** + * Lookup the buffer that corresponds to the specified buffer ID and acquire it. + * NOTE: It is the responsibility of the caller to close the buffer. + * @param id buffer identifier + * @return buffer that has been acquired + */ + def acquireBuffer(id: RapidsBufferId): RapidsBuffer = singleton.acquireBuffer(id) +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala index 71feb406ade..a06e78b68f5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala @@ -38,7 +38,8 @@ object RapidsBufferStore { */ abstract class RapidsBufferStore( val name: String, - catalog: RapidsBufferCatalog) extends AutoCloseable with Logging { + catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton) + extends AutoCloseable with Logging { private class BufferTracker { private[this] val comparator: Comparator[RapidsBufferBase] = diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala index d5fc8c943b4..5e72c7978ef 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * Buffer storage using device memory. * @param catalog catalog to register this store */ -class RapidsDeviceMemoryStore( - catalog: RapidsBufferCatalog) extends RapidsBufferStore("GPU", catalog) { +class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton) + extends RapidsBufferStore("GPU", catalog) { override protected def createBuffer( other: RapidsBuffer, stream: Cuda.Stream): RapidsBufferBase = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala index 87086dcae4c..2c0caae212e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala @@ -28,8 +28,9 @@ import org.apache.spark.sql.rapids.RapidsDiskBlockManager /** A buffer store using files on the local disks. */ class RapidsDiskStore( - catalog: RapidsBufferCatalog, - diskBlockManager: RapidsDiskBlockManager) extends RapidsBufferStore("disk", catalog) { + diskBlockManager: RapidsDiskBlockManager, + catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton) + extends RapidsBufferStore("disk", catalog) { private[this] val sharedBufferFiles = new ConcurrentHashMap[RapidsBufferId, File] override def createBuffer( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala index ba6a5145a83..07ecc0689e8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala @@ -26,8 +26,9 @@ import com.nvidia.spark.rapids.format.TableMeta * @param maxSize maximum size in bytes for all buffers in this store */ class RapidsHostMemoryStore( - catalog: RapidsBufferCatalog, - maxSize: Long) extends RapidsBufferStore("host", catalog) { + maxSize: Long, + catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton) + extends RapidsBufferStore("host", catalog) { private[this] val pool = HostMemoryBuffer.allocate(maxSize, false) private[this] val addressAllocator = new AddressSpaceAllocator(maxSize) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillPriorities.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillPriorities.scala index ac73941470d..a4dc9de5d09 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillPriorities.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillPriorities.scala @@ -42,9 +42,16 @@ object SpillPriorities { /** * Priorities for buffers received from shuffle. - * Shuffle input buffers are about to be read by a task, so only spill - * them if there's no other choice. + * Shuffle input buffers are about to be read by a task, so spill + * them if there's no other choice, but leave some space at the end of the priority range + * so there can be some things after it. */ - // TODO: Should these be ordered amongst themselves? Maybe consider buffer size? - val INPUT_FROM_SHUFFLE_PRIORITY: Long = Long.MaxValue + val INPUT_FROM_SHUFFLE_PRIORITY: Long = Long.MaxValue - 1000 + + /** + * Priority for buffers in coalesce batch that did not fit into the batch we are working on. + * Most of the time this is shuffle input data that we read early so it should be slightly higher + * priority to keep around than other input shuffle buffers. + */ + val COALESCE_BATCH_ON_DECK_PRIORITY: Long = INPUT_FROM_SHUFFLE_PRIORITY + 1 } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index 84132d2aace..c8cf6b75a0b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -380,7 +380,7 @@ class RapidsShuffleClient( exec: Executor, clientCopyExecutor: Executor, maximumMetadataSize: Long, - devStorage: RapidsDeviceMemoryStore = GpuShuffleEnv.getDeviceStorage, + devStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage, catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog) extends Logging { object ShuffleClientOps { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index ebc271aadc5..8ae9f5c0c8b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -18,20 +18,14 @@ package org.apache.spark.sql.rapids import java.util.Locale -import ai.rapids.cudf.{CudaMemInfo, Rmm} import com.nvidia.spark.rapids._ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging class GpuShuffleEnv(rapidsConf: RapidsConf) extends Logging { - private val catalog = new RapidsBufferCatalog private var shuffleCatalog: ShuffleBufferCatalog = _ private var shuffleReceivedBufferCatalog: ShuffleReceivedBufferCatalog = _ - private var deviceStorage: RapidsDeviceMemoryStore = _ - private var hostStorage: RapidsHostMemoryStore = _ - private var diskStorage: RapidsDiskStore = _ - private var memoryEventHandler: DeviceMemoryEventHandler = _ private lazy val conf = SparkEnv.get.conf @@ -49,52 +43,19 @@ class GpuShuffleEnv(rapidsConf: RapidsConf) extends Logging { } } - def initStorage(devInfo: CudaMemInfo): Unit = { + def init(): Unit = { if (isRapidsShuffleConfigured) { - assert(memoryEventHandler == null) - deviceStorage = new RapidsDeviceMemoryStore(catalog) - hostStorage = new RapidsHostMemoryStore(catalog, rapidsConf.hostSpillStorageSize) val diskBlockManager = new RapidsDiskBlockManager(conf) - diskStorage = new RapidsDiskStore(catalog, diskBlockManager) - deviceStorage.setSpillStore(hostStorage) - hostStorage.setSpillStore(diskStorage) - - logInfo("Installing GPU memory handler for spill") - memoryEventHandler = new DeviceMemoryEventHandler(deviceStorage) - Rmm.setEventHandler(memoryEventHandler) - - shuffleCatalog = new ShuffleBufferCatalog(catalog, diskBlockManager) - shuffleReceivedBufferCatalog = new ShuffleReceivedBufferCatalog(catalog, diskBlockManager) - } - } - - def closeStorage(): Unit = { - logInfo("Closing shuffle storage") - if (memoryEventHandler != null) { - // Workaround for shutdown ordering problems where device buffers allocated with this handler - // are being freed after the handler is destroyed - //Rmm.clearEventHandler() - memoryEventHandler = null - } - if (deviceStorage != null) { - deviceStorage.close() - deviceStorage = null - } - if (hostStorage != null) { - hostStorage.close() - hostStorage = null - } - if (diskStorage != null) { - diskStorage.close() - diskStorage = null + shuffleCatalog = + new ShuffleBufferCatalog(RapidsBufferCatalog.singleton, diskBlockManager) + shuffleReceivedBufferCatalog = + new ShuffleReceivedBufferCatalog(RapidsBufferCatalog.singleton, diskBlockManager) } } def getCatalog: ShuffleBufferCatalog = shuffleCatalog def getReceivedCatalog: ShuffleReceivedBufferCatalog = shuffleReceivedBufferCatalog - - def getDeviceStorage: RapidsDeviceMemoryStore = deviceStorage } object GpuShuffleEnv extends Logging { @@ -119,11 +80,6 @@ object GpuShuffleEnv extends Logging { isRapidsShuffleManagerInitialized = initialized } - def shutdown(): Unit = { - // in the driver, this will not be set - Option(env).foreach(_.closeStorage()) - } - def getCatalog: ShuffleBufferCatalog = if (env == null) { null } else { @@ -134,16 +90,13 @@ object GpuShuffleEnv extends Logging { // Functions below only get called from the executor // - def init(conf: RapidsConf, devInfo: CudaMemInfo): Unit = { - Option(env).foreach(_.closeStorage()) + def init(conf: RapidsConf): Unit = { val shuffleEnv = new GpuShuffleEnv(conf) - shuffleEnv.initStorage(devInfo) + shuffleEnv.init() env = shuffleEnv } def getReceivedCatalog: ShuffleReceivedBufferCatalog = env.getReceivedCatalog - def getDeviceStorage: RapidsDeviceMemoryStore = env.getDeviceStorage - def rapidsShuffleCodec: Option[TableCompressionCodec] = env.rapidsShuffleCodec } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala index 55b87de4e23..309bd466264 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala @@ -290,7 +290,7 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole mapId, metricsReporter, catalog, - GpuShuffleEnv.getDeviceStorage, + RapidsBufferCatalog.getDeviceStorage, server, gpu.dependency.metrics) case other => @@ -361,7 +361,6 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole override def shuffleBlockResolver: ShuffleBlockResolver = resolver override def stop(): Unit = { - GpuShuffleEnv.shutdown() wrapped.stop() server.foreach(_.close()) transport.foreach(_.close()) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/TempSpillBufferId.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/TempSpillBufferId.scala new file mode 100644 index 00000000000..c193182a67e --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/TempSpillBufferId.scala @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import java.io.File +import java.util.UUID +import java.util.concurrent.atomic.AtomicInteger +import java.util.function.IntUnaryOperator + +import com.nvidia.spark.rapids.RapidsBufferId + +import org.apache.spark.storage.TempLocalBlockId + +object TempSpillBufferId { + private val MAX_TABLE_ID = Integer.MAX_VALUE + private val TABLE_ID_UPDATER = new IntUnaryOperator { + override def applyAsInt(i: Int): Int = if (i < MAX_TABLE_ID) i + 1 else 0 + } + + /** Tracks the next table identifier */ + private[this] val tableIdCounter = new AtomicInteger(0) + + def apply(): TempSpillBufferId = { + val tableId = tableIdCounter.getAndUpdate(TABLE_ID_UPDATER) + val tempBlockId = TempLocalBlockId(UUID.randomUUID()) + new TempSpillBufferId(tableId, tempBlockId) + } +} + +class TempSpillBufferId private( + override val tableId: Int, + bufferId: TempLocalBlockId) extends RapidsBufferId { + + override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = + diskBlockManager.getFile(bufferId) +} \ No newline at end of file diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala index 2fd73d0a29a..ea82a40bb6a 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala @@ -62,7 +62,7 @@ class GpuPartitioningSuite extends FunSuite with Arm { SparkSession.getActiveSession.foreach(_.close()) val conf = new SparkConf().set(RapidsConf.SHUFFLE_COMPRESSION_CODEC.key, "none") TestUtils.withGpuSparkSession(conf) { _ => - GpuShuffleEnv.init(new RapidsConf(conf), Cuda.memGetInfo()) + GpuShuffleEnv.init(new RapidsConf(conf)) val partitionIndices = Array(0, 2, 2) val gp = new GpuPartitioning { override val numPartitions: Int = partitionIndices.length @@ -98,9 +98,9 @@ class GpuPartitioningSuite extends FunSuite with Arm { val conf = new SparkConf() .set(RapidsConf.SHUFFLE_COMPRESSION_CODEC.key, "copy") TestUtils.withGpuSparkSession(conf) { _ => - GpuShuffleEnv.init(new RapidsConf(conf), Cuda.memGetInfo()) + GpuShuffleEnv.init(new RapidsConf(conf)) val spillPriority = 7L - val catalog = new RapidsBufferCatalog + val catalog = RapidsBufferCatalog.singleton withResource(new RapidsDeviceMemoryStore(catalog)) { deviceStore => val partitionIndices = Array(0, 2, 2) val gp = new GpuPartitioning { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala index cf05e9b11de..983d367e2b7 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSinglePartitioningSuite.scala @@ -38,7 +38,7 @@ class GpuSinglePartitioningSuite extends FunSuite with Arm { val conf = new SparkConf().set("spark.shuffle.manager", GpuShuffleEnv.RAPIDS_SHUFFLE_CLASS) .set(RapidsConf.SHUFFLE_COMPRESSION_CODEC.key, "none") TestUtils.withGpuSparkSession(conf) { _ => - GpuShuffleEnv.init(new RapidsConf(conf), Cuda.memGetInfo()) + GpuShuffleEnv.init(new RapidsConf(conf)) val partitioner = GpuSinglePartitioning(Nil) withResource(buildBatch()) { expected => // partition will consume batch, so make a new batch with incremented refcounts diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala index 92cae53351a..be6d16fd429 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala @@ -53,9 +53,9 @@ class RapidsDiskStoreSuite extends FunSuite with BeforeAndAfterEach with Arm wit val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = spy(new RapidsBufferCatalog) withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(catalog, hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore => devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(catalog, mock[RapidsDiskBlockManager])) { diskStore => + withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore => assertResult(0)(diskStore.currentSize) hostStore.setSpillStore(diskStore) val bufferSize = addTableToStore(devStore, bufferId, spillPriority) @@ -89,9 +89,9 @@ class RapidsDiskStoreSuite extends FunSuite with BeforeAndAfterEach with Arm wit val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = new RapidsBufferCatalog withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(catalog, hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore => devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(catalog, mock[RapidsDiskBlockManager])) { diskStore => + withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore => hostStore.setSpillStore(diskStore) addTableToStore(devStore, bufferId, spillPriority) val expectedBatch = withResource(catalog.acquireBuffer(bufferId)) { buffer => @@ -119,9 +119,9 @@ class RapidsDiskStoreSuite extends FunSuite with BeforeAndAfterEach with Arm wit val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = new RapidsBufferCatalog withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(catalog, hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore => devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(catalog, mock[RapidsDiskBlockManager])) { diskStore => + withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore => hostStore.setSpillStore(diskStore) addTableToStore(devStore, bufferId, spillPriority) val expectedBuffer = withResource(catalog.acquireBuffer(bufferId)) { buffer => @@ -166,9 +166,9 @@ class RapidsDiskStoreSuite extends FunSuite with BeforeAndAfterEach with Arm wit val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = new RapidsBufferCatalog withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(catalog, hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore => devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(catalog, mock[RapidsDiskBlockManager])) { diskStore => + withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore => hostStore.setSpillStore(diskStore) addTableToStore(devStore, bufferId, spillPriority) devStore.synchronousSpill(0) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala index 196c330e515..8cdeec5f16f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala @@ -43,7 +43,7 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = spy(new RapidsBufferCatalog) withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(catalog, hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore => assertResult(0)(hostStore.currentSize) assertResult(hostStoreMaxSize)(hostStore.numBytesFree) devStore.setSpillStore(hostStore) @@ -76,7 +76,7 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = new RapidsBufferCatalog withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(catalog, hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore => devStore.setSpillStore(hostStore) var ct = buildContiguousTable() try { @@ -111,7 +111,7 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = new RapidsBufferCatalog withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(catalog, hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore => devStore.setSpillStore(hostStore) var ct = buildContiguousTable() try {