diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java index ab8335eea72..4cffdeb54af 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java @@ -40,6 +40,21 @@ public final class GpuColumnVectorFromBuffer extends GpuColumnVector { public static ColumnarBatch from(ContiguousTable contigTable) { DeviceMemoryBuffer buffer = contigTable.getBuffer(); Table table = contigTable.getTable(); + return from(table, buffer); + } + + /** + * Get a ColumnarBatch from a set of columns in a table, and the corresponding device buffer, + * which backs such columns. The resulting batch is composed of columns which are instances of + * GpuColumnVectorFromBuffer. This will increment the reference count for all columns + * converted so you will need to close both the table that is passed in and the batch + * returned to be sure that there are no leaks. + * + * @param table a table with columns at offsets of `buffer` + * @param buffer a device buffer that packs data for columns in `table` + * @return batch of GpuColumnVectorFromBuffer instances derived from the table and buffer + */ + public static ColumnarBatch from(Table table, DeviceMemoryBuffer buffer) { long rows = table.getRowCount(); if (rows != (int) rows) { throw new IllegalStateException("Cannot support a batch larger that MAX INT rows"); diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala index 5e72c7978ef..e68a706ca27 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala @@ -138,7 +138,7 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog override def getColumnarBatch: ColumnarBatch = { if (table.isDefined) { - GpuColumnVector.from(table.get) //REFCOUNT ++ of all columns + GpuColumnVectorFromBuffer.from(table.get, contigBuffer) //REFCOUNT ++ of all columns } else { columnarBatchFromDeviceBuffer(contigBuffer) }