Skip to content

Commit

Permalink
Throw SparkDateTimeException for InvalidInput while casting in ANSI m…
Browse files Browse the repository at this point in the history
…ode [databricks] (#5731)

* Throw SparkDateTimeException for InvalidInput in CastToDateTime in ANSI mode

Signed-off-by: Niranjan Artal <nartal@nvidia.com>

* addressed review comments

Signed-off-by: Niranjan Artal <nartal@nvidia.com>
  • Loading branch information
nartal1 authored Jun 6, 2022
1 parent 2cc7136 commit efb4d76
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 10 deletions.
2 changes: 1 addition & 1 deletion integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def fun(spark):
data = [invalid_value]
df = spark.createDataFrame(data, type)
return df.select(f.col('value').cast(TimestampType())).collect()
assert_gpu_and_cpu_error(fun, {"spark.sql.ansi.enabled": True}, "java.time.DateTimeException")
assert_gpu_and_cpu_error(fun, {"spark.sql.ansi.enabled": True}, "SparkDateTimeException")

# if float.floor > Long.max or float.ceil < Long.min, throw exception
@pytest.mark.skipif(is_before_spark_330(), reason="ansi cast throws exception only in 3.3.0+")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

package com.nvidia.spark.rapids.shims

import java.time.DateTimeException

import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{Arm, BoolUtils, FloatUtils, GpuColumnVector}

import org.apache.spark.rapids.ShimTrampolineUtil
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType

Expand All @@ -44,12 +43,15 @@ object AnsiUtil extends Arm {
}

private def castDoubleToTimestampAnsi(doubleInput: ColumnView, toType: DataType): ColumnVector = {
val msg = s"The column contains at least a single value that is " +
s"NaN, Infinity or out-of-range values. To return NULL instead, use 'try_cast'. " +
s"If necessary set ${SQLConf.ANSI_ENABLED.key} to false to bypass this error."
val msg = s"The column contains out-of-range values. To return NULL instead, use " +
s"'try_cast'. If necessary set ${SQLConf.ANSI_ENABLED.key} to false to bypass this error."

// These are the arguments required by SparkDateTimeException class to create error message.
val errorClass = "CAST_INVALID_INPUT"
val messageParameters = Array("DOUBLE", "TIMESTAMP", SQLConf.ANSI_ENABLED.key)

def throwDateTimeException: Unit = {
throw new DateTimeException(msg)
def throwSparkDateTimeException(infOrNan: String): Unit = {
throw ShimTrampolineUtil.dateTimeException(errorClass, Array(infOrNan) ++ messageParameters)
}

def throwOverflowException: Unit = {
Expand All @@ -58,15 +60,16 @@ object AnsiUtil extends Arm {

withResource(doubleInput.isNan) { hasNan =>
if (BoolUtils.isAnyValidTrue(hasNan)) {
throwDateTimeException
throwSparkDateTimeException("NaN")
}
}

// check nan
withResource(FloatUtils.getInfinityVector(doubleInput.getType)) { inf =>
withResource(doubleInput.contains(inf)) { hasInf =>
if (BoolUtils.isAnyValidTrue(hasInf)) {
throwDateTimeException
// We specify as "Infinity" for both "+Infinity" and "-Infinity" in the error message
throwSparkDateTimeException("Infinity")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rapids

import org.apache.spark.SparkDateTimeException

object ShimTrampolineUtil {
def dateTimeException(errorClass: String, messageParameters: Array[String]) = {
new SparkDateTimeException(errorClass, messageParameters)
}
}

0 comments on commit efb4d76

Please sign in to comment.