Skip to content

Commit

Permalink
Exclude the semaphore wating time from the deserialization metric (#38)
Browse files Browse the repository at this point in the history
Exclude the semaphore wating time from the deserialization metric, along with a double close bug fix.

---------

Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman authored May 7, 2024
1 parent 61404b6 commit f9b634c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ trait GpuPartitioning extends Partitioning {
if (_serializingOnGPU) {
table =>
withResource(new NvtxRange("Table to Host", NvtxColor.BLUE)) { _ =>
withResource(table) { _ =>
PackedTableHostColumnVector.from(table)
}
PackedTableHostColumnVector.from(table)
}
} else {
GpuCompressedColumnVector.from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ private[rapids] class SimpleTableSerializer extends TableSerde {
}
}

private[rapids] class SimpleTableDeserializer(sparkTypes: Array[DataType]) extends TableSerde {
private[rapids] class SimpleTableDeserializer(
sparkTypes: Array[DataType],
deserTime: GpuMetric) extends TableSerde {
private def readProtocolHeader(dIn: DataInputStream): Unit = {
val magicNum = dIn.readInt()
if (magicNum != P_MAGIC_NUM) {
Expand Down Expand Up @@ -172,10 +174,12 @@ private[rapids] class SimpleTableDeserializer(sparkTypes: Array[DataType]) exten
def readFromStream(dIn: DataInputStream): ColumnarBatch = {
// IO operation is coming, so leave GPU for a while.
GpuSemaphore.releaseIfNecessary(TaskContext.get())
// 1) read and check header
readProtocolHeader(dIn)
// 2) read table metadata
val tableMeta = TableMeta.getRootAsTableMeta(readByteBufferFromStream(dIn))
val tableMeta = deserTime.ns {
// 1) read and check header
readProtocolHeader(dIn)
// 2) read table metadata
TableMeta.getRootAsTableMeta(readByteBufferFromStream(dIn))
}
if (tableMeta.packedMetaAsByteBuffer() == null) {
// no packed metadata, must be a table with zero columns
// Acquiring the GPU even the coming batch is empty, because the downstream
Expand All @@ -186,39 +190,42 @@ private[rapids] class SimpleTableDeserializer(sparkTypes: Array[DataType]) exten
} else {
// 3) read table data
val hostBuf = withResource(new NvtxRange("Read Host Table", NvtxColor.ORANGE)) { _ =>
readHostBufferFromStream(dIn)
deserTime.ns(readHostBufferFromStream(dIn))
}
val data = withResource(hostBuf) { _ =>
// Begin to use GPU
GpuSemaphore.acquireIfNecessary(TaskContext.get())
withResource(new NvtxRange("Table to Device", NvtxColor.YELLOW)) { _ =>
closeOnExcept(DeviceMemoryBuffer.allocate(hostBuf.getLength)) { devBuf =>
devBuf.copyFromHostBuffer(hostBuf)
devBuf
deserTime.ns {
closeOnExcept(DeviceMemoryBuffer.allocate(hostBuf.getLength)) { devBuf =>
devBuf.copyFromHostBuffer(hostBuf)
devBuf
}
}
}
}
withResource(new NvtxRange("Deserialize Table", NvtxColor.RED)) { _ =>
withResource(data) { _ =>
val bufferMeta = tableMeta.bufferMeta()
if (bufferMeta == null || bufferMeta.codecBufferDescrsLength == 0) {
MetaUtils.getBatchFromMeta(data, tableMeta, sparkTypes)
} else {
// Compressed table is not supported by the write side, but ok to
// put it here for the read side. Since compression will be supported later.
GpuCompressedColumnVector.from(data, tableMeta)
deserTime.ns {
withResource(data) { _ =>
val bufferMeta = tableMeta.bufferMeta()
if (bufferMeta == null || bufferMeta.codecBufferDescrsLength == 0) {
MetaUtils.getBatchFromMeta(data, tableMeta, sparkTypes)
} else {
GpuCompressedColumnVector.from(data, tableMeta)
}
}
}
}
}
}

}

private[rapids] class SerializedTableIterator(dIn: DataInputStream,
sparkTypes: Array[DataType],
deserTime: GpuMetric) extends Iterator[(Int, ColumnarBatch)] {

private val tableDeserializer = new SimpleTableDeserializer(sparkTypes)
private val tableDeserializer = new SimpleTableDeserializer(sparkTypes, deserTime)
private var closed = false
private var onDeck: Option[SpillableColumnarBatch] = None
Option(TaskContext.get()).foreach { tc =>
Expand Down Expand Up @@ -255,10 +262,8 @@ private[rapids] class SerializedTableIterator(dIn: DataInputStream,
return
}
try {
onDeck = deserTime.ns(
Some(SpillableColumnarBatch(tableDeserializer.readFromStream(dIn),
SpillPriorities.ACTIVE_ON_DECK_PRIORITY))
)
onDeck = Some(SpillableColumnarBatch(tableDeserializer.readFromStream(dIn),
SpillPriorities.ACTIVE_ON_DECK_PRIORITY))
} catch {
case _: EOFException => // we reach the end
dIn.close()
Expand Down

0 comments on commit f9b634c

Please sign in to comment.