diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index 6f686931240..1ae36c0d230 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -372,7 +372,7 @@ def fun(spark): data = [invalid_value] df = spark.createDataFrame(data, type) return df.select(f.col('value').cast(TimestampType())).collect() - assert_gpu_and_cpu_error(fun, {"spark.sql.ansi.enabled": True}, "java.time.DateTimeException") + assert_gpu_and_cpu_error(fun, {"spark.sql.ansi.enabled": True}, "SparkDateTimeException") # if float.floor > Long.max or float.ceil < Long.min, throw exception @pytest.mark.skipif(is_before_spark_330(), reason="ansi cast throws exception only in 3.3.0+") diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/AnsiUtil.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/AnsiUtil.scala index 2ab25dfa84f..1898a21587b 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/AnsiUtil.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/AnsiUtil.scala @@ -16,11 +16,10 @@ package com.nvidia.spark.rapids.shims -import java.time.DateTimeException - import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar} import com.nvidia.spark.rapids.{Arm, BoolUtils, FloatUtils, GpuColumnVector} +import org.apache.spark.rapids.ShimTrampolineUtil import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType @@ -44,12 +43,15 @@ object AnsiUtil extends Arm { } private def castDoubleToTimestampAnsi(doubleInput: ColumnView, toType: DataType): ColumnVector = { - val msg = s"The column contains at least a single value that is " + - s"NaN, Infinity or out-of-range values. To return NULL instead, use 'try_cast'. " + - s"If necessary set ${SQLConf.ANSI_ENABLED.key} to false to bypass this error." + val msg = s"The column contains out-of-range values. To return NULL instead, use " + + s"'try_cast'. If necessary set ${SQLConf.ANSI_ENABLED.key} to false to bypass this error." + + // These are the arguments required by SparkDateTimeException class to create error message. + val errorClass = "CAST_INVALID_INPUT" + val messageParameters = Array("DOUBLE", "TIMESTAMP", SQLConf.ANSI_ENABLED.key) - def throwDateTimeException: Unit = { - throw new DateTimeException(msg) + def throwSparkDateTimeException(infOrNan: String): Unit = { + throw ShimTrampolineUtil.dateTimeException(errorClass, Array(infOrNan) ++ messageParameters) } def throwOverflowException: Unit = { @@ -58,7 +60,7 @@ object AnsiUtil extends Arm { withResource(doubleInput.isNan) { hasNan => if (BoolUtils.isAnyValidTrue(hasNan)) { - throwDateTimeException + throwSparkDateTimeException("NaN") } } @@ -66,7 +68,8 @@ object AnsiUtil extends Arm { withResource(FloatUtils.getInfinityVector(doubleInput.getType)) { inf => withResource(doubleInput.contains(inf)) { hasInf => if (BoolUtils.isAnyValidTrue(hasInf)) { - throwDateTimeException + // We specify as "Infinity" for both "+Infinity" and "-Infinity" in the error message + throwSparkDateTimeException("Infinity") } } } diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/rapids/ShimTrampolineUtil.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/rapids/ShimTrampolineUtil.scala new file mode 100644 index 00000000000..bd3e3ae3a3c --- /dev/null +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/rapids/ShimTrampolineUtil.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rapids + +import org.apache.spark.SparkDateTimeException + +object ShimTrampolineUtil { + def dateTimeException(errorClass: String, messageParameters: Array[String]) = { + new SparkDateTimeException(errorClass, messageParameters) + } +}