From d557ef1195253dcce126003b96b5fa81a55cd6af Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 21 Aug 2023 10:49:54 -0500 Subject: [PATCH 1/2] Add SpillableHostColumnarBatch Signed-off-by: Alessandro Bellina --- .../spark/rapids/RapidsHostColumnVector.java | 26 ++ .../nvidia/spark/rapids/RapidsBuffer.scala | 29 ++- .../spark/rapids/RapidsBufferCatalog.scala | 30 ++- .../nvidia/spark/rapids/RapidsDiskStore.scala | 130 +++++++--- .../spark/rapids/RapidsHostMemoryStore.scala | 241 +++++++++++++++++- .../spark/rapids/SpillableColumnarBatch.scala | 107 +++++++- .../execution/GpuBroadcastExchangeExec.scala | 39 +-- .../execution/SerializedHostTableUtils.scala | 63 +++++ .../rapids/RapidsHostMemoryStoreSuite.scala | 232 +++++++++++++++++ 9 files changed, 811 insertions(+), 86 deletions(-) create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/SerializedHostTableUtils.scala diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java index 0755e0052af..d0c0d4ee1ef 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java @@ -17,9 +17,12 @@ package com.nvidia.spark.rapids; +import ai.rapids.cudf.HostColumnVector; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.vectorized.ColumnarBatch; +import java.util.HashSet; + /** * A GPU accelerated version of the Spark ColumnVector. * Most of the standard Spark APIs should never be called, as they assume that the data @@ -57,6 +60,25 @@ public static RapidsHostColumnVector[] extractColumns(ColumnarBatch batch) { return vectors; } + public static ColumnarBatch incRefCounts(ColumnarBatch batch) { + for (RapidsHostColumnVector rapidsHostCv: extractColumns(batch)) { + rapidsHostCv.incRefCount(); + } + return batch; + } + + public static long getTotalHostMemoryUsed(ColumnarBatch batch) { + long sum = 0; + if (batch.numCols() > 0) { + HashSet found = new HashSet<>(); + for (RapidsHostColumnVector rapidsHostCv: extractColumns(batch)) { + if (found.add(rapidsHostCv)) { + sum += rapidsHostCv.getHostMemoryUsed(); + } + } + } + return sum; + } private final ai.rapids.cudf.HostColumnVector cudfCv; @@ -75,6 +97,10 @@ public final RapidsHostColumnVector incRefCount() { return this; } + public final long getHostMemoryUsed() { + return cudfCv.getHostMemorySize(); + } + public final ai.rapids.cudf.HostColumnVector getBase() { return cudfCv; } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala index 2c6aa11df44..e8ed1ae4184 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids import java.io.File +import java.io.OutputStream import scala.collection.mutable.ArrayBuffer @@ -175,7 +176,6 @@ class RapidsBufferCopyIterator(buffer: RapidsBuffer) } else { None } - def isChunked: Boolean = chunkedPacker.isDefined // this is used for the single shot case to flag when `next` is call @@ -245,12 +245,24 @@ trait RapidsBuffer extends AutoCloseable { def getCopyIterator: RapidsBufferCopyIterator = new RapidsBufferCopyIterator(this) + /** + * At spill time, the tier we are spilling to may need to hand the rapids buffer an output stream + * to write to. This is the case for a `RapidsHostColumnarBatch`. + * @param outputStream stream that the `RapidsBuffer` will serialize itself to + */ + def serializeToStream(outputStream: OutputStream): Unit = { + throw new IllegalStateException(s"Buffer $this does not support serializeToStream") + } + /** Descriptor for how the memory buffer is formatted */ def meta: TableMeta /** The storage tier for this buffer */ val storageTier: StorageTier + /** A RapidsBuffer that needs to be serialized/deserialized at spill/materialization time */ + val needsSerialization: Boolean = false + /** * Get the columnar batch within this buffer. The caller must have * successfully acquired the buffer beforehand. @@ -263,6 +275,21 @@ trait RapidsBuffer extends AutoCloseable { */ def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch + /** + * Get the host-backed columnar batch from this buffer. The caller must have + * successfully acquired the buffer beforehand. + * + * If this `RapidsBuffer` was added originally to the device tier, or if this is + * a just a buffer (not a batch), this function will throw. + * + * @param sparkTypes the spark data types the batch should have + * @see [[addReference]] + * @note It is the responsibility of the caller to close the batch. + */ + def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { + throw new IllegalStateException(s"$this does not support host columnar batches.") + } + /** * Get the underlying memory buffer. This may be either a HostMemoryBuffer or a DeviceMemoryBuffer * depending on where the buffer currently resides. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index 9ee65503a58..edbbae53688 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -349,8 +349,15 @@ class RapidsBufferCatalog( batch: ColumnarBatch, initialSpillPriority: Long, needsSync: Boolean = true): RapidsBufferHandle = { - closeOnExcept(GpuColumnVector.from(batch)) { table => - addTable(table, initialSpillPriority, needsSync) + require(batch.numCols() > 0, + "Cannot call addBatch with a batch that doesn't have columns") + batch.column(0) match { + case _: RapidsHostColumnVector => + addHostBatch(batch, initialSpillPriority, needsSync) + case _ => + closeOnExcept(GpuColumnVector.from(batch)) { table => + addTable(table, initialSpillPriority, needsSync) + } } } @@ -381,6 +388,25 @@ class RapidsBufferCatalog( makeNewHandle(id, initialSpillPriority) } + + /** + * Add a host-backed ColumnarBatch to the catalog. This is only called from addBatch + * after we detect that this is a host-backed batch. + */ + private def addHostBatch( + hostCb: ColumnarBatch, + initialSpillPriority: Long, + needsSync: Boolean): RapidsBufferHandle = { + val id = TempSpillBufferId() + val rapidsBuffer = hostStorage.addBatch( + id, + hostCb, + initialSpillPriority, + needsSync) + registerNewBuffer(rapidsBuffer) + makeNewHandle(id, initialSpillPriority) + } + /** * Register a degenerate RapidsBufferId given a TableMeta * @note this is called from the shuffle catalogs only diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala index 1fcb9b94119..a76da733a8c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala @@ -16,16 +16,19 @@ package com.nvidia.spark.rapids -import java.io.{File, FileOutputStream} +import java.io.{File, FileInputStream, FileOutputStream} import java.nio.channels.FileChannel.MapMode import java.util.concurrent.ConcurrentHashMap import ai.rapids.cudf.{Cuda, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta import org.apache.spark.sql.rapids.RapidsDiskBlockManager +import org.apache.spark.sql.rapids.execution.SerializedHostTableUtils +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch /** A buffer store using files on the local disks. */ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) @@ -36,6 +39,49 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) incoming: RapidsBuffer, stream: Cuda.Stream): RapidsBufferBase = { // assuming that the disk store gets contiguous buffers + val id = incoming.id + val path = if (id.canShareDiskPaths) { + sharedBufferFiles.computeIfAbsent(id, _ => id.getDiskPath(diskBlockManager)) + } else { + id.getDiskPath(diskBlockManager) + } + + val (fileOffset, diskLength) = if (id.canShareDiskPaths) { + // only one writer at a time for now when using shared files + path.synchronized { + if (incoming.needsSerialization) { + // only shuffle buffers share paths, and adding host-backed ColumnarBatch is + // not the shuffle case, so this is not supported and is never exercised. + throw new IllegalStateException( + s"Attempted spilling to disk a RapidsBuffer $incoming that needs serialization " + + s"while sharing spill paths.") + } else { + copyBufferToPath(incoming, path, append = true) + } + } + } else { + if (incoming.needsSerialization) { + serializeBufferToStream(incoming, path) + } else { + copyBufferToPath(incoming, path, append = false) + } + } + + logDebug(s"Spilled to $path $fileOffset:$diskLength") + new RapidsDiskBuffer( + id, + fileOffset, + diskLength, + incoming.meta, + incoming.getSpillPriority, + incoming.needsSerialization) + } + + /** Copy a host buffer to a file, returning the file offset at which the data was written. */ + private def copyBufferToPath( + incoming: RapidsBuffer, + path: File, + append: Boolean): (Long, Long) = { val incomingBuffer = withResource(incoming.getCopyIterator) { incomingCopyIterator => incomingCopyIterator.next() @@ -45,58 +91,44 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) case h: HostMemoryBuffer => h case _ => throw new UnsupportedOperationException("buffer without host memory") } - val id = incoming.id - val path = if (id.canShareDiskPaths) { - sharedBufferFiles.computeIfAbsent(id, _ => id.getDiskPath(diskBlockManager)) - } else { - id.getDiskPath(diskBlockManager) - } - val fileOffset = if (id.canShareDiskPaths) { - // only one writer at a time for now when using shared files - path.synchronized { - copyBufferToPath(hostBuffer, path, append = true) + val iter = new HostByteBufferIterator(hostBuffer) + val fos = new FileOutputStream(path, append) + try { + val channel = fos.getChannel + val fileOffset = channel.position + iter.foreach { bb => + while (bb.hasRemaining) { + channel.write(bb) + } } - } else { - copyBufferToPath(hostBuffer, path, append = false) + (fileOffset, channel.position()) + } finally { + fos.close() } - logDebug(s"Spilled to $path $fileOffset:${incomingBuffer.getLength}") - new RapidsDiskBuffer( - id, - fileOffset, - incomingBuffer.getLength, - incoming.meta, - incoming.getSpillPriority) } } - /** Copy a host buffer to a file, returning the file offset at which the data was written. */ - private def copyBufferToPath( - buffer: HostMemoryBuffer, - path: File, - append: Boolean): Long = { - val iter = new HostByteBufferIterator(buffer) - val fos = new FileOutputStream(path, append) - try { - val channel = fos.getChannel - val fileOffset = channel.position - iter.foreach { bb => - while (bb.hasRemaining) { - channel.write(bb) - } + private def serializeBufferToStream( + incoming: RapidsBuffer, + path: File): (Long, Long) = { + withResource(new FileOutputStream(path, false /*append not supported*/)) { fos => + withResource(fos.getChannel) { outputChannel => + val startOffset = outputChannel.position() + incoming.serializeToStream(fos) + val endingOffset = outputChannel.position() + val writtenBytes = endingOffset - startOffset + (startOffset, writtenBytes) } - fileOffset - } finally { - fos.close() } } - class RapidsDiskBuffer( id: RapidsBufferId, fileOffset: Long, size: Long, meta: TableMeta, - spillPriority: Long) + spillPriority: Long, + override val needsSerialization: Boolean = false) extends RapidsBufferBase( id, meta, spillPriority) { private[this] var hostBuffer: Option[HostMemoryBuffer] = None @@ -106,6 +138,8 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) override val storageTier: StorageTier = StorageTier.DISK override def getMemoryBuffer: MemoryBuffer = synchronized { + require(!needsSerialization, + "Called getMemoryBuffer on a disk buffer that needs deserialization") if (hostBuffer.isEmpty) { val path = id.getDiskPath(diskBlockManager) val mappedBuffer = HostMemoryBuffer.mapFile(path, MapMode.READ_WRITE, @@ -117,6 +151,22 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) hostBuffer.get } + override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { + require(needsSerialization, + "Disk buffer was not serialized yet getHostColumnarBatch is being invoked") + require(fileOffset == 0, + "Attempted to obtain a HostColumnarBatch from a spilled RapidsBuffer that is sharing " + + "paths on disk") + val path = id.getDiskPath(diskBlockManager) + withResource(new FileInputStream(path)) { fis => + val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fis) + val hostCols = closeOnExcept(hostBuffer) { _ => + SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes) + } + new ColumnarBatch(hostCols.toArray, header.getNumRows) + } + } + override def close(): Unit = synchronized { if (refcount == 1) { // free the memory mapping since this is the last active reader diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala index 986f7fa73e8..7e230116d47 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala @@ -16,12 +16,20 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange, PinnedMemoryPool} +import java.io.OutputStream +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable + +import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, JCudfSerialization, MemoryBuffer, NvtxColor, NvtxRange, PinnedMemoryPool} import com.nvidia.spark.rapids.Arm.{closeOnExcept, freeOnExcept, withResource} import com.nvidia.spark.rapids.SpillPriorities.{applyPriorityOffset, HOST_MEMORY_BUFFER_SPILL_OFFSET} import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch + /** * A buffer store using host memory. * @param maxSize maximum size in bytes for all buffers in this store @@ -65,7 +73,8 @@ class RapidsHostMemoryStore( buffer.getLength, tableMeta, initialSpillPriority, - buffer) + buffer, + needsSerialization = false) freeOnExcept(rapidsBuffer) { _ => logDebug(s"Adding host buffer for: [id=$id, size=${buffer.getLength}, " + s"uncompressed=${rapidsBuffer.meta.bufferMeta.uncompressedSize}, " + @@ -76,6 +85,21 @@ class RapidsHostMemoryStore( } } + def addBatch(id: RapidsBufferId, + hostCb: ColumnarBatch, + initialSpillPriority: Long, + needsSync: Boolean): RapidsBuffer = { + RapidsHostColumnVector.incRefCounts(hostCb) + val rapidsBuffer = new RapidsHostColumnarBatch( + id, + hostCb, + initialSpillPriority) + freeOnExcept(rapidsBuffer) { _ => + addBuffer(rapidsBuffer, needsSync) + rapidsBuffer + } + } + override protected def createBuffer( other: RapidsBuffer, stream: Cuda.Stream): RapidsBufferBase = { @@ -123,7 +147,8 @@ class RapidsHostMemoryStore( size: Long, meta: TableMeta, spillPriority: Long, - buffer: HostMemoryBuffer) + buffer: HostMemoryBuffer, + override val needsSerialization: Boolean = false) extends RapidsBufferBase(id, meta, spillPriority) with MemoryBuffer.EventHandler { override val storageTier: StorageTier = StorageTier.HOST @@ -191,6 +216,216 @@ class RapidsHostMemoryStore( super.free() } } + + /** + * A per cuDF host column event handler that handles calls to .close() + * inside of the `HostColumnVector` lock. + */ + class RapidsHostColumnEventHandler + extends HostColumnVector.EventHandler { + + // Every RapidsHostColumnarBatch that references this column has an entry in this map. + // The value represents the number of times (normally 1) that a ColumnVector + // appears in the RapidsHostColumnarBatch. This is also the HosColumnVector refCount at which + // the column is considered spillable. + // The map is protected via the ColumnVector lock. + private val registration = new mutable.HashMap[RapidsHostColumnarBatch, Int]() + + /** + * Every RapidsHostColumnarBatch iterates through its columns and either creates + * a `RapidsHostColumnEventHandler` object and associates it with the column's + * `eventHandler` or calls into the existing one, and registers itself. + * + * The registration has two goals: it accounts for repetition of a column + * in a `RapidsHostColumnarBatch`. If a batch has the same column repeated it must adjust + * the refCount at which this column is considered spillable. + * + * The second goal is to account for aliasing. If two host batches alias this column + * we are going to mark it as non spillable. + * + * @param rapidsHostCb - the host batch that is registering itself with this tracker + */ + def register(rapidsHostCb: RapidsHostColumnarBatch, repetition: Int): Unit = { + registration.put(rapidsHostCb, repetition) + } + + /** + * This is invoked during `RapidsHostColumnarBatch.free` in order to remove the entry + * in `registration`. + * + * @param rapidsHostCb - the batch that is de-registering itself + */ + def deregister(rapidsHostCb: RapidsHostColumnarBatch): Unit = { + registration.remove(rapidsHostCb) + } + + // called with the cudf HostColumnVector lock held from cuDF's side + override def onClosed(cudfCv: HostColumnVector, refCount: Int): Unit = { + // we only handle spillability if there is a single batch registered + // (no aliasing) + if (registration.size == 1) { + val (rapidsHostCb, spillableRefCount) = registration.head + if (spillableRefCount == refCount) { + rapidsHostCb.onColumnSpillable(cudfCv) + } + } + } + } + + /** + * A `RapidsHostColumnarBatch` is the spill store holder of ColumnarBatch backed by + * HostColumnVector. + * + * This class owns the host batch and will close it when `close` is called. + * + * @param id the `RapidsBufferId` this batch is associated with + * @param batch the host ColumnarBatch we are managing + * @param spillPriority a starting spill priority + */ + class RapidsHostColumnarBatch( + id: RapidsBufferId, + hostCb: ColumnarBatch, + spillPriority: Long) + extends RapidsBufferBase( + id, + null, + spillPriority) { + + override val storageTier: StorageTier = StorageTier.HOST + + override val needsSerialization: Boolean = true + + // This is the current size in batch form. It is to be used while this + // batch hasn't migrated to another store. + private val hostSizeInByes: Long = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) + + // By default all columns are NOT spillable since we are not the only owners of + // the columns (the caller is holding onto a ColumnarBatch that will be closed + // after instantiation, triggering onClosed callbacks) + // This hash set contains the columns that are currently spillable. + private val columnSpillability = new ConcurrentHashMap[HostColumnVector, Boolean]() + + private val numDistinctColumns = RapidsHostColumnVector.extractBases(hostCb).distinct.size + + // we register our event callbacks as the very first action to deal with + // spillability + registerOnCloseEventHandler() + + /** Release the underlying resources for this buffer. */ + override protected def releaseResources(): Unit = { + hostCb.close() + } + + override def meta: TableMeta = { + null + } + + override def getMemoryUsedBytes: Long = hostSizeInByes + + /** + * Mark a column as spillable + * + * @param column the ColumnVector to mark as spillable + */ + def onColumnSpillable(column: HostColumnVector): Unit = { + columnSpillability.put(column, true) + updateSpillability() + } + + /** + * Update the spillability state of this RapidsHostColumnarBatch. This is invoked from + * two places: + * + * - from the onColumnSpillable callback, which is invoked from a + * HostColumnVector.EventHandler.onClosed callback. + * + * - after adding a batch to the store to mark the batch as spillable if + * all columns are spillable. + */ + override def updateSpillability(): Unit = { + doSetSpillable(this, columnSpillability.size == numDistinctColumns) + } + + override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { + throw new UnsupportedOperationException( + "RapidsHostColumnarBatch does not support getColumnarBatch") + } + + override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { + columnSpillability.clear() + doSetSpillable(this, false) + RapidsHostColumnVector.incRefCounts(hostCb) + } + + override def getMemoryBuffer: MemoryBuffer = { + throw new UnsupportedOperationException( + "RapidsHostColumnarBatch does not support getMemoryBuffer") + } + + override def getCopyIterator: RapidsBufferCopyIterator = { + throw new UnsupportedOperationException( + "RapidsHostColumnarBatch does not support getCopyIterator") + } + + override def serializeToStream(outputStream: OutputStream): Unit = { + val columns = RapidsHostColumnVector.extractBases(hostCb) + JCudfSerialization.writeToStream(columns, outputStream, 0, hostCb.numRows()) + } + + override def free(): Unit = { + // lets remove our handler from the chain of handlers for each column + removeOnCloseEventHandler() + super.free() + } + + private def registerOnCloseEventHandler(): Unit = { + val columns = RapidsHostColumnVector.extractBases(hostCb) + // cudfColumns could contain duplicates. We need to take this into account when we are + // deciding the floor refCount for a duplicated column + val repetitionPerColumn = new mutable.HashMap[HostColumnVector, Int]() + columns.foreach { col => + val repetitionCount = repetitionPerColumn.getOrElse(col, 0) + repetitionPerColumn(col) = repetitionCount + 1 + } + repetitionPerColumn.foreach { case (distinctCv, repetition) => + // lock the column because we are setting its event handler, and we are inspecting + // its refCount. + distinctCv.synchronized { + val eventHandler = distinctCv.getEventHandler match { + case null => + val eventHandler = new RapidsHostColumnEventHandler + distinctCv.setEventHandler(eventHandler) + eventHandler + case existing: RapidsHostColumnEventHandler => + existing + case other => + throw new IllegalStateException( + s"Invalid column event handler $other") + } + eventHandler.register(this, repetition) + if (repetition == distinctCv.getRefCount) { + onColumnSpillable(distinctCv) + } + } + } + } + + // this method is called from free() + private def removeOnCloseEventHandler(): Unit = { + val distinctColumns = RapidsHostColumnVector.extractBases(hostCb).distinct + distinctColumns.foreach { distinctCv => + distinctCv.synchronized { + distinctCv.getEventHandler match { + case eventHandler: RapidsHostColumnEventHandler => + eventHandler.deregister(this) + case t => + throw new IllegalStateException( + s"Invalid column event handler $t") + } + } + } + } + } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index 3b76ad5fa83..98a074bf0cd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -122,6 +122,72 @@ class SpillableColumnarBatchImpl ( } } +class JustRowsHostColumnarBatch(numRows: Int) + extends SpillableColumnarBatch { + override def numRows(): Int = numRows + override def setSpillPriority(priority: Long): Unit = () // NOOP nothing to spill + + def getColumnarBatch(): ColumnarBatch = { + new ColumnarBatch(Array.empty, numRows) + } + + override def close(): Unit = () // NOOP nothing to close + override val sizeInBytes: Long = 0L + + override def dataTypes: Array[DataType] = Array.empty +} + +/** + * The implementation of [[SpillableHostColumnarBatch]] that points to buffers that can be spilled. + * @note the buffer should be in the cache by the time this is created and this is taking over + * ownership of the life cycle of the batch. So don't call this constructor directly please + * use `SpillableHostColumnarBatch.apply` instead. + */ +class SpillableHostColumnarBatchImpl ( + handle: RapidsBufferHandle, + rowCount: Int, + sparkTypes: Array[DataType], + catalog: RapidsBufferCatalog) + extends SpillableColumnarBatch { + + override def dataTypes: Array[DataType] = sparkTypes + /** + * The number of rows stored in this batch. + */ + override def numRows(): Int = rowCount + + private def withRapidsBuffer[T](fn: RapidsBuffer => T): T = { + withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => + fn(rapidsBuffer) + } + } + + override lazy val sizeInBytes: Long = { + withRapidsBuffer(_.getMemoryUsedBytes) + } + + /** + * Set a new spill priority. + */ + override def setSpillPriority(priority: Long): Unit = { + handle.setSpillPriority(priority) + } + + override def getColumnarBatch(): ColumnarBatch = { + withRapidsBuffer { rapidsBuffer => + rapidsBuffer.getHostColumnarBatch(sparkTypes) + } + } + + /** + * Remove the `ColumnarBatch` from the cache. + */ + override def close(): Unit = { + // closing my reference + handle.close() + } +} + object SpillableColumnarBatch { /** * Create a new SpillableColumnarBatch. @@ -207,9 +273,46 @@ object SpillableColumnarBatch { } } } - } +object SpillableHostColumnarBatch { + /** + * Create a new SpillableColumnarBatch backed by host columns. + * + * @note This takes over ownership of batch, and batch should not be used after this. + * @param batch the batch to make spillable + * @param priority the initial spill priority of this batch + */ + def apply( + batch: ColumnarBatch, + priority: Long, + catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton): SpillableColumnarBatch = { + val numRows = batch.numRows() + if (batch.numCols() <= 0) { + // We consumed it + batch.close() + new JustRowsHostColumnarBatch(numRows) + } else { + val types = RapidsHostColumnVector.extractColumns(batch).map(_.dataType()) + val handle = addHostBatch(batch, priority, catalog) + new SpillableHostColumnarBatchImpl( + handle, + numRows, + types, + catalog) + } + } + + private[this] def addHostBatch( + batch: ColumnarBatch, + initialSpillPriority: Long, + catalog: RapidsBufferCatalog): RapidsBufferHandle = { + withResource(batch) { batch => + catalog.addBatch(batch, initialSpillPriority) + } + } + +} /** * Just like a SpillableColumnarBatch but for buffers. */ @@ -351,4 +454,4 @@ object SpillableHostBuffer { } new SpillableHostBuffer(handle, length, catalog) } -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index 6c8fda7a4a0..eaada67d400 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -26,7 +26,7 @@ import scala.ref.WeakReference import scala.util.control.NonFatal import ai.rapids.cudf.{HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange} -import ai.rapids.cudf.JCudfSerialization.{HostConcatResult, SerializedTableHeader} +import ai.rapids.cudf.JCudfSerialization.HostConcatResult import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.GpuMetric._ @@ -51,43 +51,6 @@ import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -object SerializedHostTableUtils { - /** - * Read in a cuDF serialized table into host memory from an input stream. - */ - def readTableHeaderAndBuffer( - in: ObjectInputStream): (JCudfSerialization.SerializedTableHeader, HostMemoryBuffer) = { - val din = new DataInputStream(in) - val header = new JCudfSerialization.SerializedTableHeader(din) - if (!header.wasInitialized()) { - throw new IllegalStateException("Could not read serialized table header") - } - closeOnExcept(HostMemoryBuffer.allocate(header.getDataLen)) { buffer => - JCudfSerialization.readTableIntoBuffer(din, header, buffer) - if (!header.wasDataRead()) { - throw new IllegalStateException("Could not read serialized table data") - } - (header, buffer) - } - } - - /** - * Deserialize a cuDF serialized table to host build column vectors - */ - def buildHostColumns( - header: SerializedTableHeader, - buffer: HostMemoryBuffer, - dataTypes: Array[DataType]): Array[RapidsHostColumnVector] = { - assert(dataTypes.length == header.getNumColumns) - closeOnExcept(JCudfSerialization.unpackHostColumnVectors(header, buffer)) { hostColumns => - assert(hostColumns.length == dataTypes.length) - dataTypes.zip(hostColumns).safeMap { case (dataType, hostColumn) => - new RapidsHostColumnVector(dataType, hostColumn) - } - } - } -} - /** * Class that is used to broadcast results (a contiguous host batch) to executors. * diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/SerializedHostTableUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/SerializedHostTableUtils.scala new file mode 100644 index 00000000000..3e749701652 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/SerializedHostTableUtils.scala @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution + +import java.io.{DataInputStream, InputStream} + +import ai.rapids.cudf.{HostMemoryBuffer, JCudfSerialization} +import com.nvidia.spark.rapids.Arm.closeOnExcept +import com.nvidia.spark.rapids.RapidsHostColumnVector +import com.nvidia.spark.rapids.RapidsPluginImplicits._ + +import org.apache.spark.sql.types.DataType + +object SerializedHostTableUtils { + /** + * Read in a cuDF serialized table into host memory from an input stream. + */ + def readTableHeaderAndBuffer( + in: InputStream): (JCudfSerialization.SerializedTableHeader, HostMemoryBuffer) = { + val din = new DataInputStream(in) + val header = new JCudfSerialization.SerializedTableHeader(din) + if (!header.wasInitialized()) { + throw new IllegalStateException("Could not read serialized table header") + } + closeOnExcept(HostMemoryBuffer.allocate(header.getDataLen)) { buffer => + JCudfSerialization.readTableIntoBuffer(din, header, buffer) + if (!header.wasDataRead()) { + throw new IllegalStateException("Could not read serialized table data") + } + (header, buffer) + } + } + + /** + * Deserialize a cuDF serialized table to host build column vectors + */ + def buildHostColumns( + header: JCudfSerialization.SerializedTableHeader, + buffer: HostMemoryBuffer, + dataTypes: Array[DataType]): Array[RapidsHostColumnVector] = { + assert(dataTypes.length == header.getNumColumns) + closeOnExcept(JCudfSerialization.unpackHostColumnVectors(header, buffer)) { hostColumns => + assert(hostColumns.length == dataTypes.length) + dataTypes.zip(hostColumns).safeMap { case (dataType, hostColumn) => + new RapidsHostColumnVector(dataType, hostColumn) + } + } + } +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala index 7a451d29bab..6200938efbc 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala @@ -55,6 +55,37 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { } } + private def buildHostBatch(): ColumnarBatch = { + val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, + DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) + val hostCols = withResource(buildContiguousTable()) { ct => + withResource(ct.getTable) { tbl => + (0 until tbl.getNumberOfColumns) + .map(c => tbl.getColumn(c).copyToHost()) + } + }.toArray + new ColumnarBatch( + hostCols.zip(sparkTypes).map { case (hostCol, dataType) => + new RapidsHostColumnVector(dataType, hostCol) + }, hostCols.head.getRowCount.toInt) + } + + private def buildHostBatchWithDuplicate(): ColumnarBatch = { + val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, + DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) + val hostCols = withResource(buildContiguousTable()) { ct => + withResource(ct.getTable) { tbl => + (0 until tbl.getNumberOfColumns) + .map(c => tbl.getColumn(c).copyToHost()) + } + }.toArray + hostCols.foreach(_.incRefCount()) + new ColumnarBatch( + (hostCols ++ hostCols).zip(sparkTypes ++ sparkTypes).map { case (hostCol, dataType) => + new RapidsHostColumnVector(dataType, hostCol) + }, hostCols.head.getRowCount.toInt) + } + test("spill updates catalog") { val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 @@ -251,6 +282,207 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { } } + test("host batch originated: get host memory batch") { + val spillPriority = -10 + val hostStoreMaxSize = 1L * 1024 * 1024 + val bm = new RapidsDiskBlockManager(new SparkConf()) + withResource(new RapidsDiskStore(bm)) { diskStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore, hostStore) + devStore.setSpillStore(hostStore) + hostStore.setSpillStore(diskStore) + + val hostCb = buildHostBatch() + + val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) + + withResource( + SpillableHostColumnarBatch(hostCb, spillPriority, catalog)) { spillableBuffer => + assertResult(sizeOnHost)(hostStore.currentSpillableSize) + + withResource(spillableBuffer.getColumnarBatch()) { hostCb => + // 0 because we have a reference to the memoryBuffer + assertResult(0)(hostStore.currentSpillableSize) + val spilled = catalog.synchronousSpill(hostStore, 0) + assertResult(0)(spilled.get) + } + + assertResult(sizeOnHost)(hostStore.currentSpillableSize) + val spilled = catalog.synchronousSpill(hostStore, 0) + assertResult(sizeOnHost)(spilled.get) + + val sizeOnDisk = diskStore.currentSpillableSize + + // reconstitute batch from disk + withResource(spillableBuffer.getColumnarBatch()) { hostCbFromDisk => + // disk has a different size, so this spillable batch has a different sizeInBytes + // right now, because this is the serialized represenation size + assertResult(sizeOnDisk)(spillableBuffer.sizeInBytes) + // lets recreate our original batch and compare to make sure contents match + withResource(buildHostBatch()) { expectedHostCb => + TestUtils.compareBatches(expectedHostCb, hostCbFromDisk) + } + } + } + } + } + } + } + + test("a host batch is not spillable when we leak it") { + val spillPriority = -10 + val hostStoreMaxSize = 1L * 1024 * 1024 + val bm = new RapidsDiskBlockManager(new SparkConf()) + withResource(new RapidsDiskStore(bm)) { diskStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore, hostStore) + devStore.setSpillStore(hostStore) + hostStore.setSpillStore(diskStore) + + val hostCb = buildHostBatch() + + val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) + + val leakedBatch = withResource( + SpillableHostColumnarBatch(hostCb, spillPriority, catalog)) { spillableBuffer => + assertResult(sizeOnHost)(hostStore.currentSpillableSize) + + val leakedBatch = spillableBuffer.getColumnarBatch() + // 0 because we have a reference to the host batch + assertResult(0)(hostStore.currentSpillableSize) + val spilled = catalog.synchronousSpill(hostStore, 0) + assertResult(0)(spilled.get) + leakedBatch + } + + withResource(leakedBatch) { _ => + // 0 because we have leaked that the host batch + assertResult(0)(hostStore.currentSize) + assertResult(0)(hostStore.currentSpillableSize) + val spilled = catalog.synchronousSpill(hostStore, 0) + assertResult(0)(spilled.get) + } + // after closing we still have 0 bytes in the store or available to spill + assertResult(0)(hostStore.currentSize) + assertResult(0)(hostStore.currentSpillableSize) + } + } + } + } + + test("a host batch is not spillable when columns are incRefCounted") { + val spillPriority = -10 + val hostStoreMaxSize = 1L * 1024 * 1024 + val bm = new RapidsDiskBlockManager(new SparkConf()) + withResource(new RapidsDiskStore(bm)) { diskStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore, hostStore) + devStore.setSpillStore(hostStore) + hostStore.setSpillStore(diskStore) + + val hostCb = buildHostBatch() + + val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) + + withResource( + SpillableHostColumnarBatch(hostCb, spillPriority, catalog)) { spillableBuffer => + assertResult(sizeOnHost)(hostStore.currentSpillableSize) + + val leakedFirstColumn = withResource(spillableBuffer.getColumnarBatch()) { hostCb => + // 0 because we have a reference to the host batch + assertResult(0)(hostStore.currentSpillableSize) + val spilled = catalog.synchronousSpill(hostStore, 0) + assertResult(0)(spilled.get) + // leak it by increasing the ref count of the underlying cuDF column + RapidsHostColumnVector.extractBases(hostCb).head.incRefCount() + } + withResource(leakedFirstColumn) { _ => + // 0 because we have a reference to the first column + assertResult(0)(hostStore.currentSpillableSize) + val spilled = catalog.synchronousSpill(hostStore, 0) + assertResult(0)(spilled.get) + } + // batch is now spillable because we close our reference to the column + assertResult(sizeOnHost)(hostStore.currentSpillableSize) + val spilled = catalog.synchronousSpill(hostStore, 0) + assertResult(sizeOnHost)(spilled.get) + } + } + } + } + } + + test("an aliased host batch is not spillable (until closing the original) ") { + val spillPriority = -10 + val hostStoreMaxSize = 1L * 1024 * 1024 + val bm = new RapidsDiskBlockManager(new SparkConf()) + withResource(new RapidsDiskStore(bm)) { diskStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore, hostStore) + val hostBatch = buildHostBatch() + val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostBatch) + val handle = withResource(hostBatch) { _ => + catalog.addBatch(hostBatch, spillPriority) + } + withResource(handle) { _ => + val types: Array[DataType] = + Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray + assertResult(sizeOnHost)(hostStore.currentSize) + assertResult(sizeOnHost)(hostStore.currentSpillableSize) + withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => + // extract the batch from the table we added, and add it back as a batch + withResource(rapidsBuffer.getHostColumnarBatch(types)) { batch => + catalog.addBatch(batch, spillPriority) + } + } // we now have two copies in the store + assertResult(sizeOnHost * 2)(hostStore.currentSize) + assertResult(0)(hostStore.currentSpillableSize) + } // remove the original + assertResult(sizeOnHost)(hostStore.currentSize) + assertResult(sizeOnHost)(hostStore.currentSpillableSize) + } + } + } + } + + test("an aliased host batch supports duplicated columns") { + val spillPriority = -10 + val hostStoreMaxSize = 1L * 1024 * 1024 + val bm = new RapidsDiskBlockManager(new SparkConf()) + withResource(new RapidsDiskStore(bm)) { diskStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore, hostStore) + val hostBatch = buildHostBatchWithDuplicate() + val sizeOnHost = RapidsHostColumnVector.getTotalHostMemoryUsed(hostBatch) + val handle = withResource(hostBatch) { _ => + catalog.addBatch(hostBatch, spillPriority) + } + withResource(handle) { _ => + val types: Array[DataType] = + Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray + assertResult(sizeOnHost)(hostStore.currentSize) + assertResult(sizeOnHost)(hostStore.currentSpillableSize) + withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => + // extract the batch from the table we added, and add it back as a batch + withResource(rapidsBuffer.getHostColumnarBatch(types)) { batch => + catalog.addBatch(batch, spillPriority) + } + } // we now have two copies in the store + assertResult(sizeOnHost * 2)(hostStore.currentSize) + assertResult(0)(hostStore.currentSpillableSize) + } // remove the original + assertResult(sizeOnHost)(hostStore.currentSize) + assertResult(sizeOnHost)(hostStore.currentSpillableSize) + } + } + } + } + test("buffer exceeds maximum size") { val sparkTypes = Array[DataType](LongType) val spillPriority = -10 From ff6ab4c1c2318e392589da06f566b1dded7511c1 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Sun, 27 Aug 2023 23:02:31 -0500 Subject: [PATCH 2/2] Introduce a trait that encapsulates host-backed ColumnarBath pieces --- .../spark/rapids/RapidsHostColumnVector.java | 3 +- .../nvidia/spark/rapids/RapidsBuffer.scala | 41 +++-- .../spark/rapids/RapidsBufferCatalog.scala | 28 ++++ .../nvidia/spark/rapids/RapidsDiskStore.scala | 151 +++++++++--------- .../spark/rapids/RapidsHostMemoryStore.scala | 43 +++-- .../spark/rapids/SpillableColumnarBatch.scala | 11 +- 6 files changed, 169 insertions(+), 108 deletions(-) diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java index d0c0d4ee1ef..c7913cd93e5 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala index e8ed1ae4184..ed0699e92f0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids import java.io.File -import java.io.OutputStream +import java.nio.channels.WritableByteChannel import scala.collection.mutable.ArrayBuffer @@ -245,24 +245,12 @@ trait RapidsBuffer extends AutoCloseable { def getCopyIterator: RapidsBufferCopyIterator = new RapidsBufferCopyIterator(this) - /** - * At spill time, the tier we are spilling to may need to hand the rapids buffer an output stream - * to write to. This is the case for a `RapidsHostColumnarBatch`. - * @param outputStream stream that the `RapidsBuffer` will serialize itself to - */ - def serializeToStream(outputStream: OutputStream): Unit = { - throw new IllegalStateException(s"Buffer $this does not support serializeToStream") - } - /** Descriptor for how the memory buffer is formatted */ def meta: TableMeta /** The storage tier for this buffer */ val storageTier: StorageTier - /** A RapidsBuffer that needs to be serialized/deserialized at spill/materialization time */ - val needsSerialization: Boolean = false - /** * Get the columnar batch within this buffer. The caller must have * successfully acquired the buffer beforehand. @@ -448,3 +436,30 @@ sealed class DegenerateRapidsBuffer( override def close(): Unit = {} } + +trait RapidsHostBatchBuffer extends AutoCloseable { + /** + * Get the host-backed columnar batch from this buffer. The caller must have + * successfully acquired the buffer beforehand. + * + * If this `RapidsBuffer` was added originally to the device tier, or if this is + * a just a buffer (not a batch), this function will throw. + * + * @param sparkTypes the spark data types the batch should have + * @see [[addReference]] + * @note It is the responsibility of the caller to close the batch. + */ + def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch + + def getMemoryUsedBytes(): Long +} + +trait RapidsBufferChannelWritable { + /** + * At spill time, write this buffer to an nio WritableByteChannel. + * @param writableChannel that this buffer can just write itself to, either byte-for-byte + * or via serialization if needed. + * @return the amount of bytes written to the channel + */ + def writeToChannel(writableChannel: WritableByteChannel): Long +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index edbbae53688..6c1593058db 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -456,6 +456,23 @@ class RapidsBufferCatalog( throw new IllegalStateException(s"Unable to acquire buffer for ID: $id") } + /** + * Acquires a RapidsBuffer that the caller expects to be host-backed and not + * device bound. This ensures that the buffer acquired implements the correct + * trait, otherwise it throws and removes its buffer acquisition. + * + * @param handle handle associated with this `RapidsBuffer` + * @return host-backed RapidsBuffer that has been acquired + */ + def acquireHostBatchBuffer(handle: RapidsBufferHandle): RapidsHostBatchBuffer = { + closeOnExcept(acquireBuffer(handle)) { + case hrb: RapidsHostBatchBuffer => hrb + case other => + throw new IllegalStateException( + s"Attempted to acquire a RapidsHostBatchBuffer, but got $other instead") + } + } + /** * Lookup the buffer that corresponds to the specified buffer ID at the specified storage tier, * and acquire it. @@ -940,6 +957,17 @@ object RapidsBufferCatalog extends Logging { def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = singleton.acquireBuffer(handle) + /** + * Acquires a RapidsBuffer that the caller expects to be host-backed and not + * device bound. This ensures that the buffer acquired implements the correct + * trait, otherwise it throws and removes its buffer acquisition. + * + * @param handle handle associated with this `RapidsBuffer` + * @return host-backed RapidsBuffer that has been acquired + */ + def acquireHostBatchBuffer(handle: RapidsBufferHandle): RapidsHostBatchBuffer = + singleton.acquireHostBatchBuffer(handle) + def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala index a76da733a8c..bd34317f927 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala @@ -49,86 +49,62 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) val (fileOffset, diskLength) = if (id.canShareDiskPaths) { // only one writer at a time for now when using shared files path.synchronized { - if (incoming.needsSerialization) { - // only shuffle buffers share paths, and adding host-backed ColumnarBatch is - // not the shuffle case, so this is not supported and is never exercised. - throw new IllegalStateException( - s"Attempted spilling to disk a RapidsBuffer $incoming that needs serialization " + - s"while sharing spill paths.") - } else { - copyBufferToPath(incoming, path, append = true) - } + writeToFile(incoming, path, append = true) } } else { - if (incoming.needsSerialization) { - serializeBufferToStream(incoming, path) - } else { - copyBufferToPath(incoming, path, append = false) - } + writeToFile(incoming, path, append = false) } logDebug(s"Spilled to $path $fileOffset:$diskLength") - new RapidsDiskBuffer( - id, - fileOffset, - diskLength, - incoming.meta, - incoming.getSpillPriority, - incoming.needsSerialization) + incoming match { + case _: RapidsHostBatchBuffer => + new RapidsDiskColumnarBatch( + id, + fileOffset, + diskLength, + incoming.meta, + incoming.getSpillPriority) + + case _ => + new RapidsDiskBuffer( + id, + fileOffset, + diskLength, + incoming.meta, + incoming.getSpillPriority) + } } /** Copy a host buffer to a file, returning the file offset at which the data was written. */ - private def copyBufferToPath( + private def writeToFile( incoming: RapidsBuffer, path: File, append: Boolean): (Long, Long) = { - val incomingBuffer = - withResource(incoming.getCopyIterator) { incomingCopyIterator => - incomingCopyIterator.next() - } - withResource(incomingBuffer) { _ => - val hostBuffer = incomingBuffer match { - case h: HostMemoryBuffer => h - case _ => throw new UnsupportedOperationException("buffer without host memory") - } - val iter = new HostByteBufferIterator(hostBuffer) - val fos = new FileOutputStream(path, append) - try { - val channel = fos.getChannel - val fileOffset = channel.position - iter.foreach { bb => - while (bb.hasRemaining) { - channel.write(bb) + incoming match { + case fileWritable: RapidsBufferChannelWritable => + withResource(new FileOutputStream(path, append)) { fos => + withResource(fos.getChannel) { outputChannel => + val startOffset = outputChannel.position() + val writtenBytes = fileWritable.writeToChannel(outputChannel) + (startOffset, writtenBytes) } } - (fileOffset, channel.position()) - } finally { - fos.close() - } - } - } - - private def serializeBufferToStream( - incoming: RapidsBuffer, - path: File): (Long, Long) = { - withResource(new FileOutputStream(path, false /*append not supported*/)) { fos => - withResource(fos.getChannel) { outputChannel => - val startOffset = outputChannel.position() - incoming.serializeToStream(fos) - val endingOffset = outputChannel.position() - val writtenBytes = endingOffset - startOffset - (startOffset, writtenBytes) - } + case other => + throw new IllegalStateException( + s"Unable to write $other to file") } } + /** + * A RapidsDiskBuffer that is mean to represent device-bound memory. This + * buffer can produce a device-backed ColumnarBatch. + */ class RapidsDiskBuffer( id: RapidsBufferId, fileOffset: Long, size: Long, meta: TableMeta, - spillPriority: Long, - override val needsSerialization: Boolean = false) + spillPriority: Long) extends RapidsBufferBase( id, meta, spillPriority) { private[this] var hostBuffer: Option[HostMemoryBuffer] = None @@ -138,8 +114,6 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) override val storageTier: StorageTier = StorageTier.DISK override def getMemoryBuffer: MemoryBuffer = synchronized { - require(!needsSerialization, - "Called getMemoryBuffer on a disk buffer that needs deserialization") if (hostBuffer.isEmpty) { val path = id.getDiskPath(diskBlockManager) val mappedBuffer = HostMemoryBuffer.mapFile(path, MapMode.READ_WRITE, @@ -151,22 +125,6 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) hostBuffer.get } - override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - require(needsSerialization, - "Disk buffer was not serialized yet getHostColumnarBatch is being invoked") - require(fileOffset == 0, - "Attempted to obtain a HostColumnarBatch from a spilled RapidsBuffer that is sharing " + - "paths on disk") - val path = id.getDiskPath(diskBlockManager) - withResource(new FileInputStream(path)) { fis => - val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fis) - val hostCols = closeOnExcept(hostBuffer) { _ => - SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes) - } - new ColumnarBatch(hostCols.toArray, header.getNumRows) - } - } - override def close(): Unit = synchronized { if (refcount == 1) { // free the memory mapping since this is the last active reader @@ -193,4 +151,43 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) } } } + + /** + * A RapidsDiskBuffer that should remain in the host, producing host-backed + * ColumnarBatch if the caller invokes getHostColumnarBatch, but not producing + * anything on the device. + */ + class RapidsDiskColumnarBatch( + id: RapidsBufferId, + fileOffset: Long, + size: Long, + // TODO: remove meta + meta: TableMeta, + spillPriority: Long) + extends RapidsDiskBuffer( + id, fileOffset, size, meta, spillPriority) + with RapidsHostBatchBuffer { + + override def getMemoryBuffer: MemoryBuffer = + throw new IllegalStateException( + "Called getMemoryBuffer on a disk buffer that needs deserialization") + + override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = + throw new IllegalStateException( + "Called getColumnarBatch on a disk buffer that needs deserialization") + + override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { + require(fileOffset == 0, + "Attempted to obtain a HostColumnarBatch from a spilled RapidsBuffer that is sharing " + + "paths on disk") + val path = id.getDiskPath(diskBlockManager) + withResource(new FileInputStream(path)) { fis => + val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fis) + val hostCols = closeOnExcept(hostBuffer) { _ => + SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes) + } + new ColumnarBatch(hostCols.toArray, header.getNumRows) + } + } + } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala index 7e230116d47..e3dbfad7b8f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala @@ -16,7 +16,8 @@ package com.nvidia.spark.rapids -import java.io.OutputStream +import java.io.DataOutputStream +import java.nio.channels.{Channels, WritableByteChannel} import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable @@ -27,6 +28,7 @@ import com.nvidia.spark.rapids.SpillPriorities.{applyPriorityOffset, HOST_MEMORY import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta +import org.apache.spark.sql.rapids.storage.RapidsStorageUtils import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -73,8 +75,7 @@ class RapidsHostMemoryStore( buffer.getLength, tableMeta, initialSpillPriority, - buffer, - needsSerialization = false) + buffer) freeOnExcept(rapidsBuffer) { _ => logDebug(s"Adding host buffer for: [id=$id, size=${buffer.getLength}, " + s"uncompressed=${rapidsBuffer.meta.bufferMeta.uncompressedSize}, " + @@ -147,9 +148,9 @@ class RapidsHostMemoryStore( size: Long, meta: TableMeta, spillPriority: Long, - buffer: HostMemoryBuffer, - override val needsSerialization: Boolean = false) + buffer: HostMemoryBuffer) extends RapidsBufferBase(id, meta, spillPriority) + with RapidsBufferChannelWritable with MemoryBuffer.EventHandler { override val storageTier: StorageTier = StorageTier.HOST @@ -161,6 +162,21 @@ class RapidsHostMemoryStore( } } + override def writeToChannel(outputChannel: WritableByteChannel): Long = { + var written: Long = 0L + val iter = new HostByteBufferIterator(buffer) + iter.foreach { bb => + try { + while (bb.hasRemaining) { + written += outputChannel.write(bb) + } + } finally { + RapidsStorageUtils.dispose(bb) + } + } + written + } + override def updateSpillability(): Unit = { if (buffer.getRefCount == 1) { setSpillable(this, true) @@ -289,12 +305,12 @@ class RapidsHostMemoryStore( extends RapidsBufferBase( id, null, - spillPriority) { + spillPriority) + with RapidsBufferChannelWritable + with RapidsHostBatchBuffer { override val storageTier: StorageTier = StorageTier.HOST - override val needsSerialization: Boolean = true - // This is the current size in batch form. It is to be used while this // batch hasn't migrated to another store. private val hostSizeInByes: Long = RapidsHostColumnVector.getTotalHostMemoryUsed(hostCb) @@ -367,9 +383,14 @@ class RapidsHostMemoryStore( "RapidsHostColumnarBatch does not support getCopyIterator") } - override def serializeToStream(outputStream: OutputStream): Unit = { - val columns = RapidsHostColumnVector.extractBases(hostCb) - JCudfSerialization.writeToStream(columns, outputStream, 0, hostCb.numRows()) + override def writeToChannel(outputChannel: WritableByteChannel): Long = { + withResource(Channels.newOutputStream(outputChannel)) { outputStream => + withResource(new DataOutputStream(outputStream)) { dos => + val columns = RapidsHostColumnVector.extractBases(hostCb) + JCudfSerialization.writeToStream(columns, dos, 0, hostCb.numRows()) + dos.size() + } + } } override def free(): Unit = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index 98a074bf0cd..beb5db35cbd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -151,19 +151,20 @@ class SpillableHostColumnarBatchImpl ( extends SpillableColumnarBatch { override def dataTypes: Array[DataType] = sparkTypes + /** * The number of rows stored in this batch. */ override def numRows(): Int = rowCount - private def withRapidsBuffer[T](fn: RapidsBuffer => T): T = { - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => + private def withRapidsHostBatchBuffer[T](fn: RapidsHostBatchBuffer => T): T = { + withResource(catalog.acquireHostBatchBuffer(handle)) { rapidsBuffer => fn(rapidsBuffer) } } override lazy val sizeInBytes: Long = { - withRapidsBuffer(_.getMemoryUsedBytes) + withRapidsHostBatchBuffer(_.getMemoryUsedBytes) } /** @@ -174,8 +175,8 @@ class SpillableHostColumnarBatchImpl ( } override def getColumnarBatch(): ColumnarBatch = { - withRapidsBuffer { rapidsBuffer => - rapidsBuffer.getHostColumnarBatch(sparkTypes) + withRapidsHostBatchBuffer { hostBatchBuffer => + hostBatchBuffer.getHostColumnarBatch(sparkTypes) } }