Skip to content

Commit

Permalink
Fixed leaks in unit test and use ColumnarBatch for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
razajafri committed Jun 15, 2020
1 parent 2c2883d commit dc89611
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 103 deletions.
98 changes: 53 additions & 45 deletions tests/src/test/scala/ai/rapids/spark/GpuUnitTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,41 +92,45 @@ class GpuUnitTests extends SparkQueryCompareTestSuite {
expression: Expression): Boolean = {
// The result is null for a non-nullable expression
assert(result != null || expression.nullable, "expression.nullable should be true if " +
"result is null")
assert(result.getBase().getType() == expected.getBase().getType(), "types should be the same")
val hostExpected = expected.copyToHost()
val hostResult = result.copyToHost()
for (row <- 0 until result.getRowCount().toInt) {
assert(hostExpected.isNullAt(row) == hostResult.isNullAt(row),
"expected and actual differ at " + row + " one of them isn't null")
if (!hostExpected.isNullAt(row)) {
result.getBase.getType() match {
case INT8 | BOOL8 =>
assert(hostExpected.getByte(row) == hostResult.getByte(row), "row " + row)
case INT16 =>
assert(hostExpected.getShort(row) == hostResult.getShort(row), "row " + row)

case INT32 | TIMESTAMP_DAYS =>
assert(hostExpected.getInt(row) == hostResult.getInt(row), "row " + row)

case INT64 | TIMESTAMP_MICROSECONDS | TIMESTAMP_MILLISECONDS | TIMESTAMP_NANOSECONDS |
TIMESTAMP_SECONDS =>
assert(hostExpected.getLong(row) == hostResult.getLong(row), "row " + row)

case FLOAT32 =>
assert(compare(hostExpected.getFloat(row), hostResult.getFloat(row), 0.0001),
"row " + row)

case FLOAT64 =>
assert(compare(hostExpected.getDouble(row), hostResult.getDouble(row), 0.0001),
"row " + row)

case STRING =>
assert(hostExpected.getUTF8String(row) == hostResult.getUTF8String(row), "row " + row)

case _ =>
throw new IllegalArgumentException(hostResult.getBase.getType() + " is not supported " +
"yet")
"result is null")
assert(result.getBase().getType() == expected.getBase().getType(),
"types should be the same")
withResource(expected.copyToHost()) { hostExpected =>
withResource(result.copyToHost()) { hostResult =>
for (row <- 0 until result.getRowCount().toInt) {
assert(hostExpected.isNullAt(row) == hostResult.isNullAt(row),
"expected and actual differ at " + row + " one of them isn't null")
if (!hostExpected.isNullAt(row)) {
result.getBase.getType() match {
case INT8 | BOOL8 =>
assert(hostExpected.getByte(row) == hostResult.getByte(row), "row " + row)
case INT16 =>
assert(hostExpected.getShort(row) == hostResult.getShort(row), "row " + row)

case INT32 | TIMESTAMP_DAYS =>
assert(hostExpected.getInt(row) == hostResult.getInt(row), "row " + row)

case INT64 | TIMESTAMP_MICROSECONDS | TIMESTAMP_MILLISECONDS | TIMESTAMP_NANOSECONDS |
TIMESTAMP_SECONDS =>
assert(hostExpected.getLong(row) == hostResult.getLong(row), "row " + row)

case FLOAT32 =>
assert(compare(hostExpected.getFloat(row), hostResult.getFloat(row),
0.0001), "row " + row)

case FLOAT64 =>
assert(compare(hostExpected.getDouble(row), hostResult.getDouble(row),
0.0001), "row " + row)

case STRING =>
assert(hostExpected.getUTF8String(row) == hostResult.getUTF8String(row),
"row " + row)

case _ =>
throw new IllegalArgumentException(hostResult.getBase.getType() +
" is not supported yet")
}
}
}
}
}
Expand All @@ -143,17 +147,21 @@ class GpuUnitTests extends SparkQueryCompareTestSuite {
}

private def checkEvaluationWithoutCodegen(gpuExpression: GpuExpression,
expected: GpuColumnVector,
inputBatch: ColumnarBatch = EmptyBatch): Unit = {
val actual = try evaluateWithoutCodegen(gpuExpression, inputBatch) catch {
expected: GpuColumnVector,
inputBatch: ColumnarBatch = EmptyBatch): Unit = {
try {
withResource(evaluateWithoutCodegen(gpuExpression, inputBatch)) { actual =>

if (!checkResult(actual, expected, gpuExpression)) {
val input = if (inputBatch == EmptyBatch) "" else s", input: $inputBatch"
fail(s"Incorrect evaluation (codegen off): $gpuExpression, " +
s"actual: $actual, " +
s"expected: $expected$input")
}
}
} catch {
case e: Exception => e.printStackTrace()
fail(s"Exception evaluating $gpuExpression", e)
}
if (!checkResult(actual, expected, gpuExpression)) {
val input = if (inputBatch == EmptyBatch) "" else s", input: $inputBatch"
fail(s"Incorrect evaluation (codegen off): $gpuExpression, " +
s"actual: $actual, " +
s"expected: $expected$input")
fail(s"Exception evaluating $gpuExpression", e)
}
}

Expand Down
105 changes: 47 additions & 58 deletions tests/src/test/scala/ai/rapids/spark/unit/DateTimeUnitTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
package ai.rapids.spark.unit

import ai.rapids.cudf.ColumnVector
import ai.rapids.spark.{GpuColumnVector, GpuLiteral, GpuUnitTests}
import ai.rapids.spark.{GpuBoundReference, GpuColumnVector, GpuLiteral, GpuUnitTests}

import org.apache.spark.sql.rapids.{GpuDateAdd, GpuDateSub}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.vectorized.ColumnarBatch

class DateTimeUnitTest extends GpuUnitTests {
val TIMES_DAY = Array(-1528, //1965-10-26
Expand All @@ -35,35 +37,28 @@ class DateTimeUnitTest extends GpuUnitTests {
withResource(GpuColumnVector.from(ColumnVector.daysFromInts(17676, 17706, 17680,
17683, 17680))) { expected2 =>

// vector 0
val dates0 = GpuColumnVector.from(ColumnVector.daysFromInts(TIMES_DAY: _*))
// assign to new handles for code readability
val dates0_1 = dates0.incRefCount()

//vector 1
val days0 = GpuColumnVector.from(ColumnVector.fromInts(2, 32, 6, 9, 6))
// assign to new handles for code readability
val days0_1 = days0.incRefCount()

// saving types for convenience
val lhsType = getSparkType(dates0.getBase.getType)
val rhsType = getSparkType(days0.getBase.getType)

val daysExprV0 = GpuLiteral(days0_1, rhsType)
val datesS = GpuLiteral(17674, lhsType)

// lhs = vector, rhs = scalar
val datesExprV = GpuLiteral(dates0, lhsType)
val daysS = GpuLiteral(30, rhsType)
checkEvaluation(GpuDateAdd(datesExprV, daysS), expected0)

// lhs = vector, rhs = vector
val daysExprV = GpuLiteral(days0, rhsType)
val datesExprV_1 = GpuLiteral(dates0_1, lhsType)
checkEvaluation(GpuDateAdd(datesExprV_1, daysExprV), expected1)

// lhs = scalar, rhs = vector
checkEvaluation(GpuDateAdd(datesS, daysExprV0), expected2)
withResource(GpuColumnVector.from(ColumnVector.daysFromInts(TIMES_DAY: _*))) {
datesVector =>
withResource(GpuColumnVector.from(ColumnVector.fromInts(2, 32, 6, 9, 6))) {
daysVector =>
val datesExpressionVector = GpuBoundReference(0, DataTypes.DateType, false)
val daysExpressionVector = GpuBoundReference(1, DataTypes.IntegerType, false)
val batch = new ColumnarBatch(List(datesVector, daysVector).toArray,
TIMES_DAY.length)

val daysScalar = GpuLiteral(30, DataTypes.IntegerType)
// lhs = vector, rhs = scalar
checkEvaluation(GpuDateAdd(datesExpressionVector, daysScalar), expected0, batch)

// lhs = vector, rhs = vector
checkEvaluation(GpuDateAdd(datesExpressionVector, daysExpressionVector),
expected1, batch)

// lhs = scalar, rhs = vector
val datesS = GpuLiteral(17674, DataTypes.DateType)
checkEvaluation(GpuDateAdd(datesS, daysExpressionVector), expected2, batch)
}
}
}
}
}
Expand All @@ -76,35 +71,29 @@ class DateTimeUnitTest extends GpuUnitTests {
-13537, 1710))) { expected1 =>
withResource(GpuColumnVector.from(ColumnVector.daysFromInts(17672, 17642, 17668,
17665, 17668))) { expected2 =>
// vector 0
val dates0 = GpuColumnVector.from(ColumnVector.daysFromInts(TIMES_DAY: _*))
// assign to new handles for code readability
val dates0_1 = dates0.incRefCount()

//vector 1
val days0 = GpuColumnVector.from(ColumnVector.fromInts(2, 32, 6, 9, 6))
// assign to new handles for code readability
val days0_1 = days0.incRefCount()

// saving types for convenience
val lhsType = getSparkType(dates0.getBase.getType)
val rhsType = getSparkType(days0.getBase.getType)

val daysExprV0 = GpuLiteral(days0_1, rhsType)
val datesS = GpuLiteral(17674, lhsType)

// lhs = vector, rhs = scalar
val datesExprV = GpuLiteral(dates0, lhsType)
val daysS = GpuLiteral(30, rhsType)
checkEvaluation(GpuDateSub(datesExprV, daysS), expected0)

// lhs = vector, rhs = vector
val daysExprV = GpuLiteral(days0, rhsType)
val datesExprV_1 = GpuLiteral(dates0_1, lhsType)
checkEvaluation(GpuDateSub(datesExprV_1, daysExprV), expected1)

// lhs = scalar, rhs = vector
checkEvaluation(GpuDateSub(datesS, daysExprV0), expected2)
withResource(GpuColumnVector.from(ColumnVector.daysFromInts(TIMES_DAY: _*))) {
datesVector =>
withResource(GpuColumnVector.from(ColumnVector.fromInts(2, 32, 6, 9, 6))) {
daysVector =>
val datesExpressionVector = GpuBoundReference(0, DataTypes.DateType, false)
val daysExpressionVector = GpuBoundReference(1, DataTypes.IntegerType, false)
val batch = new ColumnarBatch(List(datesVector, daysVector).toArray,
TIMES_DAY.length)

val daysScalar = GpuLiteral(30, DataTypes.IntegerType)
// lhs = vector, rhs = scalar
checkEvaluation(GpuDateSub(datesExpressionVector, daysScalar), expected0, batch)

// lhs = vector, rhs = vector
checkEvaluation(GpuDateSub(datesExpressionVector, daysExpressionVector),
expected1, batch)

// lhs = scalar, rhs = vector
val datesS = GpuLiteral(17674, DataTypes.DateType)
checkEvaluation(GpuDateSub(datesS, daysExpressionVector), expected2, batch)
}
}
}
}
}
Expand Down

0 comments on commit dc89611

Please sign in to comment.