Skip to content

Commit

Permalink
getMemoryUsedBytes -> memoryUsedBytes as a val
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
  • Loading branch information
abellina committed Sep 8, 2023
1 parent cbb954f commit 49659ba
Show file tree
Hide file tree
Showing 16 changed files with 55 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,14 @@ trait RapidsBuffer extends AutoCloseable {
*
* @note Do not use this size to allocate a target buffer to copy, always use `getPackedSize.`
*/
def getMemoryUsedBytes: Long
val memoryUsedBytes: Long

/**
* The size of this buffer if it has already gone through contiguous_split.
*
* @note Use this function when allocating a target buffer for spill or shuffle purposes.
*/
def getPackedSizeBytes: Long = getMemoryUsedBytes
def getPackedSizeBytes: Long = memoryUsedBytes

/**
* At spill time, obtain an iterator used to copy this buffer to a different tier.
Expand Down Expand Up @@ -389,7 +389,7 @@ sealed class DegenerateRapidsBuffer(
override val id: RapidsBufferId,
override val meta: TableMeta) extends RapidsBuffer {

override def getMemoryUsedBytes: Long = 0L
override val memoryUsedBytes: Long = 0L

override val storageTier: StorageTier = StorageTier.DEVICE

Expand Down Expand Up @@ -451,7 +451,7 @@ trait RapidsHostBatchBuffer extends AutoCloseable {
*/
def getHostColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch

def getMemoryUsedBytes(): Long
val memoryUsedBytes: Long
}

trait RapidsBufferChannelWritable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ class RapidsBufferCatalog(
val totalSpilled = bufferSpills.map { case BufferSpill(spilledBuffer, maybeNewBuffer) =>
maybeNewBuffer.foreach(registerNewBuffer)
removeBufferTier(spilledBuffer.id, spilledBuffer.storageTier)
spilledBuffer.getMemoryUsedBytes
spilledBuffer.memoryUsedBytes
}.sum
Some(totalSpilled)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ import org.apache.spark.sql.rapids.GpuTaskMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* A helper case class that contains the buffer we spilled from our current tier
* and likely a new buffer created in a spill store tier, but it can be set to None.
* If the buffer already exists in the target spill store, `newBuffer` will be None.
* @param spilledBuffer a `RapidsBuffer` we spilled from this store
* @param newBuffer an optional `RapidsBuffer` in the target spill store.
*/
case class BufferSpill(spilledBuffer: RapidsBuffer, newBuffer: Option[RapidsBuffer])

/**
Expand Down Expand Up @@ -69,14 +76,14 @@ abstract class RapidsBufferStore(val tier: StorageTier)
if (old != null) {
throw new DuplicateBufferException(s"duplicate buffer registered: ${buffer.id}")
}
totalBytesStored += buffer.getMemoryUsedBytes
totalBytesStored += buffer.memoryUsedBytes

// device buffers "spillability" is handled via DeviceMemoryBuffer ref counting
// so spillableOnAdd should be false, all other buffer tiers are spillable at
// all times.
if (spillableOnAdd) {
if (spillable.offer(buffer)) {
totalBytesSpillable += buffer.getMemoryUsedBytes
totalBytesSpillable += buffer.memoryUsedBytes
}
}
}
Expand All @@ -86,9 +93,9 @@ abstract class RapidsBufferStore(val tier: StorageTier)
spilling.remove(id)
val obj = buffers.remove(id)
if (obj != null) {
totalBytesStored -= obj.getMemoryUsedBytes
totalBytesStored -= obj.memoryUsedBytes
if (spillable.remove(obj)) {
totalBytesSpillable -= obj.getMemoryUsedBytes
totalBytesSpillable -= obj.memoryUsedBytes
}
}
}
Expand Down Expand Up @@ -122,14 +129,14 @@ abstract class RapidsBufferStore(val tier: StorageTier)
if (!spilling.contains(buffer.id) && buffers.containsKey(buffer.id)) {
// try to add it to the spillable collection
if (spillable.offer(buffer)) {
totalBytesSpillable += buffer.getMemoryUsedBytes
totalBytesSpillable += buffer.memoryUsedBytes
logDebug(s"Buffer ${buffer.id} is spillable. " +
s"total=${totalBytesStored} spillable=${totalBytesSpillable}")
} // else it was already there (unlikely)
}
} else {
if (spillable.remove(buffer)) {
totalBytesSpillable -= buffer.getMemoryUsedBytes
totalBytesSpillable -= buffer.memoryUsedBytes
logDebug(s"Buffer ${buffer.id} is not spillable. " +
s"total=${totalBytesStored}, spillable=${totalBytesSpillable}")
} // else it was already removed
Expand All @@ -141,8 +148,8 @@ abstract class RapidsBufferStore(val tier: StorageTier)
if (buffer != null) {
// mark the id as "spilling" (this buffer is in the middle of a spill operation)
spilling.add(buffer.id)
totalBytesSpillable -= buffer.getMemoryUsedBytes
logDebug(s"Spilling buffer ${buffer.id}. size=${buffer.getMemoryUsedBytes} " +
totalBytesSpillable -= buffer.memoryUsedBytes
logDebug(s"Spilling buffer ${buffer.id}. size=${buffer.memoryUsedBytes} " +
s"total=${totalBytesStored}, new spillable=${totalBytesSpillable}")
}
buffer
Expand Down Expand Up @@ -294,7 +301,14 @@ abstract class RapidsBufferStore(val tier: StorageTier)
this,
stream)
bufferSpills.append(bufferSpill)
totalSpilled += bufferSpill.spilledBuffer.getMemoryUsedBytes
totalSpilled += bufferSpill.spilledBuffer.memoryUsedBytes
} else {
// if `nextSpillableBuffer` already spilled, we still need to
// remove it from our tier and call free on it, but set
// `newBuffer` to None because there's nothing to register
// as it has already spilled.
bufferSpills.append(BufferSpill(nextSpillableBuffer, None))
totalSpilled += nextSpillableBuffer.memoryUsedBytes
}
}
}
Expand Down Expand Up @@ -331,7 +345,9 @@ abstract class RapidsBufferStore(val tier: StorageTier)
/**
* Given a specific `RapidsBuffer` spill it to `spillStore`
*
* @return the buffer, if successfully spilled, in order for the caller to free it
* @return a `BufferSpill` instance with the target buffer in this store, and an optional
* new `RapidsBuffer` in the target spill store if this rapids buffer hadn't already
* spilled.
* @note called with catalog lock held
*/
private def spillBuffer(
Expand All @@ -351,7 +367,7 @@ abstract class RapidsBufferStore(val tier: StorageTier)
}
if (maybeNewBuffer.isEmpty) {
throw new IllegalStateException(
s"Unable to spill buffer ${buffer.id} of size ${buffer.getMemoryUsedBytes} " +
s"Unable to spill buffer ${buffer.id} of size ${buffer.memoryUsedBytes} " +
s"to tier ${lastTier}")
}
// return the buffer to free and the new buffer to register
Expand Down Expand Up @@ -527,7 +543,7 @@ abstract class RapidsBufferStore(val tier: StorageTier)
freeBuffer()
}
} else {
logWarning(s"Trying to free an invalid buffer => $id, size = ${getMemoryUsedBytes}, $this")
logWarning(s"Trying to free an invalid buffer => $id, size = ${memoryUsedBytes}, $this")
}
}

Expand Down Expand Up @@ -569,7 +585,7 @@ abstract class RapidsBufferStore(val tier: StorageTier)
releaseResources()
}

override def toString: String = s"$name buffer size=${getMemoryUsedBytes}"
override def toString: String = s"$name buffer size=${memoryUsedBytes}"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class RapidsDeviceMemoryStore(
chunkedPacker.getMeta
}

override def getMemoryUsedBytes: Long = unpackedSizeInBytes
override val memoryUsedBytes: Long = unpackedSizeInBytes

override def getPackedSizeBytes: Long = getChunkedPacker.getTotalContiguousSize

Expand Down Expand Up @@ -415,7 +415,7 @@ class RapidsDeviceMemoryStore(
with MemoryBuffer.EventHandler
with RapidsBufferChannelWritable {

override def getMemoryUsedBytes(): Long = size
override val memoryUsedBytes: Long = size

override val storageTier: StorageTier = StorageTier.DEVICE

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager)
id, meta, spillPriority) {
private[this] var hostBuffer: Option[HostMemoryBuffer] = None

override def getMemoryUsedBytes(): Long = size
override val memoryUsedBytes: Long = size

override val storageTier: StorageTier = StorageTier.DISK

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class RapidsGdsStore(
extends RapidsBufferBase(id, meta, spillPriority) {
override val storageTier: StorageTier = StorageTier.GDS

override def getMemoryUsedBytes(): Long = size
override val memoryUsedBytes: Long = size

override def getMemoryBuffer: MemoryBuffer = getDeviceMemoryBuffer
}
Expand Down Expand Up @@ -232,7 +232,7 @@ class RapidsGdsStore(
var isPending: Boolean = true)
extends RapidsGdsBuffer(id, size, meta, spillPriority) {

override def getMemoryUsedBytes(): Long = size
override val memoryUsedBytes: Long = size

override def materializeMemoryBuffer: MemoryBuffer = this.synchronized {
closeOnExcept(DeviceMemoryBuffer.allocate(size)) { buffer =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class RapidsHostMemoryStore(
stream: Cuda.Stream): Boolean = {
// this spillStore has a maximum size requirement (host only). We need to spill from it
// in order to make room for `buffer`.
val targetTotalSize = maxSize - buffer.getMemoryUsedBytes
val targetTotalSize = maxSize - buffer.memoryUsedBytes
if (targetTotalSize <= 0) {
// lets not spill to host when the buffer we are about
// to spill is larger than our limit
Expand All @@ -112,15 +112,15 @@ class RapidsHostMemoryStore(
val amountSpilled =
synchronousSpill(targetTotalSize, stream).map {
case BufferSpill(spilledBuffer, _) =>
spilledBuffer.getMemoryUsedBytes
spilledBuffer.memoryUsedBytes
}.sum

if (amountSpilled != 0) {
logInfo(s"Spilled $amountSpilled bytes from the ${spillStore.name} store")
TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled)
}
// if after spill we can fit the new buffer, return true
buffer.getMemoryUsedBytes <= currentSize
buffer.memoryUsedBytes <= currentSize
}
}

Expand Down Expand Up @@ -219,7 +219,7 @@ class RapidsHostMemoryStore(
}

/** The size of this buffer in bytes. */
override def getMemoryUsedBytes: Long = size
override val memoryUsedBytes: Long = size

// If this require triggers, we are re-adding a `HostMemoryBuffer` outside of
// the catalog lock, which should not possible. The event handler is set to null
Expand Down Expand Up @@ -364,7 +364,7 @@ class RapidsHostMemoryStore(
null
}

override def getMemoryUsedBytes: Long = hostSizeInByes
override val memoryUsedBytes: Long = hostSizeInByes

/**
* Mark a column as spillable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class SpillableColumnarBatchImpl (
}

override lazy val sizeInBytes: Long =
withRapidsBuffer(_.getMemoryUsedBytes)
withRapidsBuffer(_.memoryUsedBytes)

/**
* Set a new spill priority.
Expand Down Expand Up @@ -164,7 +164,7 @@ class SpillableHostColumnarBatchImpl (
}

override lazy val sizeInBytes: Long = {
withRapidsHostBatchBuffer(_.getMemoryUsedBytes)
withRapidsHostBatchBuffer(_.memoryUsedBytes)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ class RapidsShuffleIterator(
try {
sb = catalog.acquireBuffer(handle)
cb = sb.getColumnarBatch(sparkTypes)
metricsUpdater.update(blockedTime, 1, sb.getMemoryUsedBytes, cb.numRows())
metricsUpdater.update(blockedTime, 1, sb.memoryUsedBytes, cb.numRows())
} finally {
nvtxRangeAfterGettingBatch.close()
range.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class RapidsBufferCatalogSuite extends AnyFunSuite with MockitoSugar {
var _acquireAttempts: Int = acquireAttempts
var currentPriority: Long = initialPriority
override val id: RapidsBufferId = bufferId
override def getMemoryUsedBytes: Long = 0
override val memoryUsedBytes: Long = 0
override def meta: TableMeta = tableMeta
override val storageTier: StorageTier = tier
override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ class RapidsDeviceMemoryStoreSuite extends AnyFunSuite with MockitoSugar {
throw new UnsupportedOperationException

/** The size of this buffer in bytes. */
override def getMemoryUsedBytes: Long = size
override val memoryUsedBytes: Long = size
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with MockitoSugar {
ArgumentMatchers.eq(handle.id), ArgumentMatchers.eq(StorageTier.DEVICE))
withResource(catalog.acquireBuffer(handle)) { buffer =>
assertResult(StorageTier.DISK)(buffer.storageTier)
assertResult(bufferSize)(buffer.getMemoryUsedBytes)
assertResult(bufferSize)(buffer.memoryUsedBytes)
assertResult(handle.id)(buffer.id)
assertResult(spillPriority)(buffer.getSpillPriority)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with MockitoSugar {
withResource(catalog.acquireBuffer(handle)) { buffer =>
assertResult(StorageTier.GDS)(buffer.storageTier)
assertResult(id)(buffer.id)
assertResult(size)(buffer.getMemoryUsedBytes)
assertResult(size)(buffer.memoryUsedBytes)
assertResult(spillPriority)(buffer.getSpillPriority)
}
}
Expand Down Expand Up @@ -126,7 +126,7 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with MockitoSugar {
ArgumentMatchers.eq(bufferId), ArgumentMatchers.eq(StorageTier.DEVICE))
withResource(catalog.acquireBuffer(handle)) { buffer =>
assertResult(StorageTier.GDS)(buffer.storageTier)
assertResult(bufferSize)(buffer.getMemoryUsedBytes)
assertResult(bufferSize)(buffer.memoryUsedBytes)
assertResult(bufferId)(buffer.id)
assertResult(spillPriority)(buffer.getSpillPriority)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar {
ArgumentMatchers.eq(handle.id), ArgumentMatchers.eq(StorageTier.DEVICE))
withResource(catalog.acquireBuffer(handle)) { buffer =>
assertResult(StorageTier.HOST)(buffer.storageTier)
assertResult(bufferSize)(buffer.getMemoryUsedBytes)
assertResult(bufferSize)(buffer.memoryUsedBytes)
assertResult(handle.id)(buffer.id)
assertResult(spillPriority)(buffer.getSpillPriority)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
assert(cl.hasNext)
assertResult(cb)(cl.next())
assertResult(1)(testMetricsUpdater.totalRemoteBlocksFetched)
assertResult(mockBuffer.getMemoryUsedBytes)(testMetricsUpdater.totalRemoteBytesRead)
assertResult(mockBuffer.memoryUsedBytes)(testMetricsUpdater.totalRemoteBytesRead)
assertResult(10)(testMetricsUpdater.totalRowsFetched)
} finally {
RmmSpark.taskDone(taskId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SpillableColumnarBatchSuite extends AnyFunSuite {
}

class MockBuffer(override val id: RapidsBufferId) extends RapidsBuffer {
override def getMemoryUsedBytes: Long = 123
override val memoryUsedBytes: Long = 123
override def meta: TableMeta = null
override val storageTier: StorageTier = StorageTier.DEVICE
override def getMemoryBuffer: MemoryBuffer = null
Expand Down

0 comments on commit 49659ba

Please sign in to comment.