From 8cce522d8b7d6c967d27f90cfeb00ef7a8ad972b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 12 Nov 2020 14:25:51 -0700 Subject: [PATCH] Make testExpectedExceptionStartsWith more flexible Signed-off-by: Andy Grove --- .../scala/com/nvidia/spark/rapids/CsvScanSuite.scala | 5 ++--- .../com/nvidia/spark/rapids/HashAggregatesSuite.scala | 11 +++++------ .../com/nvidia/spark/rapids/ParquetWriterSuite.scala | 5 ++--- .../spark/rapids/SparkQueryCompareTestSuite.scala | 9 ++++----- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala index 25ed1764cf6..ce9f17349ad 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala @@ -20,9 +20,8 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.functions.col class CsvScanSuite extends SparkQueryCompareTestSuite { - testExpectedExceptionStartsWith("Test CSV projection including unsupported types", - classOf[IllegalArgumentException], - "Part of the plan is not columnar", + testExpectedException[IllegalArgumentException]("Test CSV projection including unsupported types", + _.getMessage.startsWith("Part of the plan is not columnar"), mixedTypesFromCsvWithHeader) { frame => frame.select(col("c_string"), col("c_int"), col("c_timestamp")) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregatesSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregatesSuite.scala index 413514cdc63..2c1300c3094 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregatesSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregatesSuite.scala @@ -283,10 +283,9 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite { frame => frame.agg(avg(lit("abc")),avg(lit("pqr"))) } - testExpectedExceptionStartsWith( + testExpectedException[AnalysisException]( "avg literals bools fail", - classOf[AnalysisException], - "cannot resolve", + _.getMessage.startsWith("cannot resolve"), longsFromCSVDf, conf = floatAggConf) { frame => frame.agg(avg(lit(true)),avg(lit(false))) @@ -1550,10 +1549,10 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite { if (spark.SPARK_VERSION_SHORT < "3.1.0") { // A test that verifies that Distinct with Filter is not supported on the CPU or the GPU. - testExpectedExceptionStartsWith( + testExpectedException[AnalysisException]( "Avg Distinct with filter - unsupported on CPU and GPU", - classOf[AnalysisException], - "DISTINCT and FILTER cannot be used in aggregate functions at the same time", + _.getMessage.startsWith( + "DISTINCT and FILTER cannot be used in aggregate functions at the same time"), longsFromCSVDf, conf = floatAggConf) { frame => frame.selectExpr("avg(distinct longs) filter (where longs < 5)") } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala index 2efdab41ba9..ce537e81be7 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala @@ -76,10 +76,9 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite { } } - testExpectedExceptionStartsWith( + testExpectedException[IllegalArgumentException]( "int96 timestamps not supported", - classOf[IllegalArgumentException], - "Part of the plan is not columnar", + _.getMessage.startsWith("Part of the plan is not columnar"), frameFromParquet("timestamp-date-test-msec.parquet"), new SparkConf().set("spark.sql.parquet.outputTimestampType", "INT96")) { val tempFile = File.createTempFile("int96", "parquet") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index d4c8e4835b1..f75f40403ee 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -793,10 +793,9 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { } } - def testExpectedExceptionStartsWith[T <: Throwable]( + def testExpectedException[T <: Throwable]( testName: String, - exceptionClass: Class[T], - expectedException: String, + expectedException: T => Boolean, df: SparkSession => DataFrame, conf: SparkConf = new SparkConf(), repart: Integer = 1, @@ -819,8 +818,8 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { compareResults(sort, maxFloatDiff, fromCpu, fromGpu) }) t match { - case Failure(e) if e.getClass == exceptionClass => { - assert(e.getMessage != null && e.getMessage.startsWith(expectedException)) + case Failure(e) if e.isInstanceOf[T] => { + assert(expectedException(e.asInstanceOf[T])) } case Failure(e) => throw e case _ => fail("Expected an exception")