Skip to content

Commit

Permalink
addressed review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
razajafri committed Jun 15, 2020
1 parent dc89611 commit 1d38235
Showing 1 changed file with 40 additions and 35 deletions.
75 changes: 40 additions & 35 deletions tests/src/test/scala/ai/rapids/spark/GpuUnitTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,46 +97,51 @@ class GpuUnitTests extends SparkQueryCompareTestSuite {
"types should be the same")
withResource(expected.copyToHost()) { hostExpected =>
withResource(result.copyToHost()) { hostResult =>
for (row <- 0 until result.getRowCount().toInt) {
assert(hostExpected.isNullAt(row) == hostResult.isNullAt(row),
"expected and actual differ at " + row + " one of them isn't null")
if (!hostExpected.isNullAt(row)) {
result.getBase.getType() match {
case INT8 | BOOL8 =>
assert(hostExpected.getByte(row) == hostResult.getByte(row), "row " + row)
case INT16 =>
assert(hostExpected.getShort(row) == hostResult.getShort(row), "row " + row)

case INT32 | TIMESTAMP_DAYS =>
assert(hostExpected.getInt(row) == hostResult.getInt(row), "row " + row)

case INT64 | TIMESTAMP_MICROSECONDS | TIMESTAMP_MILLISECONDS | TIMESTAMP_NANOSECONDS |
TIMESTAMP_SECONDS =>
assert(hostExpected.getLong(row) == hostResult.getLong(row), "row " + row)

case FLOAT32 =>
assert(compare(hostExpected.getFloat(row), hostResult.getFloat(row),
0.0001), "row " + row)

case FLOAT64 =>
assert(compare(hostExpected.getDouble(row), hostResult.getDouble(row),
0.0001), "row " + row)

case STRING =>
assert(hostExpected.getUTF8String(row) == hostResult.getUTF8String(row),
"row " + row)

case _ =>
throw new IllegalArgumentException(hostResult.getBase.getType() +
" is not supported yet")
}
}
}
check(result, hostExpected, hostResult)
}
}
true
}

private def check(result: GpuColumnVector, hostExpected: RapidsHostColumnVector,
hostResult: RapidsHostColumnVector) = {
for (row <- 0 until result.getRowCount().toInt) {
assert(hostExpected.isNullAt(row) == hostResult.isNullAt(row),
"expected and actual differ at " + row + " one of them isn't null")
if (!hostExpected.isNullAt(row)) {
result.getBase.getType() match {
case INT8 | BOOL8 =>
assert(hostExpected.getByte(row) == hostResult.getByte(row), "row " + row)
case INT16 =>
assert(hostExpected.getShort(row) == hostResult.getShort(row), "row " + row)

case INT32 | TIMESTAMP_DAYS =>
assert(hostExpected.getInt(row) == hostResult.getInt(row), "row " + row)

case INT64 | TIMESTAMP_MICROSECONDS | TIMESTAMP_MILLISECONDS | TIMESTAMP_NANOSECONDS |
TIMESTAMP_SECONDS =>
assert(hostExpected.getLong(row) == hostResult.getLong(row), "row " + row)

case FLOAT32 =>
assert(compare(hostExpected.getFloat(row), hostResult.getFloat(row),
0.0001), "row " + row)

case FLOAT64 =>
assert(compare(hostExpected.getDouble(row), hostResult.getDouble(row),
0.0001), "row " + row)

case STRING =>
assert(hostExpected.getUTF8String(row) == hostResult.getUTF8String(row),
"row " + row)

case _ =>
throw new IllegalArgumentException(hostResult.getBase.getType() +
" is not supported yet")
}
}
}
}

protected def evaluateWithoutCodegen(gpuExpression: GpuExpression,
inputBatch: ColumnarBatch = EmptyBatch): GpuColumnVector = {
gpuExpression.foreach {
Expand Down

0 comments on commit 1d38235

Please sign in to comment.