Skip to content

Commit

Permalink
Make the error message of changing decimal type the same as Spark's […
Browse files Browse the repository at this point in the history
…databricks] (#5915)

* update the error msg of casting decimal type

Signed-off-by: remzi <13716567376yh@gmail.com>

* refactor

Signed-off-by: remzi <13716567376yh@gmail.com>

* update checking for spark321db

Signed-off-by: remzi <13716567376yh@gmail.com>

* Update sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala
  • Loading branch information
HaoYang670 authored Jul 8, 2022
1 parent 4726644 commit 5015754
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 9 deletions.
7 changes: 5 additions & 2 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import exception
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
Expand Down Expand Up @@ -449,14 +450,16 @@ def test_ceil_scale_zero(data_gen):

@pytest.mark.parametrize('data_gen', [_decimal_gen_36_neg5, _decimal_gen_38_neg10], ids=idfn)
def test_floor_ceil_overflow(data_gen):
exception_type = "java.lang.ArithmeticException" if is_before_spark_330() and not is_databricks104_or_later() \
else "SparkArithmeticException"
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, data_gen).selectExpr('floor(a)').collect(),
conf={},
error_message="ArithmeticException")
error_message=exception_type)
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, data_gen).selectExpr('ceil(a)').collect(),
conf={},
error_message="ArithmeticException")
error_message=exception_type)

@pytest.mark.parametrize('data_gen', double_gens, ids=idfn)
def test_rint(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
Expand Down Expand Up @@ -52,4 +52,12 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
Expand Down Expand Up @@ -52,4 +52,12 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
Expand Down Expand Up @@ -53,5 +53,14 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale
)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
Expand Down Expand Up @@ -56,4 +56,13 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
Expand Down Expand Up @@ -55,4 +55,13 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale, context
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import ai.rapids.cudf.ast.BinaryOperator
import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._

abstract class CudfUnaryMathExpression(name: String) extends GpuUnaryMathExpression(name)
Expand Down Expand Up @@ -198,7 +199,9 @@ case class GpuCeil(child: Expression, outputType: DataType)
withResource(DecimalUtil.outOfBounds(input.getBase, outputType)) { outOfBounds =>
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new ArithmeticException(s"Some data cannot be represented as $outputType")
throw RoundingErrorUtil.cannotChangeDecimalPrecisionError(
input, outOfBounds, dt, outputType
)
}
}
}
Expand Down Expand Up @@ -272,7 +275,9 @@ case class GpuFloor(child: Expression, outputType: DataType)
withResource(DecimalUtil.outOfBounds(input.getBase, outputType)) { outOfBounds =>
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new ArithmeticException(s"Some data cannot be represented as $outputType")
throw RoundingErrorUtil.cannotChangeDecimalPrecisionError(
input, outOfBounds, dt, outputType
)
}
}
}
Expand Down Expand Up @@ -781,3 +786,32 @@ case class GpuRint(child: Expression) extends CudfUnaryMathExpression("ROUND") {
override def unaryOp: UnaryOp = UnaryOp.RINT
override def outputTypeOverride: DType = DType.FLOAT64
}

private object RoundingErrorUtil extends Arm {
/**
* Wrapper of the `cannotChangeDecimalPrecisionError` of RapidsErrorUtils.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
RapidsErrorUtils.cannotChangeDecimalPrecisionError(value, toType)
}
}

0 comments on commit 5015754

Please sign in to comment.