Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support casting from decimal to decimal #1532

Merged
merged 3 commits into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -17017,7 +17017,7 @@ and the accelerator produces the same result.
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
Expand Down Expand Up @@ -17421,7 +17421,7 @@ and the accelerator produces the same result.
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
Expand Down
206 changes: 131 additions & 75 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,84 +380,13 @@ case class GpuCast(
input.getBase.asByteList(true)

case (ShortType | IntegerType | LongType, dt: DecimalType) =>

// Use INT64 bounds instead of FLOAT64 bounds, which enables precise comparison.
val (lowBound, upBound) = math.pow(10, dt.precision - dt.scale) match {
case bound if bound > Long.MaxValue => (Long.MinValue, Long.MaxValue)
case bound => (-bound.toLong + 1, bound.toLong - 1)
}
val checkedInput = if (ansiMode) {
assertValuesInRange(input.getBase,
minValue = Scalar.fromLong(lowBound),
maxValue = Scalar.fromLong(upBound))
input.getBase.incRefCount()
} else {
replaceOutOfRangeValues(input.getBase,
minValue = Scalar.fromLong(lowBound),
maxValue = Scalar.fromLong(upBound),
replaceValue = Scalar.fromNull(input.getBase.getType))
}

withResource(checkedInput) { checked =>
if (dt.scale < 0) {
// Rounding is essential when scale is negative,
// so we apply HALF_UP rounding manually to keep align with CpuCast.
withResource(checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, 0))) {
scaleZero => scaleZero.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
}
} else if (dt.scale > 0) {
// Integer will be enlarged during casting if scale > 0, so we cast input to INT64
// before casting it to decimal in case of overflow.
withResource(checked.castTo(DType.INT64)) { long =>
long.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
}
} else {
checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
}
}
castIntegralsToDecimal(input.getBase, dt)

case (FloatType | DoubleType, dt: DecimalType) =>
// Approach to minimize difference between CPUCast and GPUCast:
// step 1. cast input to FLOAT64 (if necessary)
// step 2. cast FLOAT64 to container DECIMAL (who keeps one more digit for rounding)
// step 3. perform HALF_UP rounding on container DECIMAL
val checkedInput = withResource(input.getBase.castTo(DType.FLOAT64)) { double =>
val roundedDouble = double.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
withResource(roundedDouble) { rounded =>
// We rely on containerDecimal to perform preciser rounding. So, we have to take extra
// space cost of container into consideration when we run bound check.
val containerScaleBound = DType.DECIMAL64_MAX_PRECISION - (dt.scale + 1)
val bound = math.pow(10, (dt.precision - dt.scale) min containerScaleBound)
if (ansiMode) {
assertValuesInRange(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false)
rounded.incRefCount()
} else {
replaceOutOfRangeValues(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false,
replaceValue = Scalar.fromNull(DType.FLOAT64))
}
}
}
castFloatsToDecimal(input.getBase, dt)

withResource(checkedInput) { checked =>
// If target scale reaches DECIMAL64_MAX_PRECISION, container DECIMAL can not
// be created because of precision overflow. In this case, we perform casting op directly.
if (DType.DECIMAL64_MAX_PRECISION == dt.scale) {
checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
} else {
val containerType = DType.create(DType.DTypeEnum.DECIMAL64, -(dt.scale + 1))
withResource(checked.castTo(containerType)) { container =>
container.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
}
}
}
case (from: DecimalType, to: DecimalType) =>
castDecimalToDecimal(input.getBase, from, to)

case _ =>
input.getBase.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
Expand Down Expand Up @@ -917,4 +846,131 @@ case class GpuCast(
}
}

private def castIntegralsToDecimal(input: ColumnVector, dt: DecimalType): ColumnVector = {

// Use INT64 bounds instead of FLOAT64 bounds, which enables precise comparison.
val (lowBound, upBound) = math.pow(10, dt.precision - dt.scale) match {
case bound if bound > Long.MaxValue => (Long.MinValue, Long.MaxValue)
case bound => (-bound.toLong + 1, bound.toLong - 1)
}
// At first, we conduct overflow check onto input column.
// Then, we cast checked input into target decimal type.
val checkedInput = if (ansiMode) {
assertValuesInRange(input,
minValue = Scalar.fromLong(lowBound),
maxValue = Scalar.fromLong(upBound))
input.incRefCount()
} else {
replaceOutOfRangeValues(input,
minValue = Scalar.fromLong(lowBound),
maxValue = Scalar.fromLong(upBound),
replaceValue = Scalar.fromNull(input.getType))
}

withResource(checkedInput) { checked =>
if (dt.scale < 0) {
// Rounding is essential when scale is negative,
// so we apply HALF_UP rounding manually to keep align with CpuCast.
withResource(checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, 0))) {
scaleZero => scaleZero.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
}
} else if (dt.scale > 0) {
// Integer will be enlarged during casting if scale > 0, so we cast input to INT64
// before casting it to decimal in case of overflow.
withResource(checked.castTo(DType.INT64)) { long =>
long.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
}
} else {
checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
}
}
}

private def castFloatsToDecimal(input: ColumnVector, dt: DecimalType): ColumnVector = {

// Approach to minimize difference between CPUCast and GPUCast:
// step 1. cast input to FLOAT64 (if necessary)
// step 2. cast FLOAT64 to container DECIMAL (who keeps one more digit for rounding)
// step 3. perform HALF_UP rounding on container DECIMAL
val checkedInput = withResource(input.castTo(DType.FLOAT64)) { double =>
val roundedDouble = double.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
withResource(roundedDouble) { rounded =>
// We rely on containerDecimal to perform preciser rounding. So, we have to take extra
// space cost of container into consideration when we run bound check.
val containerScaleBound = DType.DECIMAL64_MAX_PRECISION - (dt.scale + 1)
val bound = math.pow(10, (dt.precision - dt.scale) min containerScaleBound)
if (ansiMode) {
assertValuesInRange(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false)
rounded.incRefCount()
} else {
replaceOutOfRangeValues(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false,
replaceValue = Scalar.fromNull(DType.FLOAT64))
}
}
}

withResource(checkedInput) { checked =>
// If target scale reaches DECIMAL64_MAX_PRECISION, container DECIMAL can not
// be created because of precision overflow. In this case, we perform casting op directly.
if (DType.DECIMAL64_MAX_PRECISION == dt.scale) {
checked.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))
} else {
val containerType = DType.create(DType.DTypeEnum.DECIMAL64, -(dt.scale + 1))
withResource(checked.castTo(containerType)) { container =>
container.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP)
}
}
}
}

private def castDecimalToDecimal(input: ColumnVector,
from: DecimalType,
to: DecimalType): ColumnVector = {

// At first, we conduct overflow check onto input column.
// Then, we cast checked input into target decimal type.
val checkedInput = if (to.scale <= from.scale) {
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
// No need to promote precision unless target scale is larger than the source one,
// which indicates the cast is always valid when to.scale <= from.scale.
input.incRefCount()
} else {
// Check whether there exists overflow during promoting precision or not.
// We do NOT use `Scalar.fromDecimal(-to.scale, math.pow(10, 18).toLong)` here, because
// cuDF binaryOperation on decimal will rescale right input to fit the left one.
// The rescaling may lead to overflow.
val absBound = math.pow(10, DType.DECIMAL64_MAX_PRECISION + from.scale - to.scale).toLong
if (ansiMode) {
assertValuesInRange(input,
minValue = Scalar.fromDecimal(-from.scale, -absBound),
maxValue = Scalar.fromDecimal(-from.scale, absBound),
inclusiveMin = false, inclusiveMax = false)
input.incRefCount()
} else {
replaceOutOfRangeValues(input,
minValue = Scalar.fromDecimal(-from.scale, -absBound),
maxValue = Scalar.fromDecimal(-from.scale, absBound),
replaceValue = Scalar.fromNull(input.getType),
inclusiveMin = false, inclusiveMax = false)
}
}

withResource(checkedInput) { checked =>
to.scale - from.scale match {
case 0 =>
checked.incRefCount()
case diff if diff > 0 =>
checked.castTo(GpuColumnVector.getNonNestedRapidsType(to))
case _ =>
checked.round(to.scale, ai.rapids.cudf.RoundMode.HALF_UP)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ class CastChecks extends ExprChecks {
val binaryChecks: TypeSig = none
val sparkBinarySig: TypeSig = STRING + BINARY

val decimalChecks: TypeSig = none
val decimalChecks: TypeSig = DECIMAL
val sparkDecimalSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + STRING

val calendarChecks: TypeSig = none
Expand Down
73 changes: 66 additions & 7 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.sql.Timestamp
import java.time.LocalDateTime
import java.util.TimeZone

import scala.collection.JavaConverters._

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast}
Expand Down Expand Up @@ -476,6 +478,38 @@ class CastOpSuite extends GpuExpressionTestSuite {
}
}

test("cast decimal to decimal") {
// fromScale == toScale
testCastToDecimal(DataTypes.createDecimalType(18, 0),
scale = 0,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 2),
scale = 2,
customRandGenerator = Some(new scala.util.Random(1234L)))

// fromScale > toScale
testCastToDecimal(DataTypes.createDecimalType(18, 1),
scale = -1,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 10),
scale = 2,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 18),
scale = 15,
customRandGenerator = Some(new scala.util.Random(1234L)))

// fromScale < toScale
testCastToDecimal(DataTypes.createDecimalType(18, 0),
scale = 3,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 5),
scale = 10,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 10),
scale = 17,
customRandGenerator = Some(new scala.util.Random(1234L)))
}

test("Detect overflow from numeric types to decimal") {
def intGenerator(column: Seq[Int])(ss: SparkSession): DataFrame = {
import ss.sqlContext.implicits._
Expand All @@ -493,6 +527,11 @@ class CastOpSuite extends GpuExpressionTestSuite {
import ss.sqlContext.implicits._
column.toDF("col")
}
def decimalGenerator(column: Seq[Decimal], decType: DecimalType
)(ss: SparkSession): DataFrame = {
val field = StructField("col", decType)
ss.createDataFrame(column.map(Row(_)).asJava, StructType(Seq(field)))
}
def nonOverflowCase(dataType: DataType,
generator: SparkSession => DataFrame,
precision: Int,
Expand Down Expand Up @@ -556,6 +595,15 @@ class CastOpSuite extends GpuExpressionTestSuite {
generator = floatGenerator(Seq(12345.678f)))
overflowCase(DataTypes.DoubleType, precision = 15, scale = -5,
generator = doubleGenerator(Seq(1.23e21)))

// Test 4: overflow caused by decimal rescaling
val decType = DataTypes.createDecimalType(18, 0)
nonOverflowCase(decType,
precision = 18, scale = 10,
generator = decimalGenerator(Seq(Decimal(99999999L)), decType))
overflowCase(decType,
precision = 18, scale = 10,
generator = decimalGenerator(Seq(Decimal(100000000L)), decType))
}

protected def testCastToDecimal(
Expand Down Expand Up @@ -595,7 +643,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
if (!gpuOnly) {
val (fromCpu, fromGpu) = runOnCpuAndGpu(createDF, execFun, conf, repart = 0)
val (cpuResult, gpuResult) = dataType match {
case ShortType | IntegerType | LongType =>
case ShortType | IntegerType | LongType | _: DecimalType =>
fromCpu.map(r => Row(r.getDecimal(1))) -> fromGpu.map(r => Row(r.getDecimal(1)))
case FloatType | DoubleType =>
// There may be tiny difference between CPU and GPU result when casting from double
Expand Down Expand Up @@ -631,7 +679,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
}
}
val scaleRnd = new scala.util.Random(enhancedRnd.nextLong())
val rawColumn: Seq[AnyVal] = (0 until rowCount).map { _ =>
val rawColumn: Seq[Any] = (0 until rowCount).map { _ =>
val scale = 18 - scaleRnd.nextInt(integralSize + 1)
dataType match {
case ShortType =>
Expand All @@ -642,16 +690,27 @@ class CastOpSuite extends GpuExpressionTestSuite {
enhancedRnd.nextLong() / math.pow(10, scale max 0).toLong
case FloatType | DoubleType =>
enhancedRnd.nextLong() / math.pow(10, scale + 2)
case dt: DecimalType =>
val unscaledValue = (enhancedRnd.nextLong() * math.pow(10, dt.precision - 18)).toLong
Decimal.createUnsafe(unscaledValue, dt.precision, dt.scale)
case _ =>
throw new IllegalArgumentException(s"unsupported dataType: $dataType")
}
}
dataType match {
case ShortType => rawColumn.map(_.asInstanceOf[Long].toShort).toDF("col")
case IntegerType => rawColumn.map(_.asInstanceOf[Long].toInt).toDF("col")
case LongType => rawColumn.map(_.asInstanceOf[Long]).toDF("col")
case FloatType => rawColumn.map(_.asInstanceOf[Double].toFloat).toDF("col")
case DoubleType => rawColumn.map(_.asInstanceOf[Double]).toDF("col")
case ShortType =>
rawColumn.map(_.asInstanceOf[Long].toShort).toDF("col")
case IntegerType =>
rawColumn.map(_.asInstanceOf[Long].toInt).toDF("col")
case LongType =>
rawColumn.map(_.asInstanceOf[Long]).toDF("col")
case FloatType =>
rawColumn.map(_.asInstanceOf[Double].toFloat).toDF("col")
case DoubleType =>
rawColumn.map(_.asInstanceOf[Double]).toDF("col")
case dt: DecimalType =>
val row = rawColumn.map(e => Row(e.asInstanceOf[Decimal])).asJava
ss.createDataFrame(row, StructType(Seq(StructField("col", dt))))
}
}
}
Expand Down