Skip to content

Commit

Permalink
Fixing incorrect invalidation of positive cudf scale decimal values (N…
Browse files Browse the repository at this point in the history
…VIDIA#611)

* fixing incorrect invalidation of positive cudf scale decimal values

* signoff

Signed-off-by: Mike Wilson <knobby@burntsheep.com>

* typo fix

Signed-off-by: Mike Wilson <knobby@burntsheep.com>

* Update src/main/cpp/src/cast_string.cu

Co-authored-by: Nghia Truong <nghiatruong.vn@gmail.com>

Signed-off-by: Mike Wilson <knobby@burntsheep.com>
Co-authored-by: Nghia Truong <nghiatruong.vn@gmail.com>
  • Loading branch information
hyperbolic2346 and ttnghia authored Oct 5, 2022
1 parent f124469 commit 3080c53
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 11 deletions.
7 changes: 4 additions & 3 deletions src/main/cpp/src/cast_string.cu
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,9 @@ __global__ void string_to_decimal_kernel(T* out,
}
}

auto const significant_preceeding_zeros = decimal_location < 0 ? abs(decimal_location) : 0;
auto const zeros_to_decimal = std::max(0, decimal_location - total_digits);
auto const significant_preceding_zeros = decimal_location < 0 ? -decimal_location : 0;
auto const zeros_to_decimal = std::max(
0, scale > 0 ? decimal_location - total_digits - scale : decimal_location - total_digits);
auto const significant_digits_before_decimal =
significant_digits_before_decimal_in_string + zeros_to_decimal + rounding_digits;

Expand Down Expand Up @@ -551,7 +552,7 @@ __global__ void string_to_decimal_kernel(T* out,
// thread_value: 12
// result -> 1200
auto const digits_after_decimal =
num_precise_digits - significant_digits_before_decimal + significant_preceeding_zeros;
num_precise_digits - significant_digits_before_decimal + significant_preceding_zeros;
auto const digits_needed_after_decimal =
min(precision - significant_digits_before_decimal, -scale);

Expand Down
58 changes: 50 additions & 8 deletions src/main/cpp/tests/cast_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,17 +302,38 @@ TEST_F(StringToDecimalTests, ExponentalNotation)

TEST_F(StringToDecimalTests, PositiveScale)
{
auto const strings =
test::strings_column_wrapper({"1234e-1", "12345e1", "-1234.5678", "-0.001234567890123456e6"});
strings_column_view scv{strings};
{
auto const strings =
test::strings_column_wrapper({"1234e-1", "12345e1", "-1234.5678", "-0.001234567890123456e6"});
strings_column_view scv{strings};

auto const result =
spark_rapids_jni::string_to_decimal(6, 2, scv, false, rmm::cuda_stream_default);
auto const result =
spark_rapids_jni::string_to_decimal(6, 2, scv, false, rmm::cuda_stream_default);

test::fixed_point_column_wrapper<int32_t> expected(
{100, 123500, -1200, -1200}, {1, 1, 1, 1}, numeric::scale_type{2});
test::fixed_point_column_wrapper<int32_t> expected(
{1, 1235, -12, -12}, {1, 1, 1, 1}, numeric::scale_type{2});

CUDF_TEST_EXPECT_COLUMNS_EQUAL(result->view(), expected);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(result->view(), expected);
}

{
auto const strings = test::strings_column_wrapper(
{"813847339", "043469773", "548977048", "985946604", "325679554", "null", "957413342",
"541903389", "150050891", "663968655", "976832602", "757172936", "968693314", "106046331",
"965120263", "354546567", "108127101", "339513621", "980338159", "593267777"});
strings_column_view scv{strings};

auto const result =
spark_rapids_jni::string_to_decimal(8, 3, scv, false, rmm::cuda_stream_default);

test::fixed_point_column_wrapper<int32_t> expected(
{813847, 43470, 548977, 985947, 325680, 0, 957413, 541903, 150051, 663969,
976833, 757173, 968693, 106046, 965120, 354547, 108127, 339514, 980338, 593268},
{1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
numeric::scale_type{3});

CUDF_TEST_EXPECT_COLUMNS_EQUAL(result->view(), expected);
}
}

TEST_F(StringToDecimalTests, Edges)
Expand Down Expand Up @@ -472,4 +493,25 @@ TEST_F(StringToDecimalTests, Edges)

CUDF_TEST_EXPECT_COLUMNS_EQUAL(result->view(), expected);
}
{
auto const str = test::strings_column_wrapper({"1234567809"});
strings_column_view scv{str};

auto const result =
spark_rapids_jni::string_to_decimal(8, 3, scv, false, rmm::cuda_stream_default);
test::fixed_point_column_wrapper<int32_t> expected({1234568}, {1}, numeric::scale_type{3});

CUDF_TEST_EXPECT_COLUMNS_EQUAL(result->view(), expected);
}
{
auto const str = test::strings_column_wrapper({"4347202159", "4347802159"});
strings_column_view scv{str};

auto const result =
spark_rapids_jni::string_to_decimal(4, 6, scv, false, rmm::cuda_stream_default);
test::fixed_point_column_wrapper<int32_t> expected(
{4347, 4348}, {1, 1}, numeric::scale_type{6});

CUDF_TEST_EXPECT_COLUMNS_EQUAL(result->view(), expected);
}
}

0 comments on commit 3080c53

Please sign in to comment.