From fcef8da9c5e59dc578d5f939cfd176423ad613b3 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 19 Nov 2020 14:53:23 -0600 Subject: [PATCH] Fix the cast tests for 3.1.0+ (#1166) Signed-off-by: Robert (Bobby) Evans --- .../nvidia/spark/rapids/AnsiCastOpSuite.scala | 79 +++++++++++++------ .../com/nvidia/spark/rapids/CastOpSuite.scala | 18 ++++- .../rapids/SparkQueryCompareTestSuite.scala | 8 +- 3 files changed, 78 insertions(+), 27 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala index e81a323f5bb..b7543648ff4 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala @@ -67,55 +67,67 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite { // Ansi cast from timestamp to integral types /////////////////////////////////////////////////////////////////////////// + def before3_1_0(s: SparkSession): (Boolean, String) = { + (s.version < "3.1.0", s"Spark version must be prior to 3.1.0") + } + testSparkResultsAreEqual("ansi_cast timestamps to long", - generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf) { + generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.LongType)(frame) } testSparkResultsAreEqual("ansi_cast successful timestamps to shorts", - generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf) { + generateValidValuesTimestampsDF(Short.MinValue, Short.MaxValue), sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.ShortType)(frame) } testSparkResultsAreEqual("ansi_cast successful timestamps to ints", - generateValidValuesTimestampsDF(Int.MinValue, Int.MaxValue), sparkConf) { + generateValidValuesTimestampsDF(Int.MinValue, Int.MaxValue), sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.IntegerType)(frame) } testSparkResultsAreEqual("ansi_cast successful timestamps to bytes", - generateValidValuesTimestampsDF(Byte.MinValue, Byte.MaxValue), sparkConf) { + generateValidValuesTimestampsDF(Byte.MinValue, Byte.MaxValue), sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.ByteType)(frame) } testCastFailsForBadInputs("ansi_cast overflow timestamps to bytes", - generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MaxValue + 1), sparkConf) { + generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MaxValue + 1), sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.ByteType)(frame) } testCastFailsForBadInputs("ansi_cast underflow timestamps to bytes", - generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MinValue - 1), sparkConf) { + generateOutOfRangeTimestampsDF(Byte.MinValue, Byte.MaxValue, Byte.MinValue - 1), sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.ByteType)(frame) } testCastFailsForBadInputs("ansi_cast overflow timestamps to shorts", - generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MaxValue + 1), sparkConf) { + generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MaxValue + 1), sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.ShortType)(frame) } testCastFailsForBadInputs("ansi_cast underflow timestamps to shorts", - generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MinValue - 1), sparkConf) { + generateOutOfRangeTimestampsDF(Short.MinValue, Short.MaxValue, Short.MinValue - 1), sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.ShortType)(frame) } testCastFailsForBadInputs("ansi_cast overflow timestamps to int", generateOutOfRangeTimestampsDF(Int.MinValue, Int.MaxValue, Int.MaxValue.toLong + 1), - sparkConf) { + sparkConf, assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.IntegerType)(frame) } testCastFailsForBadInputs("ansi_cast underflow timestamps to int", generateOutOfRangeTimestampsDF(Int.MinValue, Int.MaxValue, Int.MinValue.toLong - 1), - sparkConf) { + sparkConf, assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.IntegerType)(frame) } @@ -123,31 +135,38 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite { // Ansi cast from date /////////////////////////////////////////////////////////////////////////// - testSparkResultsAreEqual("ansi_cast date to bool", testDates, sparkConf) { + testSparkResultsAreEqual("ansi_cast date to bool", testDates, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.BooleanType)(frame) } - testSparkResultsAreEqual("ansi_cast date to byte", testDates, sparkConf) { + testSparkResultsAreEqual("ansi_cast date to byte", testDates, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.ByteType)(frame) } - testSparkResultsAreEqual("ansi_cast date to short", testDates, sparkConf) { + testSparkResultsAreEqual("ansi_cast date to short", testDates, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.ShortType)(frame) } - testSparkResultsAreEqual("ansi_cast date to int", testDates, sparkConf) { + testSparkResultsAreEqual("ansi_cast date to int", testDates, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.IntegerType)(frame) } - testSparkResultsAreEqual("ansi_cast date to long", testDates, sparkConf) { + testSparkResultsAreEqual("ansi_cast date to long", testDates, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.LongType)(frame) } - testSparkResultsAreEqual("ansi_cast date to float", testDates, sparkConf) { + testSparkResultsAreEqual("ansi_cast date to float", testDates, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.FloatType)(frame) } - testSparkResultsAreEqual("ansi_cast date to double", testDates, sparkConf) { + testSparkResultsAreEqual("ansi_cast date to double", testDates, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.DoubleType)(frame) } @@ -187,7 +206,8 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite { frame => testCastTo(DataTypes.DoubleType)(frame) } - testSparkResultsAreEqual("ansi_cast bool to timestamp", testBools, sparkConf) { + testSparkResultsAreEqual("ansi_cast bool to timestamp", testBools, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.TimestampType)(frame) } @@ -219,7 +239,8 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite { frame => testCastTo(DataTypes.BooleanType)(frame) } - testSparkResultsAreEqual("ansi_cast timestamp to bool", testTimestamps, sparkConf) { + testSparkResultsAreEqual("ansi_cast timestamp to bool", testTimestamps, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.BooleanType)(frame) } @@ -377,19 +398,23 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite { // Ansi cast integral types to timestamp /////////////////////////////////////////////////////////////////////////// - testSparkResultsAreEqual("ansi_cast bytes to timestamp", testBytes, sparkConf) { + testSparkResultsAreEqual("ansi_cast bytes to timestamp", testBytes, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.TimestampType)(frame) } - testSparkResultsAreEqual("ansi_cast shorts to timestamp", testShorts, sparkConf) { + testSparkResultsAreEqual("ansi_cast shorts to timestamp", testShorts, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.TimestampType)(frame) } - testSparkResultsAreEqual("ansi_cast ints to timestamp", testInts, sparkConf) { + testSparkResultsAreEqual("ansi_cast ints to timestamp", testInts, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.TimestampType)(frame) } - testSparkResultsAreEqual("ansi_cast longs to timestamp", testLongs, sparkConf) { + testSparkResultsAreEqual("ansi_cast longs to timestamp", testLongs, sparkConf, + assumeCondition = before3_1_0) { frame => testCastTo(DataTypes.TimestampType)(frame) } @@ -613,10 +638,16 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite { testName: String, frame: SparkSession => DataFrame, sparkConf: SparkConf = sparkConf, - msg: String = GpuCast.INVALID_INPUT_MESSAGE)(transformation: DataFrame => DataFrame) + msg: String = GpuCast.INVALID_INPUT_MESSAGE, + assumeCondition: SparkSession => (Boolean, String) = null) + (transformation: DataFrame => DataFrame) : Unit = { test(testName) { + if (assumeCondition != null) { + val (isAllowed, reason) = withCpuSparkSession(assumeCondition, conf = sparkConf) + assume(isAllowed, reason) + } try { withGpuSparkSession(spark => { val input = frame(spark).repartition(1) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala index 7e0c8d2910c..8a2c3772ee2 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala @@ -49,7 +49,17 @@ class CastOpSuite extends GpuExpressionTestSuite { for (from <- supportedTypes; to <- supportedTypes) yield (from, to) } + def should310SkipAnsiCast(from: DataType, to: DataType): Boolean = (from, to) match { + case (_: NumericType, TimestampType | DateType) => true + case (BooleanType, TimestampType | DateType) => true + case (TimestampType | DateType, _: NumericType) => true + case (TimestampType | DateType, BooleanType) => true + case _ => false + } + + test("Test all supported casts with in-range values") { + val is310OrAfter = !withCpuSparkSession(s => s.version < "3.1.0") // test cast() and ansi_cast() Seq(false, true).foreach { ansiEnabled => @@ -63,8 +73,12 @@ class CastOpSuite extends GpuExpressionTestSuite { typeMatrix.foreach { case (from, to) => + // In 3.1.0 Cast.canCast was split with a separate ANSI version + // Until we are on 3.1.0 or more we cannot call this easily so for now + // We will check and skip a very specific one. + val shouldSkip = is310OrAfter && ansiEnabled && should310SkipAnsiCast(to, from) // check if Spark supports this cast - if (Cast.canCast(from, to)) { + if (!shouldSkip && Cast.canCast(from, to)) { // check if plugin supports this cast if (GpuCast.canCast(from, to)) { // test the cast @@ -87,7 +101,7 @@ class CastOpSuite extends GpuExpressionTestSuite { fail(s"Cast from $from to $to failed; ansi=$ansiEnabled", e) } } - } else { + } else if (!shouldSkip) { // if Spark doesn't support this cast then the plugin shouldn't either assert(!GpuCast.canCast(from, to)) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index f75f40403ee..fb2093e2944 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -698,13 +698,19 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { maxFloatDiff: Double = 0.0, incompat: Boolean = false, execsAllowedNonGpu: Seq[String] = Seq.empty, - sortBeforeRepart: Boolean = false) + sortBeforeRepart: Boolean = false, + assumeCondition: SparkSession => (Boolean, String) = null) (fun: DataFrame => DataFrame): Unit = { val (testConf, qualifiedTestName) = setupTestConfAndQualifierName(testName, incompat, sort, conf, execsAllowedNonGpu, maxFloatDiff, sortBeforeRepart) + test(qualifiedTestName) { + if (assumeCondition != null) { + val (isAllowed, reason) = withCpuSparkSession(assumeCondition, conf = testConf) + assume(isAllowed, reason) + } val (fromCpu, fromGpu) = runOnCpuAndGpu(df, fun, conf = testConf, repart = repart)