diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala index 7d36fc73e8e..abcb0466860 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids import java.io.File import java.math.RoundingMode -import ai.rapids.cudf.{ColumnVector, DType, Table} +import ai.rapids.cudf.{ColumnVector, Cuda, DType, Table} import org.scalatest.FunSuite import org.apache.spark.SparkConf @@ -43,45 +43,63 @@ class GpuPartitioningSuite extends FunSuite with Arm { } /** - * Retrieves the underlying column vectors for a batch without incrementing - * the refcounts of those columns. Therefore the column vectors are only - * valid as long as the batch is valid. + * Retrieves the underlying column vectors for a batch. It increments the reference counts for + * them if needed so the results need to be closed. */ - private def extractBases(batch: ColumnarBatch): Array[ColumnVector] = { + private def extractColumnVectors(batch: ColumnarBatch): Array[ColumnVector] = { if (GpuPackedTableColumn.isBatchPacked(batch)) { val packedColumn = batch.column(0).asInstanceOf[GpuPackedTableColumn] val table = packedColumn.getContiguousTable.getTable // The contiguous table is still responsible for closing these columns. - (0 until table.getNumberOfColumns).map(table.getColumn).toArray + (0 until table.getNumberOfColumns).map(i => table.getColumn(i).incRefCount()).toArray + } else if (GpuCompressedColumnVector.isBatchCompressed(batch)) { + val compressedColumn = batch.column(0).asInstanceOf[GpuCompressedColumnVector] + val descr = compressedColumn.getTableMeta.bufferMeta.codecBufferDescrs(0) + val codec = TableCompressionCodec.getCodec(descr.codec) + withResource(codec.createBatchDecompressor(100 * 1024 * 1024L, + Cuda.DEFAULT_STREAM)) { decompressor => + compressedColumn.getTableBuffer.incRefCount() + decompressor.addBufferToDecompress(compressedColumn.getTableBuffer, + compressedColumn.getTableMeta.bufferMeta) + withResource(decompressor.finishAsync()) { outputBuffers => + val outputBuffer = outputBuffers.head + // There should be only one + withResource( + MetaUtils.getTableFromMeta(outputBuffer, compressedColumn.getTableMeta)) { table => + (0 until table.getNumberOfColumns).map(i => table.getColumn(i).incRefCount()).toArray + } + } + } } else { - GpuColumnVector.extractBases(batch) + GpuColumnVector.extractBases(batch).map(_.incRefCount()) } } private def buildSubBatch(batch: ColumnarBatch, startRow: Int, endRow: Int): ColumnarBatch = { - val columns = extractBases(batch) - val types = GpuColumnVector.extractTypes(batch) - val sliced = columns.zip(types).map { case (c, t) => - GpuColumnVector.from(c.subVector(startRow, endRow), t) + withResource(extractColumnVectors(batch)) { columns => + val types = GpuColumnVector.extractTypes(batch) + val sliced = columns.zip(types).map { case (c, t) => + GpuColumnVector.from(c.subVector(startRow, endRow), t) + } + new ColumnarBatch(sliced.toArray, endRow - startRow) } - new ColumnarBatch(sliced.toArray, endRow - startRow) } private def compareBatches(expected: ColumnarBatch, actual: ColumnarBatch): Unit = { assertResult(expected.numRows)(actual.numRows) - val expectedColumns = extractBases(expected) - val actualColumns = extractBases(expected) - assertResult(expectedColumns.length)(actualColumns.length) - expectedColumns.zip(actualColumns).foreach { case (expected, actual) => - // FIXME: For decimal types, NULL_EQUALS has not been supported in cuDF yet - val cpVec = if (expected.getType.isDecimalType) { - expected.equalTo(actual) - } else { - expected.equalToNullAware(actual) - } - withResource(cpVec) { compareVector => - withResource(compareVector.all()) { compareResult => - assert(compareResult.getBoolean) + withResource(extractColumnVectors(expected)) { expectedColumns => + withResource(extractColumnVectors(actual)) { actualColumns => + assertResult(expectedColumns.length)(actualColumns.length) + expectedColumns.zip(actualColumns).foreach { case (expected, actual) => + if (expected.getRowCount == 0) { + assertResult(expected.getType)(actual.getType) + } else { + withResource(expected.equalToNullAware(actual)) { compareVector => + withResource(compareVector.all()) { compareResult => + assert(compareResult.getBoolean) + } + } + } } } }