Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: remzi <13716567376yh@gmail.com>
  • Loading branch information
HaoYang670 committed Jun 29, 2022
1 parent 50c840a commit b106fd6
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.{DataType, DecimalType}
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils extends Arm {
object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down Expand Up @@ -56,30 +53,10 @@ object RapidsErrorUtils extends Arm {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.{DataType, DecimalType}
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils extends Arm {
object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down Expand Up @@ -56,30 +53,10 @@ object RapidsErrorUtils extends Arm {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{DataType, DecimalType}
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils extends Arm{
object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down Expand Up @@ -57,30 +54,10 @@ object RapidsErrorUtils extends Arm{
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{DataType, DecimalType}
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils extends Arm {
object RapidsErrorUtils extends {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
if (isElementAtF) {
Expand Down Expand Up @@ -60,30 +57,10 @@ object RapidsErrorUtils extends Arm {
new ArithmeticException(message)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@

package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector}

import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{DataType, DecimalType}
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils extends Arm {
object RapidsErrorUtils {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
if (isElementAtF) {
Expand Down Expand Up @@ -59,30 +56,10 @@ object RapidsErrorUtils extends Arm {
QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext)
}

/**
* Wrapper of the `cannotChangeDecimalPrecisionError` in Spark.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale, context
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ case class GpuCeil(child: Expression, outputType: DataType)
withResource(DecimalUtil.outOfBounds(input.getBase, outputType)) { outOfBounds =>
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw RapidsErrorUtils.cannotChangeDecimalPrecisionError(
throw RoundingErrorUtil.cannotChangeDecimalPrecisionError(
input, outOfBounds, dt, outputType
)
}
Expand Down Expand Up @@ -275,7 +275,7 @@ case class GpuFloor(child: Expression, outputType: DataType)
withResource(DecimalUtil.outOfBounds(input.getBase, outputType)) { outOfBounds =>
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw RapidsErrorUtils.cannotChangeDecimalPrecisionError(
throw RoundingErrorUtil.cannotChangeDecimalPrecisionError(
input, outOfBounds, dt, outputType
)
}
Expand Down Expand Up @@ -786,3 +786,32 @@ case class GpuRint(child: Expression) extends CudfUnaryMathExpression("ROUND") {
override def unaryOp: UnaryOp = UnaryOp.RINT
override def outputTypeOverride: DType = DType.FLOAT64
}

private object RoundingErrorUtil extends Arm {
/**
* Wrapper of the `cannotChangeDecimalPrecisionError` of RapidsErrorUtils.
*
* @param values A decimal column which contains values that try to cast.
* @param outOfBounds A boolean column that indicates which value cannot be casted.
* Users must make sure that there is at least one `true` in this column.
* @param fromType The current decimal type.
* @param toType The type to cast.
* @param context The error context, default value is "".
*/
def cannotChangeDecimalPrecisionError(
values: GpuColumnVector,
outOfBounds: ColumnVector,
fromType: DecimalType,
toType: DecimalType,
context: String = ""): ArithmeticException = {
val row_id = withResource(outOfBounds.copyToHost()) {hcv =>
(0.toLong until outOfBounds.getRowCount())
.find(i => !hcv.isNull(i) && hcv.getBoolean(i))
.get
}
val value = withResource(values.copyToHost()){hcv =>
hcv.getDecimal(row_id.toInt, fromType.precision, fromType.scale)
}
RapidsErrorUtils.cannotChangeDecimalPrecisionError(value, toType)
}
}

0 comments on commit b106fd6

Please sign in to comment.