Skip to content

Commit

Permalink
Fix rolling-window count for null input
Browse files Browse the repository at this point in the history
  • Loading branch information
mythrocks committed Sep 29, 2020
1 parent 937fd7e commit 8211a68
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
- PR #6304 Fix span_tests.cu includes
- PR #6331 Avoids materializing `RangeIndex` during frame concatnation (when not needed)
- PR #6278 Add filter tests for struct columns
- PR #6344 Fix rolling-window count for null input


# cuDF 0.15.0 (26 Aug 2020)
Expand Down
40 changes: 36 additions & 4 deletions cpp/src/rolling/rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace cudf {
namespace detail {
namespace { // anonymous
/**
* @brief Only count operation is executed and count is updated
* @brief Only COUNT_VALID operation is executed and count is updated
* depending on `min_periods` and returns true if it was
* valid, else false.
*/
Expand All @@ -58,7 +58,7 @@ template <typename InputType,
typename agg_op,
aggregation::Kind op,
bool has_nulls>
std::enable_if_t<op == aggregation::COUNT_VALID || op == aggregation::COUNT_ALL, bool> __device__
std::enable_if_t<op == aggregation::COUNT_VALID, bool> __device__
process_rolling_window(column_device_view input,
mutable_column_device_view output,
size_type start_index,
Expand All @@ -70,10 +70,42 @@ process_rolling_window(column_device_view input,
// for CUDA 10.0 and below (fixed in CUDA 10.1)
volatile cudf::size_type count = 0;

for (size_type j = start_index; j < end_index; j++) {
if (op == aggregation::COUNT_ALL || !has_nulls || input.is_valid(j)) { count++; }
if (!has_nulls) {
count = end_index - start_index;
} else {
for (size_type j = start_index; j < end_index; j++) {
if (input.is_valid(j)) { count++; }
}
}

bool output_is_valid = ((end_index - start_index) >= min_periods);
output.element<OutputType>(current_index) = count;

return output_is_valid;
}

/**
* @brief Only COUNT_ALL operation is executed and count is updated
* depending on `min_periods` and returns true if it was
* valid, else false.
*/
template <typename InputType,
typename OutputType,
typename agg_op,
aggregation::Kind op,
bool has_nulls>
std::enable_if_t<op == aggregation::COUNT_ALL, bool> __device__
process_rolling_window(column_device_view input,
mutable_column_device_view output,
size_type start_index,
size_type end_index,
size_type current_index,
size_type min_periods)
{
// declare this as volatile to avoid some compiler optimizations that lead to incorrect results
// for CUDA 10.0 and below (fixed in CUDA 10.1)
volatile cudf::size_type count = end_index - start_index;

bool output_is_valid = (count >= min_periods);
output.element<OutputType>(current_index) = count;

Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/grouped_rolling/grouped_rolling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class GroupedRollingTest : public cudf::test::BaseFixture {
if (include_nulls || !input.nullable() || cudf::bit_is_set(valid_mask, j)) count++;
}

ref_valid[i] = (count >= min_periods);
ref_valid[i] = ((end_index - start_index) >= min_periods);
if (ref_valid[i]) ref_data[i] = count;
}

Expand Down Expand Up @@ -861,7 +861,7 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture {
if (include_nulls || !input.nullable() || cudf::bit_is_set(valid_mask, j)) count++;
}

ref_valid[i] = (count >= min_periods);
ref_valid[i] = ((end_index - start_index) >= min_periods);
if (ref_valid[i]) ref_data[i] = count;
}

Expand Down
6 changes: 3 additions & 3 deletions cpp/tests/rolling/rolling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ TEST_F(RollingStringTest, MinPeriods)
cudf::test::strings_column_wrapper expected_max(
{"This", "test", "test", "test", "test", "string", "string", "string", "string"},
{0, 0, 0, 0, 1, 1, 1, 0, 0});
fixed_width_column_wrapper<size_type> expected_count_val({0, 2, 2, 2, 3, 3, 3, 3, 2},
{0, 0, 0, 0, 1, 1, 1, 0, 0});
fixed_width_column_wrapper<size_type> expected_count_val({1, 2, 1, 2, 3, 3, 3, 2, 2},
{1, 1, 1, 1, 1, 1, 1, 1, 0});
fixed_width_column_wrapper<size_type> expected_count_all({3, 4, 4, 4, 4, 4, 4, 3, 2},
{0, 1, 1, 1, 1, 1, 1, 0, 0});

Expand Down Expand Up @@ -248,7 +248,7 @@ class RollingTest : public cudf::test::BaseFixture {
if (include_nulls || !input.nullable() || cudf::bit_is_set(valid_mask, j)) count++;
}

ref_valid[i] = (count >= min_periods);
ref_valid[i] = ((end_index - start_index) >= min_periods);
if (ref_valid[i]) ref_data[i] = count;
}

Expand Down

0 comments on commit 8211a68

Please sign in to comment.