diff --git a/sql-plugin/src/main/format/ShuffleCommon.fbs b/sql-plugin/src/main/format/ShuffleCommon.fbs index 9a37614ec19..be643b154f7 100644 --- a/sql-plugin/src/main/format/ShuffleCommon.fbs +++ b/sql-plugin/src/main/format/ShuffleCommon.fbs @@ -15,20 +15,43 @@ namespace com.nvidia.spark.rapids.format; enum CodecType : byte { + /// data simply copied, codec is only for testing + COPY = -1, + /// no compression codec was used on the data - UNCOMPRESSED = 0 + UNCOMPRESSED = 0, +} + +/// Descriptor for a compressed buffer +table CodecBufferDescriptor { + /// the compression codec used + codec: CodecType; + + /// byte offset from the start of the enclosing compressed buffer + /// where the compressed data begins + compressed_offset: long; + + /// size of the compressed data in bytes + compressed_size: long; + + /// byte offset from the start of the enclosing uncompressed buffer + /// where the uncompressed data should be written + uncompressed_offset: long; + + /// size of the uncompressed data in bytes + uncompressed_size: long; } table BufferMeta { /// ID of this buffer id: int; - /// size of the uncompressed buffer data in bytes - actual_size: long; + /// size of the buffer data in bytes + size: long; - /// size of the compressed buffer data if a codec is used - compressed_size: long; + /// size of the uncompressed buffer data + uncompressed_size: long; - /// type of compression codec used - codec: CodecType; + /// array of codec buffer descriptors if the data is compressed + codec_buffer_descrs: [CodecBufferDescriptor]; } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/BufferMeta.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/BufferMeta.java index f3bec7588df..ea8ea94ce93 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/BufferMeta.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/BufferMeta.java @@ -20,39 +20,42 @@ public final class BufferMeta extends Table { public int id() { int o = __offset(4); return o != 0 ? bb.getInt(o + bb_pos) : 0; } public boolean mutateId(int id) { int o = __offset(4); if (o != 0) { bb.putInt(o + bb_pos, id); return true; } else { return false; } } /** - * size of the uncompressed buffer data in bytes + * size of the buffer data in bytes */ - public long actualSize() { int o = __offset(6); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } - public boolean mutateActualSize(long actual_size) { int o = __offset(6); if (o != 0) { bb.putLong(o + bb_pos, actual_size); return true; } else { return false; } } + public long size() { int o = __offset(6); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } + public boolean mutateSize(long size) { int o = __offset(6); if (o != 0) { bb.putLong(o + bb_pos, size); return true; } else { return false; } } /** - * size of the compressed buffer data if a codec is used + * size of the uncompressed buffer data */ - public long compressedSize() { int o = __offset(8); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } - public boolean mutateCompressedSize(long compressed_size) { int o = __offset(8); if (o != 0) { bb.putLong(o + bb_pos, compressed_size); return true; } else { return false; } } + public long uncompressedSize() { int o = __offset(8); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } + public boolean mutateUncompressedSize(long uncompressed_size) { int o = __offset(8); if (o != 0) { bb.putLong(o + bb_pos, uncompressed_size); return true; } else { return false; } } /** - * type of compression codec used + * array of codec buffer descriptors if the data is compressed */ - public byte codec() { int o = __offset(10); return o != 0 ? bb.get(o + bb_pos) : 0; } - public boolean mutateCodec(byte codec) { int o = __offset(10); if (o != 0) { bb.put(o + bb_pos, codec); return true; } else { return false; } } + public CodecBufferDescriptor codecBufferDescrs(int j) { return codecBufferDescrs(new CodecBufferDescriptor(), j); } + public CodecBufferDescriptor codecBufferDescrs(CodecBufferDescriptor obj, int j) { int o = __offset(10); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } + public int codecBufferDescrsLength() { int o = __offset(10); return o != 0 ? __vector_len(o) : 0; } public static int createBufferMeta(FlatBufferBuilder builder, int id, - long actual_size, - long compressed_size, - byte codec) { + long size, + long uncompressed_size, + int codec_buffer_descrsOffset) { builder.startObject(4); - BufferMeta.addCompressedSize(builder, compressed_size); - BufferMeta.addActualSize(builder, actual_size); + BufferMeta.addUncompressedSize(builder, uncompressed_size); + BufferMeta.addSize(builder, size); + BufferMeta.addCodecBufferDescrs(builder, codec_buffer_descrsOffset); BufferMeta.addId(builder, id); - BufferMeta.addCodec(builder, codec); return BufferMeta.endBufferMeta(builder); } public static void startBufferMeta(FlatBufferBuilder builder) { builder.startObject(4); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } - public static void addActualSize(FlatBufferBuilder builder, long actualSize) { builder.addLong(1, actualSize, 0L); } - public static void addCompressedSize(FlatBufferBuilder builder, long compressedSize) { builder.addLong(2, compressedSize, 0L); } - public static void addCodec(FlatBufferBuilder builder, byte codec) { builder.addByte(3, codec, 0); } + public static void addSize(FlatBufferBuilder builder, long size) { builder.addLong(1, size, 0L); } + public static void addUncompressedSize(FlatBufferBuilder builder, long uncompressedSize) { builder.addLong(2, uncompressedSize, 0L); } + public static void addCodecBufferDescrs(FlatBufferBuilder builder, int codecBufferDescrsOffset) { builder.addOffset(3, codecBufferDescrsOffset, 0); } + public static int createCodecBufferDescrsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startCodecBufferDescrsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endBufferMeta(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/CodecBufferDescriptor.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/CodecBufferDescriptor.java new file mode 100644 index 00000000000..681cc7ab31d --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/CodecBufferDescriptor.java @@ -0,0 +1,74 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package com.nvidia.spark.rapids.format; + +import java.nio.*; +import java.lang.*; +import java.util.*; +import com.google.flatbuffers.*; + +@SuppressWarnings("unused") +/** + * Descriptor for a compressed buffer + */ +public final class CodecBufferDescriptor extends Table { + public static CodecBufferDescriptor getRootAsCodecBufferDescriptor(ByteBuffer _bb) { return getRootAsCodecBufferDescriptor(_bb, new CodecBufferDescriptor()); } + public static CodecBufferDescriptor getRootAsCodecBufferDescriptor(ByteBuffer _bb, CodecBufferDescriptor obj) { _bb.order(ByteOrder.LITTLE_ENDIAN); return (obj.__assign(_bb.getInt(_bb.position()) + _bb.position(), _bb)); } + public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; vtable_start = bb_pos - bb.getInt(bb_pos); vtable_size = bb.getShort(vtable_start); } + public CodecBufferDescriptor __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; } + + /** + * the compression codec used + */ + public byte codec() { int o = __offset(4); return o != 0 ? bb.get(o + bb_pos) : 0; } + public boolean mutateCodec(byte codec) { int o = __offset(4); if (o != 0) { bb.put(o + bb_pos, codec); return true; } else { return false; } } + /** + * byte offset from the start of the enclosing compressed buffer + * where the compressed data begins + */ + public long compressedOffset() { int o = __offset(6); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } + public boolean mutateCompressedOffset(long compressed_offset) { int o = __offset(6); if (o != 0) { bb.putLong(o + bb_pos, compressed_offset); return true; } else { return false; } } + /** + * size of the compressed data in bytes + */ + public long compressedSize() { int o = __offset(8); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } + public boolean mutateCompressedSize(long compressed_size) { int o = __offset(8); if (o != 0) { bb.putLong(o + bb_pos, compressed_size); return true; } else { return false; } } + /** + * byte offset from the start of the enclosing uncompressed buffer + * where the uncompressed data should be written + */ + public long uncompressedOffset() { int o = __offset(10); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } + public boolean mutateUncompressedOffset(long uncompressed_offset) { int o = __offset(10); if (o != 0) { bb.putLong(o + bb_pos, uncompressed_offset); return true; } else { return false; } } + /** + * size of the uncompressed data in bytes + */ + public long uncompressedSize() { int o = __offset(12); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } + public boolean mutateUncompressedSize(long uncompressed_size) { int o = __offset(12); if (o != 0) { bb.putLong(o + bb_pos, uncompressed_size); return true; } else { return false; } } + + public static int createCodecBufferDescriptor(FlatBufferBuilder builder, + byte codec, + long compressed_offset, + long compressed_size, + long uncompressed_offset, + long uncompressed_size) { + builder.startObject(5); + CodecBufferDescriptor.addUncompressedSize(builder, uncompressed_size); + CodecBufferDescriptor.addUncompressedOffset(builder, uncompressed_offset); + CodecBufferDescriptor.addCompressedSize(builder, compressed_size); + CodecBufferDescriptor.addCompressedOffset(builder, compressed_offset); + CodecBufferDescriptor.addCodec(builder, codec); + return CodecBufferDescriptor.endCodecBufferDescriptor(builder); + } + + public static void startCodecBufferDescriptor(FlatBufferBuilder builder) { builder.startObject(5); } + public static void addCodec(FlatBufferBuilder builder, byte codec) { builder.addByte(0, codec, 0); } + public static void addCompressedOffset(FlatBufferBuilder builder, long compressedOffset) { builder.addLong(1, compressedOffset, 0L); } + public static void addCompressedSize(FlatBufferBuilder builder, long compressedSize) { builder.addLong(2, compressedSize, 0L); } + public static void addUncompressedOffset(FlatBufferBuilder builder, long uncompressedOffset) { builder.addLong(3, uncompressedOffset, 0L); } + public static void addUncompressedSize(FlatBufferBuilder builder, long uncompressedSize) { builder.addLong(4, uncompressedSize, 0L); } + public static int endCodecBufferDescriptor(FlatBufferBuilder builder) { + int o = builder.endObject(); + return o; + } +} + diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/CodecType.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/CodecType.java index bb90659dbfb..bed8af8d41c 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/CodecType.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/format/CodecType.java @@ -4,13 +4,17 @@ public final class CodecType { private CodecType() { } + /** + * data simply copied, codec is only for testing + */ + public static final byte COPY = -1; /** * no compression codec was used on the data */ public static final byte UNCOMPRESSED = 0; - public static final String[] names = { "UNCOMPRESSED", }; + public static final String[] names = { "COPY", "UNCOMPRESSED", }; - public static String name(int e) { return names[e]; } + public static String name(int e) { return names[e - COPY]; } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala index f7c100641cb..63d6a2f8bef 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala @@ -45,23 +45,81 @@ object MetaUtils { buffer) } + /** + * Build a TableMeta message + * @param tableId the ID to use for this table + * @param columns the columns in the table + * @param numRows the number of rows in the table + * @param buffer the contiguous buffer backing the columns in the table + * @return heap-based flatbuffer message + */ def buildTableMeta( tableId: Int, columns: Seq[ColumnVector], numRows: Long, buffer: DeviceMemoryBuffer): TableMeta = { val fbb = new FlatBufferBuilder(1024) - val baseAddress = buffer.getAddress val bufferSize = buffer.getLength - val bufferMetaOffset = BufferMeta.createBufferMeta( + BufferMeta.startBufferMeta(fbb) + BufferMeta.addId(fbb, tableId) + BufferMeta.addSize(fbb, bufferSize) + BufferMeta.addUncompressedSize(fbb, bufferSize) + val bufferMetaOffset = BufferMeta.endBufferMeta(fbb) + buildTableMeta(fbb, bufferMetaOffset, columns, numRows, buffer.getAddress) + } + + /** + * Build a TableMeta message from a Table that originated in contiguous memory that has + * since been compressed. + * @param tableId ID to use for this table + * @param table table whose metadata will be encoded in the message + * @param uncompressedBuffer uncompressed buffer that backs the Table + * @param codecId identifier of the codec being used, see CodecType + * @param compressedSize compressed data from the uncompressed buffer + * @return heap-based flatbuffer message + */ + def buildTableMeta( + tableId: Int, + table: Table, + uncompressedBuffer: DeviceMemoryBuffer, + codecId: Byte, + compressedSize: Long): TableMeta = { + val fbb = new FlatBufferBuilder(1024) + val codecDescrOffset = CodecBufferDescriptor.createCodecBufferDescriptor( fbb, - tableId, - bufferSize, - bufferSize, - CodecType.UNCOMPRESSED) + codecId, + 0, + compressedSize, + 0, + uncompressedBuffer.getLength) + val codecDescrArrayOffset = + BufferMeta.createCodecBufferDescrsVector(fbb, Array(codecDescrOffset)) + BufferMeta.startBufferMeta(fbb) + BufferMeta.addId(fbb, tableId) + BufferMeta.addSize(fbb, compressedSize) + BufferMeta.addUncompressedSize(fbb, uncompressedBuffer.getLength) + BufferMeta.addCodecBufferDescrs(fbb, codecDescrArrayOffset) + val bufferMetaOffset = BufferMeta.endBufferMeta(fbb) + val columns = (0 until table.getNumberOfColumns).map(i => table.getColumn(i)) + buildTableMeta(fbb, bufferMetaOffset, columns, table.getRowCount, uncompressedBuffer.getAddress) + } + /** + * Build a TableMeta message with a pre-built BufferMeta message + * @param fbb flatbuffer builder that has an already built BufferMeta message + * @param bufferMetaOffset offset where the BufferMeta message was built + * @param columns the columns in the table + * @param numRows the number of rows in the table + * @param baseAddress address of uncompressed contiguous buffer holding the table + * @return flatbuffer message + */ + def buildTableMeta( + fbb: FlatBufferBuilder, + bufferMetaOffset: Int, + columns: Seq[ColumnVector], + numRows: Long, + baseAddress: Long): TableMeta = { val columnMetaOffsets = columns.map(col => addColumnMeta(fbb, baseAddress, col)).toArray - val columnMetasOffset = TableMeta.createColumnMetasVector(fbb, columnMetaOffsets) TableMeta.startTableMeta(fbb) TableMeta.addBufferMeta(fbb, bufferMetaOffset) @@ -187,6 +245,30 @@ object ShuffleMetadata extends Logging{ val bbFactory = new DirectByteBufferFactory + private def copyBufferMeta(fbb: FlatBufferBuilder, buffMeta: BufferMeta): Int = { + val descrOffsets = (0 until buffMeta.codecBufferDescrsLength()).map { i => + val descr = buffMeta.codecBufferDescrs(i) + CodecBufferDescriptor.createCodecBufferDescriptor( + fbb, + descr.codec, + descr.compressedOffset, + descr.compressedSize, + descr.uncompressedOffset, + descr.uncompressedSize) + } + val codecDescrArrayOffset = if (descrOffsets.nonEmpty) { + Some(BufferMeta.createCodecBufferDescrsVector(fbb, descrOffsets.toArray)) + } else { + None + } + BufferMeta.startBufferMeta(fbb) + BufferMeta.addId(fbb, buffMeta.id) + BufferMeta.addSize(fbb, buffMeta.size) + BufferMeta.addUncompressedSize(fbb, buffMeta.uncompressedSize) + codecDescrArrayOffset.foreach(off => BufferMeta.addCodecBufferDescrs(fbb, off)) + BufferMeta.endBufferMeta(fbb) + } + /** * Given a sequence of `TableMeta`, re-lay the metas using the flat buffer builder in `fbb`. * @param fbb builder to use @@ -196,10 +278,8 @@ object ShuffleMetadata extends Logging{ def copyTables(fbb: FlatBufferBuilder, tables: Seq[TableMeta]): Array[Int] = { tables.map { tableMeta => val buffMeta = tableMeta.bufferMeta() - val buffMetaOffset = if (buffMeta != null) { - Some(BufferMeta.createBufferMeta(fbb, buffMeta.id(), buffMeta.actualSize(), - buffMeta.compressedSize(), buffMeta.codec())) + Some(copyBufferMeta(fbb, buffMeta)) } else { None } @@ -240,12 +320,8 @@ object ShuffleMetadata extends Logging{ } TableMeta.startTableMeta(fbb) - if (buffMetaOffset.isDefined) { - TableMeta.addBufferMeta(fbb, buffMetaOffset.get) - } - if (columnMetaOffset.isDefined) { - TableMeta.addColumnMetas(fbb, columnMetaOffset.get) - } + buffMetaOffset.foreach(bmo => TableMeta.addBufferMeta(fbb, bmo)) + columnMetaOffset.foreach(cmo => TableMeta.addColumnMetas(fbb, cmo)) TableMeta.addRowCount(fbb, tableMeta.rowCount()) TableMeta.endTableMeta(fbb) }.toArray @@ -272,7 +348,7 @@ object ShuffleMetadata extends Logging{ blockIds : Seq[ShuffleBlockBatchId], maxResponseSize: Long) : ByteBuffer = { val fbb = new FlatBufferBuilder(1024, bbFactory) - val blockIdOffsets = blockIds.map { case blockId => + val blockIdOffsets = blockIds.map { blockId => BlockIdMeta.createBlockIdMeta(fbb, blockId.shuffleId, blockId.mapId, @@ -313,8 +389,7 @@ object ShuffleMetadata extends Logging{ def buildBufferTransferResponse(bufferMetas: Seq[BufferMeta]): ByteBuffer = { val fbb = new FlatBufferBuilder(1024, bbFactory) val responses = bufferMetas.map { bm => - val buffMetaOffset = BufferMeta.createBufferMeta(fbb, bm.id(), bm.actualSize(), - bm.compressedSize(), bm.codec()) + val buffMetaOffset = copyBufferMeta(fbb, bm) BufferTransferResponse.createBufferTransferResponse(fbb, bm.id(), TransferState.STARTED, buffMetaOffset) }.toArray @@ -365,24 +440,12 @@ object ShuffleMetadata extends Logging{ // TODO: Need to expose native ID in cudf if (DType.STRING == DType.fromNative(columnMeta.dtype())) { val offsetLenStr = columnMeta.offsets().length().toString - val validityLen = if (columnMeta.validity() == null) { - -1 - } else { - columnMeta.validity().length() - } out.append(s"column: $i [rows=${columnMeta.rowCount}, " + s"data_len=${columnMeta.data().length()}, offset_len=${offsetLenStr}, " + s"validity_len=$validityLen, type=${DType.fromNative(columnMeta.dtype())}, " + s"null_count=${columnMeta.nullCount()}]\n") } else { val offsetLenStr = "NC" - - val validityLen = if (columnMeta.validity() == null) { - -1 - } else { - columnMeta.validity().length() - } - out.append(s"column: $i [rows=${columnMeta.rowCount}, " + s"data_len=${columnMeta.data().length()}, offset_len=${offsetLenStr}, " + s"validity_len=$validityLen, type=${DType.fromNative(columnMeta.dtype())}, " + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala index a0b47d29acb..931a76a9f95 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala @@ -22,6 +22,10 @@ import com.nvidia.spark.rapids.format.TableMeta import org.apache.spark.sql.vectorized.ColumnarBatch +/** + * Buffer storage using device memory. + * @param catalog catalog to register this store + */ class RapidsDeviceMemoryStore( catalog: RapidsBufferCatalog) extends RapidsBufferStore("GPU", catalog) { override protected def createBuffer( @@ -42,21 +46,12 @@ class RapidsDeviceMemoryStore( table: Table, contigBuffer: DeviceMemoryBuffer, initialSpillPriority: Long): Unit = { - val size = contigBuffer.getLength - val meta = MetaUtils.buildTableMeta(id.tableId, table, contigBuffer) - - logDebug(s"Adding table for: [id=$id, size=$size, meta_id=${meta.bufferMeta().id()}, " + - s"meta_size=${meta.bufferMeta().actualSize()}, meta_num_cols=${meta.columnMetasLength()}]") - - val buffer = new RapidsDeviceMemoryBuffer( - id, - size, - meta, - table, - contigBuffer, - initialSpillPriority) - + val buffer = uncompressedBufferFromTable(id, table, contigBuffer, initialSpillPriority) try { + logDebug(s"Adding table for: [id=$id, size=${buffer.size}, " + + s"uncompressed=${buffer.meta.bufferMeta.uncompressedSize}, " + + s"meta_id=${buffer.meta.bufferMeta.id}, meta_size=${buffer.meta.bufferMeta.size}, " + + s"meta_num_cols=${buffer.meta.columnMetasLength}]") addBuffer(buffer) } catch { case t: Throwable => @@ -79,16 +74,21 @@ class RapidsDeviceMemoryStore( tableMeta: TableMeta, initialSpillPriority: Long): Unit = { logDebug(s"Adding receive side table for: [id=$id, size=${buffer.getLength}, " + - s"meta_id=${tableMeta.bufferMeta().id()}, " + - s"meta_size=${tableMeta.bufferMeta().actualSize()}, " + - s"meta_num_cols=${tableMeta.columnMetasLength()}]") - - val batch = MetaUtils.getBatchFromMeta(buffer, tableMeta) // REFCOUNT 1 + # COLS - // hold the 1 ref count extra in buffer, it will be removed later in releaseResources - val table = try { - GpuColumnVector.from(batch) // batch cols have 2 ref count - } finally { - batch.close() // cols should have single references + s"meta_id=${tableMeta.bufferMeta.id}, " + + s"meta_size=${tableMeta.bufferMeta.size}, " + + s"meta_num_cols=${tableMeta.columnMetasLength}]") + + val table = if (tableMeta.bufferMeta.codecBufferDescrsLength() > 0) { + // buffer is compressed so there is no Table. + None + } else { + val batch = MetaUtils.getBatchFromMeta(buffer, tableMeta) // REFCOUNT 1 + # COLS + // hold the 1 ref count extra in buffer, it will be removed later in releaseResources + try { + Some(GpuColumnVector.from(batch)) // batch cols have 2 ref count + } finally { + batch.close() // cols should have single references + } } val buff = new RapidsDeviceMemoryBuffer( @@ -102,24 +102,46 @@ class RapidsDeviceMemoryStore( addBuffer(buff) } + private def uncompressedBufferFromTable( + id: RapidsBufferId, + table: Table, + contigBuffer: DeviceMemoryBuffer, + initialSpillPriority: Long): RapidsDeviceMemoryBuffer = { + val size = contigBuffer.getLength + val meta = MetaUtils.buildTableMeta(id.tableId, table, contigBuffer) + new RapidsDeviceMemoryBuffer( + id, + size, + meta, + Some(table), + contigBuffer, + initialSpillPriority) + } + class RapidsDeviceMemoryBuffer( id: RapidsBufferId, size: Long, meta: TableMeta, - table: Table, + table: Option[Table], contigBuffer: DeviceMemoryBuffer, spillPriority: Long) extends RapidsBufferBase(id, size, meta, spillPriority) { + require(table.isDefined || meta.bufferMeta.codecBufferDescrsLength() > 0) + override val storageTier: StorageTier = StorageTier.DEVICE override protected def releaseResources(): Unit = { contigBuffer.close() - table.close() + table.foreach(_.close()) } override def getMemoryBuffer: MemoryBuffer = contigBuffer.slice(0, contigBuffer.getLength) override def getColumnarBatch: ColumnarBatch = { - GpuColumnVector.from(table) //REFCOUNT ++ of all columns + if (table.isDefined) { + GpuColumnVector.from(table.get) //REFCOUNT ++ of all columns + } else { + throw new UnsupportedOperationException("compressed buffer support not implemented") + } } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index 9224534cb56..c6d913c37d4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{DeviceMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange} import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.format.{CodecType, MetadataResponse, TableMeta, TransferState} +import com.nvidia.spark.rapids.format.{MetadataResponse, TableMeta, TransferState} import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.GpuShuffleEnv @@ -67,10 +67,7 @@ case class PendingTransferRequest(client: RapidsShuffleClient, tableMeta: TableMeta, tag: Long, handler: RapidsShuffleFetchHandler) { - - require(tableMeta.bufferMeta().codec() == CodecType.UNCOMPRESSED) - - def getLength: Long = tableMeta.bufferMeta().actualSize() + def getLength: Long = tableMeta.bufferMeta.size } /** diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/MetaUtilsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/MetaUtilsSuite.scala index fa1aa14bbaf..38c9e6e57a8 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/MetaUtilsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/MetaUtilsSuite.scala @@ -42,9 +42,9 @@ class MetaUtilsSuite extends FunSuite with Arm { val bufferMeta = meta.bufferMeta assertResult(7)(bufferMeta.id) - assertResult(buffer.getLength)(bufferMeta.compressedSize) - assertResult(buffer.getLength)(bufferMeta.actualSize) - assertResult(CodecType.UNCOMPRESSED)(bufferMeta.codec) + assertResult(buffer.getLength)(bufferMeta.size) + assertResult(buffer.getLength)(bufferMeta.uncompressedSize) + assertResult(0)(bufferMeta.codecBufferDescrsLength) assertResult(table.getRowCount)(meta.rowCount) assertResult(table.getNumberOfColumns)(meta.columnMetasLength) @@ -76,6 +76,29 @@ class MetaUtilsSuite extends FunSuite with Arm { } } + test("buildTableMeta with codec") { + withResource(buildContiguousTable()) { contigTable => + val tableId = 7 + val codecType = CodecType.COPY + val compressedSize: Long = 123 + val table = contigTable.getTable + val buffer = contigTable.getBuffer + val meta = MetaUtils.buildTableMeta(tableId, table, buffer, codecType, compressedSize) + + val bufferMeta = meta.bufferMeta + assertResult(tableId)(bufferMeta.id) + assertResult(compressedSize)(bufferMeta.size) + assertResult(table.getRowCount)(meta.rowCount) + assertResult(1)(bufferMeta.codecBufferDescrsLength) + val codecDescr = bufferMeta.codecBufferDescrs(0) + assertResult(codecType)(codecDescr.codec) + assertResult(compressedSize)(codecDescr.compressedSize) + assertResult(0)(codecDescr.compressedOffset) + assertResult(0)(codecDescr.uncompressedOffset) + assertResult(buffer.getLength)(codecDescr.uncompressedSize) + } + } + test("buildDegenerateTableMeta no columns") { val degenerateBatch = new ColumnarBatch(Array(), 127) val meta = MetaUtils.buildDegenerateTableMeta(8, degenerateBatch) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala index d50e9d358a3..a7773693c08 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids.shuffle import ai.rapids.cudf.{DeviceMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.format.{BufferMeta, TableMeta} import org.mockito.{ArgumentCaptor, ArgumentMatchers} import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ @@ -38,10 +38,23 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { def verifyTableMeta(expected: TableMeta, actual: TableMeta): Unit = { assertResult(expected.rowCount())(actual.rowCount()) assertResult(expected.columnMetasLength())(actual.columnMetasLength()) - assertResult(expected.bufferMeta().id())(actual.bufferMeta().id()) - assertResult(expected.bufferMeta().actualSize())(actual.bufferMeta().actualSize()) - assertResult(expected.bufferMeta().compressedSize())(actual.bufferMeta().compressedSize()) - assertResult(expected.bufferMeta().codec())(actual.bufferMeta().codec()) + verifyBufferMeta(expected.bufferMeta, actual.bufferMeta) + } + + def verifyBufferMeta(expected: BufferMeta, actual: BufferMeta): Unit = { + assertResult(expected.id)(actual.id) + assertResult(expected.size)(actual.size) + assertResult(expected.uncompressedSize)(actual.uncompressedSize) + assertResult(expected.codecBufferDescrsLength)(actual.codecBufferDescrsLength) + (0 until expected.codecBufferDescrsLength).foreach { i => + val expectedDescr = expected.codecBufferDescrs(i) + val actualDescr = actual.codecBufferDescrs(i) + assertResult(expectedDescr.codec)(actualDescr.codec) + assertResult(expectedDescr.compressedOffset)(actualDescr.compressedOffset) + assertResult(expectedDescr.compressedSize)(actualDescr.compressedSize) + assertResult(expectedDescr.uncompressedOffset)(actualDescr.uncompressedOffset) + assertResult(expectedDescr.uncompressedSize)(actualDescr.uncompressedSize) + } } test("successful metadata fetch") { @@ -141,7 +154,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { val totalReceived = mockConnection.receiveLengths.sum val numBuffersUsed = mockConnection.receiveLengths.size - assertResult(tableMeta.bufferMeta().actualSize())(totalReceived) + assertResult(tableMeta.bufferMeta().size())(totalReceived) assertResult(11)(numBuffersUsed) // we would perform 1 request to issue a `TransferRequest`, so the server can start. @@ -155,7 +168,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper { verify(mockStorage, times(1)) .addBuffer(any(), dmbCaptor.capture(), any(), any()) - assertResult(tableMeta.bufferMeta().actualSize())( + assertResult(tableMeta.bufferMeta().size())( dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer].getLength) // after closing, we should have freed our bounce buffers.