diff --git a/tests/src/test/scala/ai/rapids/spark/GpuUnitTests.scala b/tests/src/test/scala/ai/rapids/spark/GpuUnitTests.scala index 3da37a308d6..99b9c7a6694 100644 --- a/tests/src/test/scala/ai/rapids/spark/GpuUnitTests.scala +++ b/tests/src/test/scala/ai/rapids/spark/GpuUnitTests.scala @@ -97,46 +97,51 @@ class GpuUnitTests extends SparkQueryCompareTestSuite { "types should be the same") withResource(expected.copyToHost()) { hostExpected => withResource(result.copyToHost()) { hostResult => - for (row <- 0 until result.getRowCount().toInt) { - assert(hostExpected.isNullAt(row) == hostResult.isNullAt(row), - "expected and actual differ at " + row + " one of them isn't null") - if (!hostExpected.isNullAt(row)) { - result.getBase.getType() match { - case INT8 | BOOL8 => - assert(hostExpected.getByte(row) == hostResult.getByte(row), "row " + row) - case INT16 => - assert(hostExpected.getShort(row) == hostResult.getShort(row), "row " + row) - - case INT32 | TIMESTAMP_DAYS => - assert(hostExpected.getInt(row) == hostResult.getInt(row), "row " + row) - - case INT64 | TIMESTAMP_MICROSECONDS | TIMESTAMP_MILLISECONDS | TIMESTAMP_NANOSECONDS | - TIMESTAMP_SECONDS => - assert(hostExpected.getLong(row) == hostResult.getLong(row), "row " + row) - - case FLOAT32 => - assert(compare(hostExpected.getFloat(row), hostResult.getFloat(row), - 0.0001), "row " + row) - - case FLOAT64 => - assert(compare(hostExpected.getDouble(row), hostResult.getDouble(row), - 0.0001), "row " + row) - - case STRING => - assert(hostExpected.getUTF8String(row) == hostResult.getUTF8String(row), - "row " + row) - - case _ => - throw new IllegalArgumentException(hostResult.getBase.getType() + - " is not supported yet") - } - } - } + check(result, hostExpected, hostResult) } } true } + private def check(result: GpuColumnVector, hostExpected: RapidsHostColumnVector, + hostResult: RapidsHostColumnVector) = { + for (row <- 0 until result.getRowCount().toInt) { + assert(hostExpected.isNullAt(row) == hostResult.isNullAt(row), + "expected and actual differ at " + row + " one of them isn't null") + if (!hostExpected.isNullAt(row)) { + result.getBase.getType() match { + case INT8 | BOOL8 => + assert(hostExpected.getByte(row) == hostResult.getByte(row), "row " + row) + case INT16 => + assert(hostExpected.getShort(row) == hostResult.getShort(row), "row " + row) + + case INT32 | TIMESTAMP_DAYS => + assert(hostExpected.getInt(row) == hostResult.getInt(row), "row " + row) + + case INT64 | TIMESTAMP_MICROSECONDS | TIMESTAMP_MILLISECONDS | TIMESTAMP_NANOSECONDS | + TIMESTAMP_SECONDS => + assert(hostExpected.getLong(row) == hostResult.getLong(row), "row " + row) + + case FLOAT32 => + assert(compare(hostExpected.getFloat(row), hostResult.getFloat(row), + 0.0001), "row " + row) + + case FLOAT64 => + assert(compare(hostExpected.getDouble(row), hostResult.getDouble(row), + 0.0001), "row " + row) + + case STRING => + assert(hostExpected.getUTF8String(row) == hostResult.getUTF8String(row), + "row " + row) + + case _ => + throw new IllegalArgumentException(hostResult.getBase.getType() + + " is not supported yet") + } + } + } + } + protected def evaluateWithoutCodegen(gpuExpression: GpuExpression, inputBatch: ColumnarBatch = EmptyBatch): GpuColumnVector = { gpuExpression.foreach {