Skip to content

Commit

Permalink
update the error msg of casting decimal type
Browse files Browse the repository at this point in the history
Signed-off-by: remzi <13716567376yh@gmail.com>
  • Loading branch information
HaoYang670 committed Jun 27, 2022
1 parent 7ad61fb commit 50c840a
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 14 deletions.
6 changes: 4 additions & 2 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import exception
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
Expand Down Expand Up @@ -449,14 +450,15 @@ def test_ceil_scale_zero(data_gen):

@pytest.mark.parametrize('data_gen', [_decimal_gen_36_neg5, _decimal_gen_38_neg10], ids=idfn)
def test_floor_ceil_overflow(data_gen):
exception_type = "java.lang.ArithmeticException" if is_before_spark_330() else "SparkArithmeticException"
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, data_gen).selectExpr('floor(a)').collect(),
conf={},
error_message="ArithmeticException")
error_message=exception_type)
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, data_gen).selectExpr('ceil(a)').collect(),
conf={},
error_message="ArithmeticException")
error_message=exception_type)

@pytest.mark.parametrize('data_gen', double_gens, ids=idfn)
def test_rint(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}

object RapidsErrorUtils {
object RapidsErrorUtils extends Arm {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down Expand Up @@ -52,4 +55,32 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}

object RapidsErrorUtils {
object RapidsErrorUtils extends Arm {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down Expand Up @@ -52,4 +55,32 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}

object RapidsErrorUtils {
object RapidsErrorUtils extends Arm{
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down Expand Up @@ -53,5 +56,34 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale
)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}

object RapidsErrorUtils {
object RapidsErrorUtils extends Arm {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
if (isElementAtF) {
Expand Down Expand Up @@ -56,4 +59,33 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}

object RapidsErrorUtils {
object RapidsErrorUtils extends Arm {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
if (isElementAtF) {
Expand Down Expand Up @@ -55,4 +58,33 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale, context
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import ai.rapids.cudf.ast.BinaryOperator
import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._

abstract class CudfUnaryMathExpression(name: String) extends GpuUnaryMathExpression(name)
Expand Down Expand Up @@ -198,7 +199,9 @@ case class GpuCeil(child: Expression, outputType: DataType)
withResource(DecimalUtil.outOfBounds(input.getBase, outputType)) { outOfBounds =>
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new ArithmeticException(s"Some data cannot be represented as $outputType")
throw RapidsErrorUtils.cannotChangeDecimalPrecisionError(
input, outOfBounds, dt, outputType
)
}
}
}
Expand Down Expand Up @@ -272,7 +275,9 @@ case class GpuFloor(child: Expression, outputType: DataType)
withResource(DecimalUtil.outOfBounds(input.getBase, outputType)) { outOfBounds =>
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new ArithmeticException(s"Some data cannot be represented as $outputType")
throw RapidsErrorUtils.cannotChangeDecimalPrecisionError(
input, outOfBounds, dt, outputType
)
}
}
}
Expand Down

0 comments on commit 50c840a

Please sign in to comment.