diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java index 83a51a6dc26..b5911c71505 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVectorFromBuffer.java @@ -100,7 +100,7 @@ public static ColumnarBatch from(Table table, DeviceMemoryBuffer buffer, DataTyp * @param cudfColumn a ColumnVector instance * @param buffer the buffer to hold */ - private GpuColumnVectorFromBuffer(DataType type, ColumnVector cudfColumn, + public GpuColumnVectorFromBuffer(DataType type, ColumnVector cudfColumn, DeviceMemoryBuffer buffer) { super(type, cudfColumn); this.buffer = buffer; diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala index 942da56d6d8..d7e5a539827 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala @@ -260,7 +260,7 @@ object MetaUtils extends Arm { sparkType: DataType): GpuColumnVector = { val columnView = makeCudfColumnView(buffer, meta) val column = ColumnViewUtil.fromViewWithContiguousAllocation(columnView, buffer) - GpuColumnVector.from(column, sparkType) + new GpuColumnVectorFromBuffer(sparkType, column, buffer) } private def makeCudfColumn(buffer: DeviceMemoryBuffer, meta: ColumnMeta): ColumnVector = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/MetaUtilsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/MetaUtilsSuite.scala index 8445b00c9c9..5808036ef0f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/MetaUtilsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/MetaUtilsSuite.scala @@ -215,8 +215,10 @@ class MetaUtilsSuite extends FunSuite with Arm { assertResult(table.getRowCount)(batch.numRows) assertResult(table.getNumberOfColumns)(batch.numCols) (0 until table.getNumberOfColumns).foreach { i => + val batchColumn = batch.column(i) + assert(batchColumn.isInstanceOf[GpuColumnVectorFromBuffer]) TestUtils.compareColumns(table.getColumn(i), - batch.column(i).asInstanceOf[GpuColumnVector].getBase) + batchColumn.asInstanceOf[GpuColumnVector].getBase) } } }