-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix crash when casting decimals to long #6103
Conversation
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
withResource(input.min()) { min => | ||
withResource(input.max()) { max => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
min
or max
should be just a number, right? So we won't need to wrap them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not just a number, it is a Scalar
object (https://github.com/rapidsai/cudf/blob/branch-22.08/java/src/main/java/ai/rapids/cudf/Scalar.java#L35), so it has to be closed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is scalar then we may want to call isValid
to check the min and max values. Otherwise, if the input is all nulls, these values will be invalid.
.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal | ||
assertValuesInRange(input, Scalar.fromDecimal(min), Scalar.fromDecimal(max)) | ||
withResource(input.min()) { min => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comment about why we are doing this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, the assertValuesInRange
function is inefficient: It calls less
operator then any
then greater
then any
, and all these operations are O(N)
. We can achieve the result by half of computation by using this new approach: min
then max
in O(N)
then compare the min/max values with the boundary in O(1)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've filed a corresponding issue: #6130
it would be nice to have an issue associated with this that describes the problem and reproducing it |
withResource(input.max()) { max => | ||
if (min.getBigDecimal().compareTo(bigDecimalMin) == -1 || | ||
max.getBigDecimal().compareTo(bigDecimalMax) == 1) { | ||
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this same thing throws in Spark on Cpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we need to add an integration test to compare?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this same thing throws in Spark on Cpu?
CPU throws an ArithmeticException but I wanted to match the existing behavior. I have changed it to match the CPU.
Maybe we need to add an integration test to compare?
Just adding an integration test to for a handful of values or testing a wider range? I experimented with this and it will be a lot more involved test as there are many other types of exceptions that can be thrown in ANSI e.g. overflow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can imagine that this will be the case for other cast operations.
If we do not have integration tests to cover the behavior of CPU vs GPU, then we can create a new followup issue to improve the tests including the "Decimal to Long".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for adding a test to make sure that we are not regressing on the exception.
withResource(input.max()) { max => | ||
if (min.getBigDecimal().compareTo(bigDecimalMin) == -1 || | ||
max.getBigDecimal().compareTo(bigDecimalMax) == 1) { | ||
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we need to add an integration test to compare?
generateValidValuesDecimalDF(Short.MinValue, Short.MaxValue, 18, 3), sparkConf) { | ||
frame => testCastTo(DataTypes.LongType)(frame) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to test that we actually throw the expected exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is only for the valid values. May be I need to add a test that tests just that an exception is thrown?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We won't need to have a new case as you are throwing the same exception because this will be covered in testCastFailsForBadInputs("ansi_cast overflow decimals to longs",..)
.
If we change the behavior to match the CPU, then we would need a new test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, then that test is already testing the exception that is thrown.
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
val bigDecimalMin = BigDecimal(Long.MinValue) | ||
.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal | ||
val max = BigDecimal(Long.MaxValue) | ||
val bigDecimalMax = BigDecimal(Long.MaxValue) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bigDecimalMax/Min are not minimums of BigDecimal, the names are confusing. The conversion from long is probably non-trivial, might be worthwhile to make these class vals
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The BigDecimal
will have to be re-scaled and most of the calculation is done in that method. I am not sure what we will gain by just making a BigDecimal(Long.MIN)
as a class val. I can still do it but I just wanted to make sure I was understanding you correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we gaining anything by keeping recomputing the constants like BigDecimal(Long.MaxValue)
?
If nothing else it's a bunch of unnecessary object allocations
https://github.com/scala/scala/blob/2.12.x/src/library/scala/math/BigDecimal.scala#L211
https://github.com/frohoff/jdk8u-dev-jdk/blob/da0da73ab82ed714dc5be94acd2f0d00fbdfe2e9/src/share/classes/java/math/BigDecimal.java#L1217-L1223
https://github.com/frohoff/jdk8u-dev-jdk/blob/da0da73ab82ed714dc5be94acd2f0d00fbdfe2e9/src/share/classes/java/math/BigDecimal.java#L1217-L1223
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note BigDecimal
instances are just like Integer
immutable. So setScale
will produce another object generated from the objects we can cache such as BigDecimal(Long.MaxValue)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a valid argument depending on how frequent casting Decimals to Longs is done in a workload. The other concern is using static/constants increases the footprint of the VM and slows down the initialization.
Anyway, if this is going to be addressed within the same PR, then I suggest to create constants in a util
to be used in different classes.
The repo has two other locations that use constants BigDecimal.
in arithemtic.scala there are two local variables
val zero = BigDecimal(0).bigDecimal
Then we can create three constants BigDecimal(0).bigDecimal
, BigDecimal(Long.MaxValue), and BigDecimal(Long.MinValue)
All other constant BigDecimals are in test classes which we can ignore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please handle creating of the util class in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it is fair. Both BigDecimal(Long.MaxValue)
, and BigDecimal(Long.MinValue)
were not introduced in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't have yet to worry about the constant pool because of the two constants being added. We can do a more sweeping refactoring in a separate PR.
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
build |
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but would be great to see this case be validated against CPU
https://github.com/NVIDIA/spark-rapids/pull/6103/files#diff-e981882f5ee2f922528de849ae5397dd30e5bfe9dd5fdbe3421de4733a0eae1aR392
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
A minor styling issue is to have constants in capital letters.
I am not sure that we have a clear formatting rule for this, but it looks like all the constants in GpuCast in capital letters.
@@ -177,6 +177,8 @@ object GpuCast extends Arm { | |||
private val TIMESTAMP_TRUNCATE_REGEX = "^([0-9]{4}-[0-9]{2}-[0-9]{2} " + | |||
"[0-9]{2}:[0-9]{2}:[0-9]{2})" + | |||
"(.[1-9]*(?:0)?[1-9]+)?(.0*[1-9]+)?(?:.0*)?$" | |||
private val bigDecimalLongMin = BigDecimal(Long.MinValue) | |||
private val bigDecimalLongMax = BigDecimal(Long.MaxValue) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
private val BIG_DECIMAL_LONG_MIN = BigDecimal(Long.MinValue)
private val BIG_DECIMAL_LONG_MAX = BigDecimal(Long.MaxValue)
val max = BigDecimal(Long.MaxValue) | ||
.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal | ||
val min = bigDecimalLongMin.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal | ||
val max = bigDecimalLongMax.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
val min = BIG_DECIMAL_LONG_MIN.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal
val max = BIG_DECIMAL_LONG_MAX.setScale(dt.scale, BigDecimal.RoundingMode.DOWN).bigDecimal
build |
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Raza!
LGTM.
build |
This PR avoids type mismatch that can occur if the min/max decimal values are a different precision than the column being casted in ANSI mode.
e.g.
Casting 222.22 to long will result in a call to check to see the values are in range before converting them to long values. That check will cause a crash as the minimum value possible in decimal(5,2) will result in a Decimal 128 value and the
assertValuesInRange
method inGpuCast
will result in a type mismatch error from cudf when it callslessThan
.@ttnghia Suggested a way around this is to first find the minimum value in the input column and compare that value to the decimal 128 value using a Java
BigDecimal.compareTo
which can handle comparing Decimals of different precisions.fixes #6128