diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index e1b69d509b8..ebeafccc9f6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -409,7 +409,7 @@ case class GpuCast( } val longStrings = withResource(trimmed.matchesRe(regex)) { regexMatches => if (ansiMode) { - withResource(regexMatches.all()) { allRegexMatches => + withResource(regexMatches.all(DType.BOOL8)) { allRegexMatches => if (!allRegexMatches.getBoolean) { throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE) } @@ -520,7 +520,7 @@ case class GpuCast( withResource(input.contains(boolStrings)) { validBools => // in ansi mode, fail if any values are not valid bool strings if (ansiEnabled) { - withResource(validBools.all()) { isAllBool => + withResource(validBools.all(DType.BOOL8)) { isAllBool => if (!isAllBool.getBoolean) { throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE) } @@ -776,7 +776,7 @@ case class GpuCast( } // prepend today's date to timestamp formats without dates - sanitizedInput = withResource(sanitizedInput) { cv => + sanitizedInput = withResource(sanitizedInput) { _ => sanitizedInput.stringReplaceWithBackrefs(TIMESTAMP_REGEX_NO_DATE, s"${todayStr}T\\1") } @@ -818,7 +818,7 @@ case class GpuCast( // replace values less than minValue with null val gtEqMinOrNull = withResource(values.greaterOrEqualTo(minValue)) { isGtEqMin => if (ansiMode) { - withResource(isGtEqMin.all()) { all => + withResource(isGtEqMin.all(DType.BOOL8)) { all => if (!all.getBoolean) { throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE) } @@ -831,7 +831,7 @@ case class GpuCast( val ltEqMaxOrNull = withResource(gtEqMinOrNull) { gtEqMinOrNull => withResource(gtEqMinOrNull.lessOrEqualTo(maxValue)) { isLtEqMax => if (ansiMode) { - withResource(isLtEqMax.all()) { all => + withResource(isLtEqMax.all(DType.BOOL8)) { all => if (!all.getBoolean) { throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala index cf616385381..a4bc79deeb0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala @@ -204,8 +204,6 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow case other => throw new IllegalStateException(s"${other.getClass} is not a supported window aggregation") } - // Add support for Pandas (Python) UDF - case pythonFunc: GpuPythonUDF => pythonFunc case other => throw new IllegalStateException(s"${other.getClass} is not a supported window function") } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index 670c70a31f8..fce2f380392 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -95,7 +95,7 @@ object GpuScalar { case dec: BigDecimal => Scalar.fromDecimal(-dec.scale, dec.bigDecimal.unscaledValue().longValueExact()) case _ => - throw new IllegalStateException(s"${v.getClass} '${v}' is not supported as a scalar yet") + throw new IllegalStateException(s"${v.getClass} '$v' is not supported as a scalar yet") } def from(v: Any, t: DataType): Scalar = v match { @@ -105,13 +105,12 @@ object GpuScalar { case vv: Decimal => vv.toBigDecimal.bigDecimal case vv: BigDecimal => vv.bigDecimal case vv: Double => BigDecimal(vv).bigDecimal - case vv: Float => BigDecimal(vv).bigDecimal + case vv: Float => BigDecimal(vv.toDouble).bigDecimal case vv: String => BigDecimal(vv).bigDecimal - case vv: Double => BigDecimal(vv).bigDecimal case vv: Long => BigDecimal(vv).bigDecimal case vv: Int => BigDecimal(vv).bigDecimal case vv => throw new IllegalStateException( - s"${vv.getClass} '${vv}' is not supported as a scalar yet") + s"${vv.getClass} '$vv' is not supported as a scalar yet") } bigDec = bigDec.setScale(t.asInstanceOf[DecimalType].scale) if (bigDec.precision() > t.asInstanceOf[DecimalType].precision) { @@ -137,7 +136,7 @@ object GpuScalar { case s: String => Scalar.fromString(s) case s: UTF8String => Scalar.fromString(s.toString) case _ => - throw new IllegalStateException(s"${v.getClass} '${v}' is not supported as a scalar yet") + throw new IllegalStateException(s"${v.getClass} '$v' is not supported as a scalar yet") } def isNan(s: Scalar): Boolean = { @@ -220,7 +219,7 @@ case class GpuLiteral (value: Any, dataType: DataType) extends GpuLeafExpression case Double.NegativeInfinity => s"CAST('-Infinity' AS ${DoubleType.sql})" case _ => v + "D" } - case (v: Decimal, t: DecimalType) => v + "BD" + case (v: Decimal, _: DecimalType) => v + "BD" case (v: Int, DateType) => val formatter = DateFormatter(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) s"DATE '${formatter.format(v)}'" diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala index 1d3eca46679..5c7b777a3ea 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala @@ -189,8 +189,8 @@ case class GpuDataSource( /** - * Create a resolved [[BaseRelation]] that can be used to read data from or write data into this - * [[DataSource]] + * Create a resolved `BaseRelation` that can be used to read data from or write data into this + * `DataSource` * * @param checkFilesExist Whether to confirm that the files exist when generating the * non-streaming file based datasource. StructuredStreaming jobs already @@ -355,7 +355,7 @@ case class GpuDataSource( } /** - * Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for + * Writes the given `LogicalPlan` out to this `DataSource` and returns a `BaseRelation` for * the following reading. * * @param mode The save mode for this writing. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala index c1323fd27e0..ac751b59baf 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala @@ -261,8 +261,8 @@ trait GpuWindowInPandasExecBase extends UnaryExecNode with GpuExec { val frame = spec.frameSpecification.asInstanceOf[GpuSpecifiedWindowFrame] function match { case GpuAggregateExpression(_, _, _, _, _) => collect("AGGREGATE", frame, e) + // GpuPythonUDF is a GpuAggregateWindowFunction, so it is covered here. case _: GpuAggregateWindowFunction => collect("AGGREGATE", frame, e) - case _: GpuPythonUDF => collect("AGGREGATE", frame, e) // OffsetWindowFunction is not supported yet, no harm to keep it here case _: OffsetWindowFunction => collect("OFFSET", frame, e) case f => sys.error(s"Unsupported window function: $f") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala index 8e50c7aeb28..7ff2a4c9912 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala @@ -18,8 +18,7 @@ package com.nvidia.spark.rapids import java.io.File -import ai.rapids.cudf.Table -import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import ai.rapids.cudf.{DType, Table} import org.scalatest.FunSuite import org.apache.spark.SparkConf @@ -55,7 +54,7 @@ class GpuPartitioningSuite extends FunSuite with Arm { val actualColumns = GpuColumnVector.extractBases(expected) expectedColumns.zip(actualColumns).foreach { case (expected, actual) => withResource(expected.equalToNullAware(actual)) { compareVector => - withResource(compareVector.all()) { compareResult => + withResource(compareVector.all(DType.BOOL8)) { compareResult => assert(compareResult.getBoolean) } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index a5ee8a7aea0..a88fefe3554 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -20,6 +20,7 @@ import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.{Locale, TimeZone} +import scala.reflect.ClassTag import scala.util.{Failure, Try} import org.scalatest.FunSuite @@ -810,8 +811,8 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { incompat: Boolean = false, execsAllowedNonGpu: Seq[String] = Seq.empty, sortBeforeRepart: Boolean = false) - (fun: DataFrame => DataFrame): Unit = { - + (fun: DataFrame => DataFrame)(implicit classTag: ClassTag[T]): Unit = { + val clazz = classTag.runtimeClass val (testConf, qualifiedTestName) = setupTestConfAndQualifierName(testName, incompat, sort, conf, execsAllowedNonGpu, maxFloatDiff, sortBeforeRepart) @@ -824,9 +825,8 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { compareResults(sort, maxFloatDiff, fromCpu, fromGpu) }) t match { - case Failure(e) if e.isInstanceOf[T] => { + case Failure(e) if clazz.isAssignableFrom(e.getClass) => assert(expectedException(e.asInstanceOf[T])) - } case Failure(e) => throw e case _ => fail("Expected an exception") }