Skip to content

Commit

Permalink
Avoid using invlalid null_mask for zero null_count (NVIDIA#1364)
Browse files Browse the repository at this point in the history
Fixes NVIDIA#1363

- adds a unit test reproducing the issue
- adds a null_count check

Signed-off-by: Gera Shegalov <gera@apache.org>
  • Loading branch information
gerashegalov authored Aug 23, 2023
1 parent 82a5c2c commit 853a79a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 30 deletions.
25 changes: 15 additions & 10 deletions src/main/cpp/src/CastStringJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_toIntegersW

auto const validity_regex = strings::regex_program::create(validity_regex_str);
auto const valid_rows = strings::matches_re(input_view, *validity_regex);
auto const prepped_table = strings::extract(input_view, *validity_regex);
const strings_column_view prepped_view{prepped_table->get_column(0)};
auto int_col = [&] {
auto const int_col = [&] {
auto const prepped_table = strings::extract(input_view, *validity_regex);
const strings_column_view prepped_view{prepped_table->get_column(0)};
switch (base) {
case 10: {
return strings::to_integers(prepped_view, res_data_type);
Expand All @@ -174,14 +174,19 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_toIntegersW
auto unmatched_implies_zero = copy_if_else(*int_col, zero_scalar, *valid_rows);

// output nulls: original + all rows matching \s*

auto const space_only_regex = strings::regex_program::create(R"(^\s*$)");
auto const extra_null_rows = strings::matches_re(input_view, *space_only_regex);
auto const extra_mask = unary_operation(*extra_null_rows, unary_operator::NOT);

auto const original_mask = mask_to_bools(input_view.null_mask(), 0, input_view.size());
auto const new_mask = binary_operation(
*original_mask, *extra_mask, binary_operator::BITWISE_AND, data_type(type_id::BOOL8));
auto const new_mask = [&] {
auto const extra_null_rows = strings::matches_re(input_view, *space_only_regex);
auto extra_mask = unary_operation(*extra_null_rows, unary_operator::NOT);
if (input_view.null_count() > 0) {
return binary_operation(*mask_to_bools(input_view.null_mask(), 0, input_view.size()),
*extra_mask,
binary_operator::BITWISE_AND,
data_type(type_id::BOOL8));
} else {
return extra_mask;
}
}();

auto const [null_mask, null_count] = bools_to_mask(*new_mask);
unmatched_implies_zero->set_null_mask(*null_mask, null_count);
Expand Down
64 changes: 44 additions & 20 deletions src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,45 @@ void castToDecimalNoStripTest() {
}
}

private void convTestInternal(Table input, Table expected, int fromBase) {
try(
ColumnVector intCol = CastStrings.toIntegersWithBase(input.getColumn(0), fromBase, false,
DType.UINT64);
ColumnVector decStrCol = CastStrings.fromIntegersWithBase(intCol, 10);
ColumnVector hexStrCol = CastStrings.fromIntegersWithBase(intCol, 16);
) {
AssertUtils.assertColumnsAreEqual(expected.getColumn(0), decStrCol, "decStrCol");
AssertUtils.assertColumnsAreEqual(expected.getColumn(1), hexStrCol, "hexStrCol");
}
}

@Test
void baseDec2HexTest() {
try(
void baseDec2HexTestNoNulls() {
try (
Table input = new Table.TestBuilder().column(
"510",
"00510",
"00-510"
).build();

Table expected = new Table.TestBuilder().column(
"510",
"510",
"0"
).column(
"1FE",
"1FE",
"0"
).build()
)
{
convTestInternal(input, expected, 10);
}
}

@Test
void baseDec2HexTestMixed() {
try (
Table input = new Table.TestBuilder().column(
null,
" ",
Expand Down Expand Up @@ -229,16 +264,10 @@ void baseDec2HexTest() {
"1FE",
"1FE",
"0"
).build();

ColumnVector intCol = CastStrings.toIntegersWithBase(input.getColumn(0), 10, false,
DType.UINT64);
ColumnVector decStrCol = CastStrings.fromIntegersWithBase(intCol, 10);
ColumnVector hexStrCol = CastStrings.fromIntegersWithBase(intCol, 16);
) {
ai.rapids.cudf.TableDebug.get().debug("intCol", intCol);
AssertUtils.assertColumnsAreEqual(expected.getColumn(0), decStrCol, "decStrCol");
AssertUtils.assertColumnsAreEqual(expected.getColumn(1), hexStrCol, "hexStrCol");
).build()
)
{
convTestInternal(input, expected, 10);
}
}

Expand Down Expand Up @@ -290,14 +319,9 @@ void baseHex2DecTest() {
"0",
"0"
).build();

ColumnVector intCol = CastStrings.toIntegersWithBase(input.getColumn(0), 16, false, DType.UINT64);
ColumnVector decStrCol = CastStrings.fromIntegersWithBase(intCol, 10);
ColumnVector hexStrCol = CastStrings.fromIntegersWithBase(intCol, 16);
) {
ai.rapids.cudf.TableDebug.get().debug("intCol", intCol);
AssertUtils.assertColumnsAreEqual(expected.getColumn(0), decStrCol, "decStrCol");
AssertUtils.assertColumnsAreEqual(expected.getColumn(1), hexStrCol, "hexStrCol");
)
{
convTestInternal(input, expected, 16);
}
}
}

0 comments on commit 853a79a

Please sign in to comment.