Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the error message of changing decimal type the same as Spark's [databricks] #5915

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import exception
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
Expand Down Expand Up @@ -449,14 +450,16 @@ def test_ceil_scale_zero(data_gen):

@pytest.mark.parametrize('data_gen', [_decimal_gen_36_neg5, _decimal_gen_38_neg10], ids=idfn)
def test_floor_ceil_overflow(data_gen):
exception_type = "java.lang.ArithmeticException" if is_before_spark_330() and not is_databricks104_or_later() \
else "SparkArithmeticException"
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, data_gen).selectExpr('floor(a)').collect(),
conf={},
error_message="ArithmeticException")
error_message=exception_type)
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, data_gen).selectExpr('ceil(a)').collect(),
conf={},
error_message="ArithmeticException")
error_message=exception_type)

@pytest.mark.parametrize('data_gen', double_gens, ids=idfn)
def test_rint(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
Expand Down Expand Up @@ -52,4 +52,12 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
Expand Down Expand Up @@ -52,4 +52,12 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
Expand Down Expand Up @@ -53,5 +53,14 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale
)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
Expand Down Expand Up @@ -56,4 +56,13 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
Expand Down Expand Up @@ -55,4 +55,13 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale, context
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import ai.rapids.cudf.ast.BinaryOperator
import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._

abstract class CudfUnaryMathExpression(name: String) extends GpuUnaryMathExpression(name)
Expand Down Expand Up @@ -198,7 +199,9 @@ case class GpuCeil(child: Expression, outputType: DataType)
withResource(DecimalUtil.outOfBounds(input.getBase, outputType)) { outOfBounds =>
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new ArithmeticException(s"Some data cannot be represented as $outputType")
throw RoundingErrorUtil.cannotChangeDecimalPrecisionError(
input, outOfBounds, dt, outputType
)
}
}
}
Expand Down Expand Up @@ -272,7 +275,9 @@ case class GpuFloor(child: Expression, outputType: DataType)
withResource(DecimalUtil.outOfBounds(input.getBase, outputType)) { outOfBounds =>
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new ArithmeticException(s"Some data cannot be represented as $outputType")
throw RoundingErrorUtil.cannotChangeDecimalPrecisionError(
input, outOfBounds, dt, outputType
)
}
}
}
Expand Down Expand Up @@ -781,3 +786,32 @@ case class GpuRint(child: Expression) extends CudfUnaryMathExpression("ROUND") {
override def unaryOp: UnaryOp = UnaryOp.RINT
override def outputTypeOverride: DType = DType.FLOAT64
}

private object RoundingErrorUtil extends Arm {
/**
* Wrapper of the `cannotChangeDecimalPrecisionError` of RapidsErrorUtils.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
RapidsErrorUtils.cannotChangeDecimalPrecisionError(value, toType)
}
}