Skip to content

Commit

Permalink
Fix crash when casting decimals to long (#6103)
Browse files Browse the repository at this point in the history
* Fix Decimal to Long cast

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* added comment

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* renamed var

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* refactored local vals to class vals

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* capitalize constants

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri and razajafri authored Aug 1, 2022
1 parent e2d18fd commit e77b40f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
20 changes: 15 additions & 5 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ object GpuCast extends Arm {
private val TIMESTAMP_TRUNCATE_REGEX = "^([0-9]{4}-[0-9]{2}-[0-9]{2} " +
"[0-9]{2}:[0-9]{2}:[0-9]{2})" +
"(.[1-9]*(?:0)?[1-9]+)?(.0*[1-9]+)?(?:.0*)?$"
private val BIG_DECIMAL_LONG_MIN = BigDecimal(Long.MinValue)
private val BIG_DECIMAL_LONG_MAX = BigDecimal(Long.MaxValue)

val INVALID_INPUT_MESSAGE: String = "Column contains at least one value that is not in the " +
"required range"
Expand Down Expand Up @@ -378,11 +380,19 @@ object GpuCast extends Arm {
// ansi cast from larger-than-long integral-like types, to long
case (dt: DecimalType, LongType) if ansiMode =>
// This is a work around for https://github.com/rapidsai/cudf/issues/9282
val min = BigDecimal(Long.MinValue)
.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal
val max = BigDecimal(Long.MaxValue)
.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal
assertValuesInRange(input, Scalar.fromDecimal(min), Scalar.fromDecimal(max))
val min = BIG_DECIMAL_LONG_MIN.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal
val max = BIG_DECIMAL_LONG_MAX.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal
// We are going against our convention of calling assertValuesInRange()
// because the min/max values are a different decimal type i.e. Decimal 128 as opposed to
// the incoming input column type.
withResource(input.min()) { minInput =>
withResource(input.max()) { maxInput =>
if (minInput.isValid && minInput.getBigDecimal().compareTo(min) == -1 ||
maxInput.isValid && maxInput.getBigDecimal().compareTo(max) == 1) {
throw new ArithmeticException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
}
if (dt.precision <= DType.DECIMAL32_MAX_PRECISION && dt.scale < 0) {
// This is a work around for https://github.com/rapidsai/cudf/issues/9281
withResource(input.castTo(DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale))) { tmp =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,11 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
}
}

testSparkResultsAreEqual("ansi_cast decimals to long",
generateValidValuesDecimalDF(Short.MinValue, Short.MaxValue, 18, 3), sparkConf) {
frame => testCastTo(DataTypes.LongType)(frame)
}

private def castToStringExpectedFun[T]: T => Option[String] = (d: T) => Some(String.valueOf(d))

private def testCastToString[T](dataType: DataType, ansiMode: Boolean,
Expand Down

0 comments on commit e77b40f

Please sign in to comment.