Skip to content

Commit

Permalink
Allow skipping host spill for a direct device->disk spill (#9211)
Browse files Browse the repository at this point in the history
* Refactor code to allow skipping host

* Implement RapidsBufferChannelWritable in RapidsTable and RapidsDeviceMemoryBuffer

* Add a test for device->disk skipping host

* Small fixes

* getMemoryUsedBytes -> memoryUsedBytes as a val

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>

* Update copyright

* Fix bug where buffer could be spilled to a lower tier while it was being spilled at a higher tier

* Fix leak in aggregate when there are retries

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>

* Ensure that 0-byte RapidsBuffers are never spillable

* Remove spillBuffer from catalog

* Fix test issues

---------

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
  • Loading branch information
abellina authored Sep 15, 2023
1 parent 4826b49 commit 34d615d
Show file tree
Hide file tree
Showing 19 changed files with 747 additions and 294 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import java.nio.ByteBuffer

import ai.rapids.cudf.{Cuda, HostMemoryBuffer, MemoryBuffer}

abstract class AbstractHostByteBufferIterator
extends Iterator[ByteBuffer] {
private[this] var nextBufferStart: Long = 0L

val totalLength: Long

protected val limit: Long = Integer.MAX_VALUE

def getByteBuffer(offset: Long, length: Long): ByteBuffer

override def hasNext: Boolean = nextBufferStart < totalLength

override def next(): ByteBuffer = {
val offset = nextBufferStart
val length = Math.min(totalLength - nextBufferStart, limit)
nextBufferStart += length
getByteBuffer(offset, length)
}
}

/**
* Create an iterator that will emit ByteBuffer instances sequentially
* to work around the 2GB ByteBuffer size limitation. This allows
* the entire address range of a >2GB host buffer to be covered
* by a sequence of ByteBuffer instances.
* <p>NOTE: It is the caller's responsibility to ensure this iterator
* does not outlive the host buffer. The iterator DOES NOT increment
* the reference count of the host buffer to ensure it remains valid.
*
* @param hostBuffer host buffer to iterate
* @return ByteBuffer iterator
*/
class HostByteBufferIterator(hostBuffer: HostMemoryBuffer)
extends AbstractHostByteBufferIterator {
override protected val limit: Long = Integer.MAX_VALUE

override val totalLength: Long = if (hostBuffer == null) {
0
} else {
hostBuffer.getLength
}

override def getByteBuffer(offset: Long, length: Long): ByteBuffer = {
hostBuffer.asByteBuffer(offset, length.toInt)
}
}

/**
* Create an iterator that will emit ByteBuffer instances sequentially
* to work around the 2GB ByteBuffer size limitation after copying a `MemoryBuffer`
* (which is likely a `DeviceMemoryBuffer`) to a host-backed bounce buffer
* that is likely smaller than 2GB.
* @note It is the caller's responsibility to ensure this iterator
* does not outlive `memoryBuffer`. The iterator DOES NOT increment
* the reference count of `memoryBuffer` to ensure it remains valid.
* @param memoryBuffer memory buffer to copy. This is likely a DeviceMemoryBuffer
* @param bounceBuffer a host bounce buffer that will be used to stage copies onto the host
* @param stream stream to synchronize on after staging to bounceBuffer
* @return ByteBuffer iterator
*/
class MemoryBufferToHostByteBufferIterator(
memoryBuffer: MemoryBuffer,
bounceBuffer: HostMemoryBuffer,
stream: Cuda.Stream)
extends AbstractHostByteBufferIterator {
override val totalLength: Long = if (memoryBuffer == null) {
0
} else {
memoryBuffer.getLength
}

override protected val limit: Long =
Math.min(bounceBuffer.getLength, Integer.MAX_VALUE)

override def getByteBuffer(offset: Long, length: Long): ByteBuffer = {
bounceBuffer
.copyFromMemoryBufferAsync(0, memoryBuffer, offset, length, stream)
stream.sync()
bounceBuffer.asByteBuffer(0, length.toInt)
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,14 @@ trait RapidsBuffer extends AutoCloseable {
*
* @note Do not use this size to allocate a target buffer to copy, always use `getPackedSize.`
*/
def getMemoryUsedBytes: Long
val memoryUsedBytes: Long

/**
* The size of this buffer if it has already gone through contiguous_split.
*
* @note Use this function when allocating a target buffer for spill or shuffle purposes.
*/
def getPackedSizeBytes: Long = getMemoryUsedBytes
def getPackedSizeBytes: Long = memoryUsedBytes

/**
* At spill time, obtain an iterator used to copy this buffer to a different tier.
Expand Down Expand Up @@ -389,7 +389,7 @@ sealed class DegenerateRapidsBuffer(
override val id: RapidsBufferId,
override val meta: TableMeta) extends RapidsBuffer {

override def getMemoryUsedBytes: Long = 0L
override val memoryUsedBytes: Long = 0L

override val storageTier: StorageTier = StorageTier.DEVICE

Expand Down Expand Up @@ -451,15 +451,18 @@ trait RapidsHostBatchBuffer extends AutoCloseable {
*/
def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch

def getMemoryUsedBytes(): Long
val memoryUsedBytes: Long
}

trait RapidsBufferChannelWritable {
/**
* At spill time, write this buffer to an nio WritableByteChannel.
* @param writableChannel that this buffer can just write itself to, either byte-for-byte
* or via serialization if needed.
* @param stream the Cuda.Stream for the spilling thread. If the `RapidsBuffer` that
* implements this method is on the device, synchronization may be needed
* for staged copies.
* @return the amount of bytes written to the channel
*/
def writeToChannel(writableChannel: WritableByteChannel): Long
def writeToChannel(writableChannel: WritableByteChannel, stream: Cuda.Stream): Long
}
Loading

0 comments on commit 34d615d

Please sign in to comment.