Skip to content

Commit

Permalink
Fix Part Suite Tests (NVIDIA#1852)
Browse files Browse the repository at this point in the history
* Fix Part Suite Tests

Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>

* Addressed review comments
  • Loading branch information
revans2 authored Mar 3, 2021
1 parent 8f4011c commit 269e51f
Showing 1 changed file with 43 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids
import java.io.File
import java.math.RoundingMode

import ai.rapids.cudf.{ColumnVector, DType, Table}
import ai.rapids.cudf.{ColumnVector, Cuda, DType, Table}
import org.scalatest.FunSuite

import org.apache.spark.SparkConf
Expand All @@ -43,45 +43,63 @@ class GpuPartitioningSuite extends FunSuite with Arm {
}

/**
* Retrieves the underlying column vectors for a batch without incrementing
* the refcounts of those columns. Therefore the column vectors are only
* valid as long as the batch is valid.
* Retrieves the underlying column vectors for a batch. It increments the reference counts for
* them if needed so the results need to be closed.
*/
private def extractBases(batch: ColumnarBatch): Array[ColumnVector] = {
private def extractColumnVectors(batch: ColumnarBatch): Array[ColumnVector] = {
if (GpuPackedTableColumn.isBatchPacked(batch)) {
val packedColumn = batch.column(0).asInstanceOf[GpuPackedTableColumn]
val table = packedColumn.getContiguousTable.getTable
// The contiguous table is still responsible for closing these columns.
(0 until table.getNumberOfColumns).map(table.getColumn).toArray
(0 until table.getNumberOfColumns).map(i => table.getColumn(i).incRefCount()).toArray
} else if (GpuCompressedColumnVector.isBatchCompressed(batch)) {
val compressedColumn = batch.column(0).asInstanceOf[GpuCompressedColumnVector]
val descr = compressedColumn.getTableMeta.bufferMeta.codecBufferDescrs(0)
val codec = TableCompressionCodec.getCodec(descr.codec)
withResource(codec.createBatchDecompressor(100 * 1024 * 1024L,
Cuda.DEFAULT_STREAM)) { decompressor =>
compressedColumn.getTableBuffer.incRefCount()
decompressor.addBufferToDecompress(compressedColumn.getTableBuffer,
compressedColumn.getTableMeta.bufferMeta)
withResource(decompressor.finishAsync()) { outputBuffers =>
val outputBuffer = outputBuffers.head
// There should be only one
withResource(
MetaUtils.getTableFromMeta(outputBuffer, compressedColumn.getTableMeta)) { table =>
(0 until table.getNumberOfColumns).map(i => table.getColumn(i).incRefCount()).toArray
}
}
}
} else {
GpuColumnVector.extractBases(batch)
GpuColumnVector.extractBases(batch).map(_.incRefCount())
}
}

private def buildSubBatch(batch: ColumnarBatch, startRow: Int, endRow: Int): ColumnarBatch = {
val columns = extractBases(batch)
val types = GpuColumnVector.extractTypes(batch)
val sliced = columns.zip(types).map { case (c, t) =>
GpuColumnVector.from(c.subVector(startRow, endRow), t)
withResource(extractColumnVectors(batch)) { columns =>
val types = GpuColumnVector.extractTypes(batch)
val sliced = columns.zip(types).map { case (c, t) =>
GpuColumnVector.from(c.subVector(startRow, endRow), t)
}
new ColumnarBatch(sliced.toArray, endRow - startRow)
}
new ColumnarBatch(sliced.toArray, endRow - startRow)
}

private def compareBatches(expected: ColumnarBatch, actual: ColumnarBatch): Unit = {
assertResult(expected.numRows)(actual.numRows)
val expectedColumns = extractBases(expected)
val actualColumns = extractBases(expected)
assertResult(expectedColumns.length)(actualColumns.length)
expectedColumns.zip(actualColumns).foreach { case (expected, actual) =>
// FIXME: For decimal types, NULL_EQUALS has not been supported in cuDF yet
val cpVec = if (expected.getType.isDecimalType) {
expected.equalTo(actual)
} else {
expected.equalToNullAware(actual)
}
withResource(cpVec) { compareVector =>
withResource(compareVector.all()) { compareResult =>
assert(compareResult.getBoolean)
withResource(extractColumnVectors(expected)) { expectedColumns =>
withResource(extractColumnVectors(actual)) { actualColumns =>
assertResult(expectedColumns.length)(actualColumns.length)
expectedColumns.zip(actualColumns).foreach { case (expected, actual) =>
if (expected.getRowCount == 0) {
assertResult(expected.getType)(actual.getType)
} else {
withResource(expected.equalToNullAware(actual)) { compareVector =>
withResource(compareVector.all()) { compareResult =>
assert(compareResult.getBoolean)
}
}
}
}
}
}
Expand Down

0 comments on commit 269e51f

Please sign in to comment.