Skip to content

Commit

Permalink
Push decimal workarounds to cuDF (#4822)
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <lovedreamf@gmail.com>

Closes #3793

Pushes cuDF-related decimal utilities down to cuDF. This PR is relied on cuDF changes: rapidsai/cudf#9809, rapidsai/cudf#9907 and rapidsai/cudf#10316.
  • Loading branch information
sperlingxx authored Feb 23, 2022
1 parent 7d1662e commit 55e0593
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 237 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,7 @@ private static DType toRapidsOrNull(DataType type) {
} else if (type instanceof DecimalType) {
// Decimal supportable check has been conducted in the GPU plan overriding stage.
// So, we don't have to handle decimal-supportable problem at here.
DecimalType dt = (DecimalType) type;
return DecimalUtil.createCudfDecimal(dt.precision(), dt.scale());
return DecimalUtil.createCudfDecimal((DecimalType) type);
} else if (type instanceof GpuUnsignedIntegerType) {
return DType.UINT32;
} else if (type instanceof GpuUnsignedLongType) {
Expand Down
199 changes: 7 additions & 192 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/DecimalUtil.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,191 +17,17 @@
package com.nvidia.spark.rapids

import ai.rapids.cudf
import ai.rapids.cudf.DType
import ai.rapids.cudf.{DecimalUtils, DType}

import org.apache.spark.sql.types._

object DecimalUtil extends Arm {

def createCudfDecimal(dt: DecimalType): DType = {
createCudfDecimal(dt.precision, dt.scale)
}

def createCudfDecimal(precision: Int, scale: Int): DType = {
if (precision <= DType.DECIMAL32_MAX_PRECISION) {
DType.create(DType.DTypeEnum.DECIMAL32, -scale)
} else if (precision <= DType.DECIMAL64_MAX_PRECISION) {
DType.create(DType.DTypeEnum.DECIMAL64, -scale)
} else if (precision <= DType.DECIMAL128_MAX_PRECISION) {
DType.create(DType.DTypeEnum.DECIMAL128, -scale)
} else {
throw new IllegalArgumentException(s"precision overflow: $precision")
}
}

def getMaxPrecision(dt: DType): Int = dt.getTypeId match {
case DType.DTypeEnum.DECIMAL32 => DType.DECIMAL32_MAX_PRECISION
case DType.DTypeEnum.DECIMAL64 => DType.DECIMAL64_MAX_PRECISION
case _ if dt.isDecimalType => DType.DECIMAL128_MAX_PRECISION
case _ => throw new IllegalArgumentException(s"not a decimal type: $dt")
}

/**
* Returns two BigDecimals that are exactly the
* (smallest value `toType` can hold, largest value `toType` can hold).
*
* Be very careful when comparing these CUDF decimal comparisons really only work
* when both types are already the same precision and scale, and when you change the scale
* you end up losing information.
*/
def bounds(toType: DecimalType): (BigDecimal, BigDecimal) = {
val boundStr = ("9" * toType.precision) + "e" + (-toType.scale)
val toUpperBound = BigDecimal(boundStr)
val toLowerBound = BigDecimal("-" + boundStr)
(toLowerBound, toUpperBound)
}
def createCudfDecimal(dt: DecimalType): DType =
DecimalUtils.createDecimalType(dt.precision, dt.scale)

/**
* CUDF can have overflow issues when rounding values. This works around those issues for you.
* @param input the input data to round.
* @param decimalPlaces the decimal places to round to
* @param mode the rounding mode
* @return the rounded data.
*/
def round(input: cudf.ColumnView,
decimalPlaces: Int,
mode: cudf.RoundMode): cudf.ColumnVector = {
assert(input.getType.isDecimalType)
val cudfInputScale = input.getType.getScale
if (cudfInputScale >= -decimalPlaces) {
// No issues with overflow for these cases, so just do it.
input.round(decimalPlaces, mode)
} else {
// We actually will need to round because we will be losing some information during the round
// The DECIMAL type we use needs to be able to hold
// `std::pow(10, std::abs(decimal_places + input.type().scale()));`
// in it without overflowing.
val scaleMovement = Math.abs(decimalPlaces + cudfInputScale)
val maxInputPrecision = getMaxPrecision(input.getType)
if (scaleMovement > maxInputPrecision) {
// This is going to overflow unless we do something else first. But for round to work all
// we actually need is 1 decimal place more than the target decimalPlaces, so we can cast
// to this first (which will truncate the extra information), and then round to the desired
// result
val intermediateDType = DType.create(input.getType.getTypeId, (-decimalPlaces) + 1)
withResource(input.castTo(intermediateDType)) { truncated =>
truncated.round(decimalPlaces, mode)
}
} else {
input.round(decimalPlaces, mode)
}
}
}

/**
* Because CUDF can have issues with comparing decimal values that have different precision
* and scale accurately it takes some special steps to do this. This handles the corner cases
* for you.
*/
def lessThan(lhs: cudf.ColumnView, rhs: BigDecimal): cudf.ColumnVector = {
assert(lhs.getType.isDecimalType)
val cudfScale = lhs.getType.getScale
val cudfPrecision = getMaxPrecision(lhs.getType)

// First we have to round the scalar (rhs) to the same scale as lhs. Because this is a
// less than and it is rhs that we are rounding, we will round away from 0 (UP)
// to make sure we always return the correct value.
// For example:
// 100.1 < 100.19
// If we rounded down the rhs 100.19 would become 100.1, and now 100.1 is not < 100.1

val roundedRhs = rhs.setScale(-cudfScale, BigDecimal.RoundingMode.UP)

if (roundedRhs.precision > cudfPrecision) {
// converting rhs to the same precision as lhs would result in an overflow/error, but
// the scale is the same so we can still figure this out. For example if LHS precision is
// 4 and RHS precision is 5 we get the following...
// 9999 < 99999 => true
// -9999 < 99999 => true
// 9999 < -99999 => false
// -9999 < -99999 => false
// so the result should be the same as RHS > 0
withResource(cudf.Scalar.fromBool(roundedRhs > 0)) { rhsGtZero =>
cudf.ColumnVector.fromScalar(rhsGtZero, lhs.getRowCount.toInt)
}
} else {
val sparkType = DecimalType(cudfPrecision, -cudfScale)
withResource(GpuScalar.from(roundedRhs, sparkType)) { scalarRhs =>
lhs.lessThan(scalarRhs)
}
}
}

def lessThan(lhs: cudf.BinaryOperable, rhs: BigDecimal, numRows: Int): cudf.ColumnVector =
lhs match {
case cv: cudf.ColumnVector =>
lessThan(cv, rhs)
case s: cudf.Scalar =>
if (s.isValid) {
val isLess = (s.getBigDecimal.compareTo(rhs) < 0)
withResource(cudf.Scalar.fromBool(isLess)) { n =>
cudf.ColumnVector.fromScalar(n, numRows)
}
} else {
withResource(cudf.Scalar.fromNull(DType.BOOL8)) { n =>
cudf.ColumnVector.fromScalar(n, numRows)
}
}
}

/**
* Because CUDF can have issues with comparing decimal values that have different precision
* and scale accurately it takes some special steps to do this. This handles the corner cases
* for you.
*/
def greaterThan(lhs: cudf.ColumnView, rhs: BigDecimal): cudf.ColumnVector = {
assert(lhs.getType.isDecimalType)
val cudfScale = lhs.getType.getScale
val cudfPrecision = getMaxPrecision(lhs.getType)

// First we have to round the scalar (rhs) to the same scale as lhs. Because this is a
// greater than and it is rhs that we are rounding, we will round towards 0 (DOWN)
// to make sure we always return the correct value.
// For example:
// 100.2 > 100.19
// If we rounded up the rhs 100.19 would become 100.2, and now 100.2 is not > 100.2

val roundedRhs = rhs.setScale(-cudfScale, BigDecimal.RoundingMode.DOWN)

if (roundedRhs.precision > cudfPrecision) {
// converting rhs to the same precision as lhs would result in an overflow/error, but
// the scale is the same so we can still figure this out. For example if LHS precision is
// 4 and RHS precision is 5 we get the following...
// 9999 > 99999 => false
// -9999 > 99999 => false
// 9999 > -99999 => true
// -9999 > -99999 => true
// so the result should be the same as RHS < 0
withResource(cudf.Scalar.fromBool(roundedRhs < 0)) { rhsLtZero =>
cudf.ColumnVector.fromScalar(rhsLtZero, lhs.getRowCount.toInt)
}
} else {
val sparkType = DecimalType(cudfPrecision, -cudfScale)
withResource(GpuScalar.from(roundedRhs, sparkType)) { scalarRhs =>
lhs.greaterThan(scalarRhs)
}
}
}

def outOfBounds(input: cudf.ColumnView, to: DecimalType): cudf.ColumnVector = {
val (lowerBound, upperBound) = bounds(to)

withResource(greaterThan(input, upperBound)) { over =>
withResource(lessThan(input, lowerBound)) { under =>
over.or(under)
}
}
}
def outOfBounds(input: cudf.ColumnView, to: DecimalType): cudf.ColumnVector =
DecimalUtils.outOfBounds(input, to.precision, to.scale)

/**
* Return the size in bytes of the Fixed-width data types.
Expand All @@ -214,24 +40,13 @@ object DecimalUtil extends Arm {
}
}

/**
* Get the number of decimal places needed to hold the integral type held by this column
*/
def getPrecisionForIntegralType(input: DType): Int = input match {
case DType.INT8 => 3 // -128 to 127
case DType.INT16 => 5 // -32768 to 32767
case DType.INT32 => 10 // -2147483648 to 2147483647
case DType.INT64 => 19 // -9223372036854775808 to 9223372036854775807
case t => throw new IllegalArgumentException(s"Unsupported type $t")
}
// The following types were copied from Spark's DecimalType class
private val BooleanDecimal = DecimalType(1, 0)

def optionallyAsDecimalType(t: DataType): Option[DecimalType] = t match {
case dt: DecimalType => Some(dt)
case ByteType | ShortType | IntegerType | LongType =>
val prec = DecimalUtil.getPrecisionForIntegralType(GpuColumnVector.getNonNestedRapidsType(t))
Some(DecimalType(prec, 0))
Some(DecimalType(GpuColumnVector.getNonNestedRapidsType(t).getPrecisionForInt, 0))
case BooleanType => Some(BooleanDecimal)
case _ => None
}
Expand Down
31 changes: 10 additions & 21 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.time.DateTimeException

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar}
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DecimalUtils, DType, Scalar}
import ai.rapids.cudf
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.v2.YearParseUtil
Expand Down Expand Up @@ -1361,10 +1361,10 @@ object GpuCast extends Arm {
input: ColumnView,
dt: DecimalType,
ansiMode: Boolean): ColumnVector = {
val prec = DecimalUtil.getPrecisionForIntegralType(input.getType)
val prec = input.getType.getPrecisionForInt
// Cast input to decimal
val inputDecimalType = new DecimalType(prec, 0)
withResource(input.castTo(DecimalUtil.createCudfDecimal(prec, 0))) { castedInput =>
withResource(input.castTo(DecimalUtil.createCudfDecimal(inputDecimalType))) { castedInput =>
castDecimalToDecimal(castedInput, inputDecimalType, dt, ansiMode)
}
}
Expand Down Expand Up @@ -1404,15 +1404,15 @@ object GpuCast extends Arm {
}

withResource(checkedInput) { checked =>
val targetType = DecimalUtil.createCudfDecimal(dt.precision, dt.scale)
val targetType = DecimalUtil.createCudfDecimal(dt)
// If target scale reaches DECIMAL128_MAX_PRECISION, container DECIMAL can not
// be created because of precision overflow. In this case, we perform casting op directly.
val casted = if (DecimalUtil.getMaxPrecision(targetType) == dt.scale) {
val casted = if (targetType.getDecimalMaxPrecision == dt.scale) {
checked.castTo(targetType)
} else {
val containerType = DecimalUtil.createCudfDecimal(dt.precision, dt.scale + 1)
val containerType = DecimalUtils.createDecimalType(dt.precision, dt.scale + 1)
withResource(checked.castTo(containerType)) { container =>
DecimalUtil.round(container, dt.scale, cudf.RoundMode.HALF_UP)
container.round(dt.scale, cudf.RoundMode.HALF_UP)
}
}
// Cast NaN values to nulls
Expand Down Expand Up @@ -1458,8 +1458,8 @@ object GpuCast extends Arm {
from: DecimalType,
to: DecimalType,
ansiMode: Boolean): ColumnVector = {
val toDType = DecimalUtil.createCudfDecimal(to.precision, to.scale)
val fromDType = DecimalUtil.createCudfDecimal(from.precision, from.scale)
val toDType = DecimalUtil.createCudfDecimal(to)
val fromDType = DecimalUtil.createCudfDecimal(from)

val fromWholeNumPrecision = from.precision - from.scale
val toWholeNumPrecision = to.precision - to.scale
Expand All @@ -1484,18 +1484,7 @@ object GpuCast extends Arm {
val rounded = if (!isScaleUpcast) {
// We have to round the data to the desired scale. Spark uses HALF_UP rounding in
// this case so we need to also.

// Rounding up can cause overflow, but if the input is in the proper range for Spark
// the overflow will fit in the current CUDF type without the need to cast it.
// Int.MinValue = -2147483648
// DECIMAL32 min unscaled = -999999999
// DECIMAL32 min unscaled and rounded = -1000000000 (Which fits)
// Long.MinValue = -9223372036854775808
// DECIMAL64 min unscaled = -999999999999999999
// DECIMAL64 min unscaled and rounded = -1000000000000000000 (Which fits)
// That means we don't need to cast it to a wider type first, we just need to be sure
// that we do boundary checks, if we did need to round
DecimalUtil.round(input, to.scale, cudf.RoundMode.HALF_UP)
input.round(to.scale, cudf.RoundMode.HALF_UP)
} else {
input.copyToColumnVector()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,7 @@
package com.nvidia.spark.rapids

import ai.rapids.cudf
import ai.rapids.cudf.{ColumnVector, DType, Scalar}
import ai.rapids.cudf.{ColumnVector, DecimalUtils, DType, Scalar}

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.{DataType, DecimalType, LongType}
Expand Down Expand Up @@ -47,7 +47,7 @@ case class GpuCheckOverflow(child: Expression,
val rounded = if (resultDType.equals(base.getType)) {
base.incRefCount()
} else {
withResource(DecimalUtil.round(base, dataType.scale, cudf.RoundMode.HALF_UP)) { rounded =>
withResource(base.round(dataType.scale, cudf.RoundMode.HALF_UP)) { rounded =>
if (resultDType.getTypeId != base.getType.getTypeId) {
rounded.castTo(resultDType)
} else {
Expand Down Expand Up @@ -98,12 +98,12 @@ case class GpuMakeDecimal(
override def toString: String = s"MakeDecimal($child,$precision,$sparkScale)"

private lazy val (minValue, maxValue) = {
val (minDec, maxDec) = DecimalUtil.bounds(dataType)
(minDec.bigDecimal.unscaledValue().longValue(), maxDec.bigDecimal.unscaledValue().longValue())
val bounds = DecimalUtils.bounds(dataType.precision, dataType.scale)
(bounds.getKey.unscaledValue().longValue(), bounds.getValue.unscaledValue().longValue())
}

override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
val outputType = DecimalUtil.createCudfDecimal(precision, sparkScale)
val outputType = DecimalUtils.createDecimalType(precision, sparkScale)
val base = input.getBase
val outOfBounds = withResource(Scalar.fromLong(maxValue)) { maxScalar =>
withResource(base.greaterThan(maxScalar)) { over =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,14 @@ object GpuAdd extends Arm {
// Overflow happens if the arguments have the same signs and it is different from the sign of
// the result
val numRows = ret.getRowCount.toInt
val zero = BigDecimal(0)
withResource(DecimalUtil.lessThan(rhs, zero, numRows)) { rhsLz =>
val argsSignSame = withResource(DecimalUtil.lessThan(lhs, zero, numRows)) { lhsLz =>
val zero = BigDecimal(0).bigDecimal
withResource(DecimalUtils.lessThan(rhs, zero, numRows)) { rhsLz =>
val argsSignSame = withResource(DecimalUtils.lessThan(lhs, zero, numRows)) { lhsLz =>
lhsLz.equalTo(rhsLz)
}
withResource(argsSignSame) { argsSignSame =>
val resultAndRhsDifferentSign =
withResource(DecimalUtil.lessThan(ret, zero)) { resultLz =>
withResource(DecimalUtils.lessThan(ret, zero)) { resultLz =>
rhsLz.notEqualTo(resultLz)
}
withResource(resultAndRhsDifferentSign) { resultAndRhsDifferentSign =>
Expand Down Expand Up @@ -286,14 +286,14 @@ case class GpuSubtract(
// Overflow happens if the arguments have different signs and the sign of the result is
// different from the sign of subtractend (RHS).
val numRows = ret.getRowCount.toInt
val zero = BigDecimal(0)
val overflow = withResource(DecimalUtil.lessThan(rhs, zero, numRows)) { rhsLz =>
val argsSignDifferent = withResource(DecimalUtil.lessThan(lhs, zero, numRows)) { lhsLz =>
val zero = BigDecimal(0).bigDecimal
val overflow = withResource(DecimalUtils.lessThan(rhs, zero, numRows)) { rhsLz =>
val argsSignDifferent = withResource(DecimalUtils.lessThan(lhs, zero, numRows)) { lhsLz =>
lhsLz.notEqualTo(rhsLz)
}
withResource(argsSignDifferent) { argsSignDifferent =>
val resultAndSubtrahendSameSign =
withResource(DecimalUtil.lessThan(ret, zero)) { resultLz =>
withResource(DecimalUtils.lessThan(ret, zero)) { resultLz =>
rhsLz.equalTo(resultLz)
}
withResource(resultAndSubtrahendSameSign) { resultAndSubtrahendSameSign =>
Expand Down
Loading

0 comments on commit 55e0593

Please sign in to comment.