From d6f0cfe3309d53a6e4fe51c7494b70f294c0c61b Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Wed, 13 Dec 2023 11:20:18 -0800 Subject: [PATCH] addressed review comments --- src/main/cpp/src/decimal_utils.cu | 37 ++++++++++--------- .../nvidia/spark/rapids/jni/DecimalUtils.java | 2 +- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/main/cpp/src/decimal_utils.cu b/src/main/cpp/src/decimal_utils.cu index 84e6baa7ac..a9d9fbf890 100644 --- a/src/main/cpp/src/decimal_utils.cu +++ b/src/main/cpp/src/decimal_utils.cu @@ -658,7 +658,7 @@ struct dec128_multiplier { cudf::mutable_column_view const& product_view, cudf::column_view const& a_col, cudf::column_view const& b_col, - bool const& cast_interim_result) + bool const cast_interim_result) : overflows(overflows), a_data(a_col.data<__int128_t>()), b_data(b_col.data<__int128_t>()), @@ -679,23 +679,24 @@ struct dec128_multiplier { int dec_precision = precision10(product); - int mult_scale = a_scale + b_scale; - - // According to https://issues.apache.org/jira/browse/SPARK-40129 - // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in - // versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 and 4.0.0 The bug is fixed for later versions but to - // match the legacy behavior we need to first round the result to a precision of 38 then we need - // to round the result to the final scale that we care about. - if (cast_interim_result) { - int first_div_precision = dec_precision - 38; - if (first_div_precision > 0) { - auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); - product = divide_and_round(product, first_div_scale_divisor); - - // a_scale and b_scale are negative. first_div_precision is not - mult_scale = a_scale + b_scale + first_div_precision; - } - } + int const mult_scale = [&]() { + // According to https://issues.apache.org/jira/browse/SPARK-40129 + // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in + // versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 and 4.0.0 The bug is fixed for later versions but to + // match the legacy behavior we need to first round the result to a precision of 38 then we need + // to round the result to the final scale that we care about. + if (cast_interim_result) { + int first_div_precision = dec_precision - 38; + if (first_div_precision > 0) { + auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); + product = divide_and_round(product, first_div_scale_divisor); + + // a_scale and b_scale are negative. first_div_precision is not + return a_scale + b_scale + first_div_precision; + } + } + return a_scale + b_scale; + }(); int exponent = prod_scale - mult_scale; if (exponent < 0) { diff --git a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java index ae9d9b16c4..17337691c5 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java @@ -44,7 +44,7 @@ public class DecimalUtils { * row. */ public static Table multiply128(ColumnView a, ColumnView b, int productScale) { - return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, false)); + return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, true)); } /**