Skip to content

Commit

Permalink
Support casting from decimal to decimal (NVIDIA#1532)
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <lovedreamf@gmail.com>
  • Loading branch information
sperlingxx authored Jan 19, 2021
1 parent 7a39efb commit feba10a
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 85 deletions.
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -17017,7 +17017,7 @@ and the accelerator produces the same result.
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
Expand Down Expand Up @@ -17421,7 +17421,7 @@ and the accelerator produces the same result.
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
Expand Down
206 changes: 131 additions & 75 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,84 +380,13 @@ case class GpuCast(
input.getBase.asByteList(true)

case (ShortType | IntegerType | LongType, dt: DecimalType) =>

// Use INT64 bounds instead of FLOAT64 bounds, which enables precise comparison.
val (lowBound, upBound) = math.pow(10, dt.precision - dt.scale) match {
case bound if bound > Long.MaxValue => (Long.MinValue, Long.MaxValue)
case bound => (-bound.toLong + 1, bound.toLong - 1)
}
val checkedInput = if (ansiMode) {
assertValuesInRange(input.getBase,
minValue = Scalar.fromLong(lowBound),
maxValue = Scalar.fromLong(upBound))
input.getBase.incRefCount()
} else {
replaceOutOfRangeValues(input.getBase,
minValue = Scalar.fromLong(lowBound),
maxValue = Scalar.fromLong(upBound),
replaceValue = Scalar.fromNull(input.getBase.getType))
}

withResource(checkedInput) { checked =>
if (dt.scale < 0) {
// Rounding is essential when scale is negative,
// so we apply HALF_UP rounding manually to keep align with CpuCast.
withResource(checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, 0))) {
scaleZero => scaleZero.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
}
} else if (dt.scale > 0) {
// Integer will be enlarged during casting if scale > 0, so we cast input to INT64
// before casting it to decimal in case of overflow.
withResource(checked.castTo(DType.INT64)) { long =>
long.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
}
} else {
checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
}
}
castIntegralsToDecimal(input.getBase, dt)

case (FloatType | DoubleType, dt: DecimalType) =>
// Approach to minimize difference between CPUCast and GPUCast:
// step 1. cast input to FLOAT64 (if necessary)
// step 2. cast FLOAT64 to container DECIMAL (who keeps one more digit for rounding)
// step 3. perform HALF_UP rounding on container DECIMAL
val checkedInput = withResource(input.getBase.castTo(DType.FLOAT64)) { double =>
val roundedDouble = double.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
withResource(roundedDouble) { rounded =>
// We rely on containerDecimal to perform preciser rounding. So, we have to take extra
// space cost of container into consideration when we run bound check.
val containerScaleBound = DType.DECIMAL64_MAX_PRECISION - (dt.scale + 1)
val bound = math.pow(10, (dt.precision - dt.scale) min containerScaleBound)
if (ansiMode) {
assertValuesInRange(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false)
rounded.incRefCount()
} else {
replaceOutOfRangeValues(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false,
replaceValue = Scalar.fromNull(DType.FLOAT64))
}
}
}
castFloatsToDecimal(input.getBase, dt)

withResource(checkedInput) { checked =>
// If target scale reaches DECIMAL64_MAX_PRECISION, container DECIMAL can not
// be created because of precision overflow. In this case, we perform casting op directly.
if (DType.DECIMAL64_MAX_PRECISION == dt.scale) {
checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
} else {
val containerType = DType.create(DType.DTypeEnum.DECIMAL64, -(dt.scale + 1))
withResource(checked.castTo(containerType)) { container =>
container.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
}
}
}
case (from: DecimalType, to: DecimalType) =>
castDecimalToDecimal(input.getBase, from, to)

case _ =>
input.getBase.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
Expand Down Expand Up @@ -917,4 +846,131 @@ case class GpuCast(
}
}

private def castIntegralsToDecimal(input: ColumnVector, dt: DecimalType): ColumnVector = {

// Use INT64 bounds instead of FLOAT64 bounds, which enables precise comparison.
val (lowBound, upBound) = math.pow(10, dt.precision - dt.scale) match {
case bound if bound > Long.MaxValue => (Long.MinValue, Long.MaxValue)
case bound => (-bound.toLong + 1, bound.toLong - 1)
}
// At first, we conduct overflow check onto input column.
// Then, we cast checked input into target decimal type.
val checkedInput = if (ansiMode) {
assertValuesInRange(input,
minValue = Scalar.fromLong(lowBound),
maxValue = Scalar.fromLong(upBound))
input.incRefCount()
} else {
replaceOutOfRangeValues(input,
minValue = Scalar.fromLong(lowBound),
maxValue = Scalar.fromLong(upBound),
replaceValue = Scalar.fromNull(input.getType))
}

withResource(checkedInput) { checked =>
if (dt.scale < 0) {
// Rounding is essential when scale is negative,
// so we apply HALF_UP rounding manually to keep align with CpuCast.
withResource(checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, 0))) {
scaleZero => scaleZero.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
}
} else if (dt.scale > 0) {
// Integer will be enlarged during casting if scale > 0, so we cast input to INT64
// before casting it to decimal in case of overflow.
withResource(checked.castTo(DType.INT64)) { long =>
long.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
}
} else {
checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
}
}
}

private def castFloatsToDecimal(input: ColumnVector, dt: DecimalType): ColumnVector = {

// Approach to minimize difference between CPUCast and GPUCast:
// step 1. cast input to FLOAT64 (if necessary)
// step 2. cast FLOAT64 to container DECIMAL (who keeps one more digit for rounding)
// step 3. perform HALF_UP rounding on container DECIMAL
val checkedInput = withResource(input.castTo(DType.FLOAT64)) { double =>
val roundedDouble = double.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
withResource(roundedDouble) { rounded =>
// We rely on containerDecimal to perform preciser rounding. So, we have to take extra
// space cost of container into consideration when we run bound check.
val containerScaleBound = DType.DECIMAL64_MAX_PRECISION - (dt.scale + 1)
val bound = math.pow(10, (dt.precision - dt.scale) min containerScaleBound)
if (ansiMode) {
assertValuesInRange(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false)
rounded.incRefCount()
} else {
replaceOutOfRangeValues(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false,
replaceValue = Scalar.fromNull(DType.FLOAT64))
}
}
}

withResource(checkedInput) { checked =>
// If target scale reaches DECIMAL64_MAX_PRECISION, container DECIMAL can not
// be created because of precision overflow. In this case, we perform casting op directly.
if (DType.DECIMAL64_MAX_PRECISION == dt.scale) {
checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
} else {
val containerType = DType.create(DType.DTypeEnum.DECIMAL64, -(dt.scale + 1))
withResource(checked.castTo(containerType)) { container =>
container.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
}
}
}
}

private def castDecimalToDecimal(input: ColumnVector,
from: DecimalType,
to: DecimalType): ColumnVector = {

// At first, we conduct overflow check onto input column.
// Then, we cast checked input into target decimal type.
val checkedInput = if (to.scale <= from.scale) {
// No need to promote precision unless target scale is larger than the source one,
// which indicates the cast is always valid when to.scale <= from.scale.
input.incRefCount()
} else {
// Check whether there exists overflow during promoting precision or not.
// We do NOT use `Scalar.fromDecimal(-to.scale, math.pow(10, 18).toLong)` here, because
// cuDF binaryOperation on decimal will rescale right input to fit the left one.
// The rescaling may lead to overflow.
val absBound = math.pow(10, DType.DECIMAL64_MAX_PRECISION + from.scale - to.scale).toLong
if (ansiMode) {
assertValuesInRange(input,
minValue = Scalar.fromDecimal(-from.scale, -absBound),
maxValue = Scalar.fromDecimal(-from.scale, absBound),
inclusiveMin = false, inclusiveMax = false)
input.incRefCount()
} else {
replaceOutOfRangeValues(input,
minValue = Scalar.fromDecimal(-from.scale, -absBound),
maxValue = Scalar.fromDecimal(-from.scale, absBound),
replaceValue = Scalar.fromNull(input.getType),
inclusiveMin = false, inclusiveMax = false)
}
}

withResource(checkedInput) { checked =>
to.scale - from.scale match {
case 0 =>
checked.incRefCount()
case diff if diff > 0 =>
checked.castTo(GpuColumnVector.getNonNestedRapidsType(to))
case _ =>
checked.round(to.scale, ai.rapids.cudf.RoundMode.HALF_UP)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ class CastChecks extends ExprChecks {
val binaryChecks: TypeSig = none
val sparkBinarySig: TypeSig = STRING + BINARY

val decimalChecks: TypeSig = none
val decimalChecks: TypeSig = DECIMAL
val sparkDecimalSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + STRING

val calendarChecks: TypeSig = none
Expand Down
73 changes: 66 additions & 7 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.sql.Timestamp
import java.time.LocalDateTime
import java.util.TimeZone

import scala.collection.JavaConverters._

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast}
Expand Down Expand Up @@ -476,6 +478,38 @@ class CastOpSuite extends GpuExpressionTestSuite {
}
}

test("cast decimal to decimal") {
// fromScale == toScale
testCastToDecimal(DataTypes.createDecimalType(18, 0),
scale = 0,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 2),
scale = 2,
customRandGenerator = Some(new scala.util.Random(1234L)))

// fromScale > toScale
testCastToDecimal(DataTypes.createDecimalType(18, 1),
scale = -1,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 10),
scale = 2,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 18),
scale = 15,
customRandGenerator = Some(new scala.util.Random(1234L)))

// fromScale < toScale
testCastToDecimal(DataTypes.createDecimalType(18, 0),
scale = 3,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 5),
scale = 10,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 10),
scale = 17,
customRandGenerator = Some(new scala.util.Random(1234L)))
}

test("Detect overflow from numeric types to decimal") {
def intGenerator(column: Seq[Int])(ss: SparkSession): DataFrame = {
import ss.sqlContext.implicits._
Expand All @@ -493,6 +527,11 @@ class CastOpSuite extends GpuExpressionTestSuite {
import ss.sqlContext.implicits._
column.toDF("col")
}
def decimalGenerator(column: Seq[Decimal], decType: DecimalType
)(ss: SparkSession): DataFrame = {
val field = StructField("col", decType)
ss.createDataFrame(column.map(Row(_)).asJava, StructType(Seq(field)))
}
def nonOverflowCase(dataType: DataType,
generator: SparkSession => DataFrame,
precision: Int,
Expand Down Expand Up @@ -556,6 +595,15 @@ class CastOpSuite extends GpuExpressionTestSuite {
generator = floatGenerator(Seq(12345.678f)))
overflowCase(DataTypes.DoubleType, precision = 15, scale = -5,
generator = doubleGenerator(Seq(1.23e21)))

// Test 4: overflow caused by decimal rescaling
val decType = DataTypes.createDecimalType(18, 0)
nonOverflowCase(decType,
precision = 18, scale = 10,
generator = decimalGenerator(Seq(Decimal(99999999L)), decType))
overflowCase(decType,
precision = 18, scale = 10,
generator = decimalGenerator(Seq(Decimal(100000000L)), decType))
}

protected def testCastToDecimal(
Expand Down Expand Up @@ -595,7 +643,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
if (!gpuOnly) {
val (fromCpu, fromGpu) = runOnCpuAndGpu(createDF, execFun, conf, repart = 0)
val (cpuResult, gpuResult) = dataType match {
case ShortType | IntegerType | LongType =>
case ShortType | IntegerType | LongType | _: DecimalType =>
fromCpu.map(r => Row(r.getDecimal(1))) -> fromGpu.map(r => Row(r.getDecimal(1)))
case FloatType | DoubleType =>
// There may be tiny difference between CPU and GPU result when casting from double
Expand Down Expand Up @@ -631,7 +679,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
}
}
val scaleRnd = new scala.util.Random(enhancedRnd.nextLong())
val rawColumn: Seq[AnyVal] = (0 until rowCount).map { _ =>
val rawColumn: Seq[Any] = (0 until rowCount).map { _ =>
val scale = 18 - scaleRnd.nextInt(integralSize + 1)
dataType match {
case ShortType =>
Expand All @@ -642,16 +690,27 @@ class CastOpSuite extends GpuExpressionTestSuite {
enhancedRnd.nextLong() / math.pow(10, scale max 0).toLong
case FloatType | DoubleType =>
enhancedRnd.nextLong() / math.pow(10, scale + 2)
case dt: DecimalType =>
val unscaledValue = (enhancedRnd.nextLong() * math.pow(10, dt.precision - 18)).toLong
Decimal.createUnsafe(unscaledValue, dt.precision, dt.scale)
case _ =>
throw new IllegalArgumentException(s"unsupported dataType: $dataType")
}
}
dataType match {
case ShortType => rawColumn.map(_.asInstanceOf[Long].toShort).toDF("col")
case IntegerType => rawColumn.map(_.asInstanceOf[Long].toInt).toDF("col")
case LongType => rawColumn.map(_.asInstanceOf[Long]).toDF("col")
case FloatType => rawColumn.map(_.asInstanceOf[Double].toFloat).toDF("col")
case DoubleType => rawColumn.map(_.asInstanceOf[Double]).toDF("col")
case ShortType =>
rawColumn.map(_.asInstanceOf[Long].toShort).toDF("col")
case IntegerType =>
rawColumn.map(_.asInstanceOf[Long].toInt).toDF("col")
case LongType =>
rawColumn.map(_.asInstanceOf[Long]).toDF("col")
case FloatType =>
rawColumn.map(_.asInstanceOf[Double].toFloat).toDF("col")
case DoubleType =>
rawColumn.map(_.asInstanceOf[Double]).toDF("col")
case dt: DecimalType =>
val row = rawColumn.map(e => Row(e.asInstanceOf[Decimal])).asJava
ss.createDataFrame(row, StructType(Seq(StructField("col", dt))))
}
}
}
Expand Down

0 comments on commit feba10a

Please sign in to comment.