Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the error message of changing decimal type the same as Spark's [databricks] #5915

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import exception
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
Expand Down Expand Up @@ -449,14 +450,15 @@ def test_ceil_scale_zero(data_gen):

@pytest.mark.parametrize('data_gen', [_decimal_gen_36_neg5, _decimal_gen_38_neg10], ids=idfn)
def test_floor_ceil_overflow(data_gen):
exception_type = "java.lang.ArithmeticException" if is_before_spark_330() else "SparkArithmeticException"
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, data_gen).selectExpr('floor(a)').collect(),
conf={},
error_message="ArithmeticException")
error_message=exception_type)
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, data_gen).selectExpr('ceil(a)').collect(),
conf={},
error_message="ArithmeticException")
error_message=exception_type)

@pytest.mark.parametrize('data_gen', double_gens, ids=idfn)
def test_rint(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}

object RapidsErrorUtils {
object RapidsErrorUtils extends Arm {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down Expand Up @@ -52,4 +55,32 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}

object RapidsErrorUtils {
object RapidsErrorUtils extends Arm {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down Expand Up @@ -52,4 +55,32 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}

object RapidsErrorUtils {
object RapidsErrorUtils extends Arm{
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down Expand Up @@ -53,5 +56,34 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale
)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}

object RapidsErrorUtils {
object RapidsErrorUtils extends Arm {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
if (isElementAtF) {
Expand Down Expand Up @@ -56,4 +59,33 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only real nit is that there is a lot of copy and paste between the 4 implementations of this code. Is there a ways we can make this more common?

Copy link
Collaborator Author

@HaoYang670 HaoYang670 Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can extra some common expressions to be a new helper function. But I am not sure where to put the helper.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only GpuCeil and GpuFloor do this. I would create a trait or an object in mathExpressions.scala and name it something like FloorCeilErrorUtil or RoundingErrorUtil and add it there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DecimalType}

object RapidsErrorUtils {
object RapidsErrorUtils extends Arm {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
if (isElementAtF) {
Expand Down Expand Up @@ -55,4 +58,33 @@ object RapidsErrorUtils {
errorContext: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale, context
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import ai.rapids.cudf.ast.BinaryOperator
import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._

abstract class CudfUnaryMathExpression(name: String) extends GpuUnaryMathExpression(name)
Expand Down Expand Up @@ -198,7 +199,9 @@ case class GpuCeil(child: Expression, outputType: DataType)
withResource(DecimalUtil.outOfBounds(input.getBase, outputType)) { outOfBounds =>
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new ArithmeticException(s"Some data cannot be represented as $outputType")
throw RapidsErrorUtils.cannotChangeDecimalPrecisionError(
input, outOfBounds, dt, outputType
)
}
}
}
Expand Down Expand Up @@ -272,7 +275,9 @@ case class GpuFloor(child: Expression, outputType: DataType)
withResource(DecimalUtil.outOfBounds(input.getBase, outputType)) { outOfBounds =>
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new ArithmeticException(s"Some data cannot be represented as $outputType")
throw RapidsErrorUtils.cannotChangeDecimalPrecisionError(
input, outOfBounds, dt, outputType
)
}
}
}
Expand Down