Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SpillableHostColumnarBatch #9098

Merged
merged 2 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
jlowe marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,9 +16,12 @@

package com.nvidia.spark.rapids;

import ai.rapids.cudf.HostColumnVector;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.vectorized.ColumnarBatch;

import java.util.HashSet;

/**
* A GPU accelerated version of the Spark ColumnVector.
* Most of the standard Spark APIs should never be called, as they assume that the data
Expand Down Expand Up @@ -57,6 +59,25 @@ public static RapidsHostColumnVector[] extractColumns(ColumnarBatch batch) {
return vectors;
}

public static ColumnarBatch incRefCounts(ColumnarBatch batch) {
for (RapidsHostColumnVector rapidsHostCv: extractColumns(batch)) {
rapidsHostCv.incRefCount();
}
return batch;
}

public static long getTotalHostMemoryUsed(ColumnarBatch batch) {
long sum = 0;
if (batch.numCols() > 0) {
HashSet<RapidsHostColumnVector> found = new HashSet<>();
for (RapidsHostColumnVector rapidsHostCv: extractColumns(batch)) {
if (found.add(rapidsHostCv)) {
abellina marked this conversation as resolved.
Show resolved Hide resolved
sum += rapidsHostCv.getHostMemoryUsed();
}
}
}
return sum;
}

private final ai.rapids.cudf.HostColumnVector cudfCv;

Expand All @@ -75,6 +96,10 @@ public final RapidsHostColumnVector incRefCount() {
return this;
}

public final long getHostMemoryUsed() {
return cudfCv.getHostMemorySize();
}

public final ai.rapids.cudf.HostColumnVector getBase() {
return cudfCv;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import java.io.File
import java.nio.channels.WritableByteChannel

import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -175,7 +176,6 @@ class RapidsBufferCopyIterator(buffer: RapidsBuffer)
} else {
None
}

def isChunked: Boolean = chunkedPacker.isDefined

// this is used for the single shot case to flag when `next` is call
Expand Down Expand Up @@ -263,6 +263,21 @@ trait RapidsBuffer extends AutoCloseable {
*/
def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch

/**
* Get the host-backed columnar batch from this buffer. The caller must have
* successfully acquired the buffer beforehand.
*
* If this `RapidsBuffer` was added originally to the device tier, or if this is
* a just a buffer (not a batch), this function will throw.
*
* @param sparkTypes the spark data types the batch should have
* @see [[addReference]]
* @note It is the responsibility of the caller to close the batch.
*/
def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = {
throw new IllegalStateException(s"$this does not support host columnar batches.")
Comment on lines +277 to +278
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Similar comment here. It's nice that this doesn't have the sentinel value, but I'd rather see a trait that defines the ability to provide a HostColumnarBatch and have those that need to use it on their underlying RAPIDS buffer pattern match to downcast the buffer type to get access to this rather than have a method that explodes if you don't carefully know what you're doing. Not a must-fix for me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting these in a trait is easy, I can do that.

}

/**
* Get the underlying memory buffer. This may be either a HostMemoryBuffer or a DeviceMemoryBuffer
* depending on where the buffer currently resides.
Expand Down Expand Up @@ -421,3 +436,30 @@ sealed class DegenerateRapidsBuffer(

override def close(): Unit = {}
}

trait RapidsHostBatchBuffer extends AutoCloseable {
/**
* Get the host-backed columnar batch from this buffer. The caller must have
* successfully acquired the buffer beforehand.
*
* If this `RapidsBuffer` was added originally to the device tier, or if this is
* a just a buffer (not a batch), this function will throw.
*
* @param sparkTypes the spark data types the batch should have
* @see [[addReference]]
* @note It is the responsibility of the caller to close the batch.
*/
def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch

def getMemoryUsedBytes(): Long
}

trait RapidsBufferChannelWritable {
/**
* At spill time, write this buffer to an nio WritableByteChannel.
* @param writableChannel that this buffer can just write itself to, either byte-for-byte
* or via serialization if needed.
* @return the amount of bytes written to the channel
*/
def writeToChannel(writableChannel: WritableByteChannel): Long
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,15 @@ class RapidsBufferCatalog(
batch: ColumnarBatch,
initialSpillPriority: Long,
needsSync: Boolean = true): RapidsBufferHandle = {
closeOnExcept(GpuColumnVector.from(batch)) { table =>
addTable(table, initialSpillPriority, needsSync)
require(batch.numCols() > 0,
"Cannot call addBatch with a batch that doesn't have columns")
batch.column(0) match {
case _: RapidsHostColumnVector =>
addHostBatch(batch, initialSpillPriority, needsSync)
case _ =>
closeOnExcept(GpuColumnVector.from(batch)) { table =>
addTable(table, initialSpillPriority, needsSync)
}
}
}

Expand Down Expand Up @@ -381,6 +388,25 @@ class RapidsBufferCatalog(
makeNewHandle(id, initialSpillPriority)
}


/**
* Add a host-backed ColumnarBatch to the catalog. This is only called from addBatch
* after we detect that this is a host-backed batch.
*/
private def addHostBatch(
hostCb: ColumnarBatch,
initialSpillPriority: Long,
needsSync: Boolean): RapidsBufferHandle = {
val id = TempSpillBufferId()
val rapidsBuffer = hostStorage.addBatch(
id,
hostCb,
initialSpillPriority,
needsSync)
registerNewBuffer(rapidsBuffer)
makeNewHandle(id, initialSpillPriority)
}

/**
* Register a degenerate RapidsBufferId given a TableMeta
* @note this is called from the shuffle catalogs only
Expand Down Expand Up @@ -430,6 +456,23 @@ class RapidsBufferCatalog(
throw new IllegalStateException(s"Unable to acquire buffer for ID: $id")
}

/**
* Acquires a RapidsBuffer that the caller expects to be host-backed and not
* device bound. This ensures that the buffer acquired implements the correct
* trait, otherwise it throws and removes its buffer acquisition.
*
* @param handle handle associated with this `RapidsBuffer`
* @return host-backed RapidsBuffer that has been acquired
*/
def acquireHostBatchBuffer(handle: RapidsBufferHandle): RapidsHostBatchBuffer = {
closeOnExcept(acquireBuffer(handle)) {
case hrb: RapidsHostBatchBuffer => hrb
case other =>
throw new IllegalStateException(
s"Attempted to acquire a RapidsHostBatchBuffer, but got $other instead")
}
}

/**
* Lookup the buffer that corresponds to the specified buffer ID at the specified storage tier,
* and acquire it.
Expand Down Expand Up @@ -914,6 +957,17 @@ object RapidsBufferCatalog extends Logging {
def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer =
singleton.acquireBuffer(handle)

/**
* Acquires a RapidsBuffer that the caller expects to be host-backed and not
* device bound. This ensures that the buffer acquired implements the correct
* trait, otherwise it throws and removes its buffer acquisition.
*
* @param handle handle associated with this `RapidsBuffer`
* @return host-backed RapidsBuffer that has been acquired
*/
def acquireHostBatchBuffer(handle: RapidsBufferHandle): RapidsHostBatchBuffer =
singleton.acquireHostBatchBuffer(handle)

def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@

package com.nvidia.spark.rapids

import java.io.{File, FileOutputStream}
import java.io.{File, FileInputStream, FileOutputStream}
import java.nio.channels.FileChannel.MapMode
import java.util.concurrent.ConcurrentHashMap

import ai.rapids.cudf.{Cuda, HostMemoryBuffer, MemoryBuffer}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta

import org.apache.spark.sql.rapids.RapidsDiskBlockManager
import org.apache.spark.sql.rapids.execution.SerializedHostTableUtils
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnarBatch

/** A buffer store using files on the local disks. */
class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
Expand All @@ -36,61 +39,66 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
incoming: RapidsBuffer,
stream: Cuda.Stream): RapidsBufferBase = {
// assuming that the disk store gets contiguous buffers
val incomingBuffer =
withResource(incoming.getCopyIterator) { incomingCopyIterator =>
incomingCopyIterator.next()
}
withResource(incomingBuffer) { _ =>
val hostBuffer = incomingBuffer match {
case h: HostMemoryBuffer => h
case _ => throw new UnsupportedOperationException("buffer without host memory")
}
val id = incoming.id
val path = if (id.canShareDiskPaths) {
sharedBufferFiles.computeIfAbsent(id, _ => id.getDiskPath(diskBlockManager))
} else {
id.getDiskPath(diskBlockManager)
}
val fileOffset = if (id.canShareDiskPaths) {
// only one writer at a time for now when using shared files
path.synchronized {
copyBufferToPath(hostBuffer, path, append = true)
}
} else {
copyBufferToPath(hostBuffer, path, append = false)
val id = incoming.id
val path = if (id.canShareDiskPaths) {
sharedBufferFiles.computeIfAbsent(id, _ => id.getDiskPath(diskBlockManager))
} else {
id.getDiskPath(diskBlockManager)
}

val (fileOffset, diskLength) = if (id.canShareDiskPaths) {
// only one writer at a time for now when using shared files
path.synchronized {
writeToFile(incoming, path, append = true)
}
logDebug(s"Spilled to $path $fileOffset:${incomingBuffer.getLength}")
new RapidsDiskBuffer(
id,
fileOffset,
incomingBuffer.getLength,
incoming.meta,
incoming.getSpillPriority)
} else {
writeToFile(incoming, path, append = false)
}

logDebug(s"Spilled to $path $fileOffset:$diskLength")
incoming match {
case _: RapidsHostBatchBuffer =>
new RapidsDiskColumnarBatch(
id,
fileOffset,
diskLength,
incoming.meta,
incoming.getSpillPriority)

case _ =>
new RapidsDiskBuffer(
id,
fileOffset,
diskLength,
incoming.meta,
incoming.getSpillPriority)
}
}

/** Copy a host buffer to a file, returning the file offset at which the data was written. */
private def copyBufferToPath(
buffer: HostMemoryBuffer,
private def writeToFile(
incoming: RapidsBuffer,
path: File,
append: Boolean): Long = {
val iter = new HostByteBufferIterator(buffer)
val fos = new FileOutputStream(path, append)
try {
val channel = fos.getChannel
val fileOffset = channel.position
iter.foreach { bb =>
while (bb.hasRemaining) {
channel.write(bb)
append: Boolean): (Long, Long) = {
incoming match {
case fileWritable: RapidsBufferChannelWritable =>
withResource(new FileOutputStream(path, append)) { fos =>
withResource(fos.getChannel) { outputChannel =>
val startOffset = outputChannel.position()
val writtenBytes = fileWritable.writeToChannel(outputChannel)
(startOffset, writtenBytes)
}
}
}
fileOffset
} finally {
fos.close()
case other =>
throw new IllegalStateException(
s"Unable to write $other to file")
}
}


/**
* A RapidsDiskBuffer that is mean to represent device-bound memory. This
* buffer can produce a device-backed ColumnarBatch.
*/
class RapidsDiskBuffer(
id: RapidsBufferId,
fileOffset: Long,
Expand Down Expand Up @@ -143,4 +151,43 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
}
}
}

/**
* A RapidsDiskBuffer that should remain in the host, producing host-backed
* ColumnarBatch if the caller invokes getHostColumnarBatch, but not producing
* anything on the device.
*/
class RapidsDiskColumnarBatch(
id: RapidsBufferId,
fileOffset: Long,
size: Long,
// TODO: remove meta
meta: TableMeta,
spillPriority: Long)
extends RapidsDiskBuffer(
id, fileOffset, size, meta, spillPriority)
with RapidsHostBatchBuffer {

override def getMemoryBuffer: MemoryBuffer =
throw new IllegalStateException(
"Called getMemoryBuffer on a disk buffer that needs deserialization")

override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch =
throw new IllegalStateException(
"Called getColumnarBatch on a disk buffer that needs deserialization")

override def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = {
require(fileOffset == 0,
"Attempted to obtain a HostColumnarBatch from a spilled RapidsBuffer that is sharing " +
"paths on disk")
val path = id.getDiskPath(diskBlockManager)
withResource(new FileInputStream(path)) { fis =>
val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fis)
val hostCols = closeOnExcept(hostBuffer) { _ =>
SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes)
}
new ColumnarBatch(hostCols.toArray, header.getNumRows)
}
}
}
}
Loading