Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Fix rolling-window count for null input #6344

Merged
merged 20 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8211a68
Fix rolling-window count for null input
mythrocks Sep 29, 2020
4da93f5
Merge remote-tracking branch 'origin/branch-0.16' into window-count-fix
mythrocks Sep 30, 2020
4d3c7b1
Rolling Window count fix: Fixing python tests.
mythrocks Sep 30, 2020
59a36f3
Merge remote-tracking branch 'origin/branch-0.16' into window-count-fix
mythrocks Oct 1, 2020
3b55909
Rolling Window count fix: Using marks for xfail.
mythrocks Oct 1, 2020
90aa2a9
Rolling Window count fix: Fixed formatting
mythrocks Oct 1, 2020
2be5049
Merge remote-tracking branch 'origin/branch-0.16' into window-count-fix
mythrocks Oct 2, 2020
160d9b2
Rolling Window count fix: Review fixes:
mythrocks Oct 2, 2020
1298b0c
Rolling Window count fix: Review fixes:
mythrocks Oct 6, 2020
8697cf1
Merge remote-tracking branch 'origin/branch-0.16' into window-count-fix
mythrocks Oct 6, 2020
9a4b168
Rolling Window count fix: Review fixes:
mythrocks Oct 6, 2020
45003cd
Merge remote-tracking branch 'origin/branch-0.16' into window-count-fix
mythrocks Oct 6, 2020
6264fcc
Merge remote-tracking branch 'origin/branch-0.16' into window-count-fix
mythrocks Oct 7, 2020
c982404
Merge remote-tracking branch 'origin/branch-0.16' into window-count-fix
mythrocks Oct 8, 2020
d901023
Rolling Window count fix: Fixed formatting
mythrocks Oct 8, 2020
f7a8c90
Rolling Window count fix: Fixed process_rolling_window for COUNT_ALL
mythrocks Oct 8, 2020
f34835b
Rolling Window count fix: Use is_valid_no_check().
mythrocks Oct 8, 2020
62709b8
Rolling Window count fix: Python review
mythrocks Oct 12, 2020
7f9b391
Merge remote-tracking branch 'origin/branch-0.16' into window-count-fix
mythrocks Oct 12, 2020
6445c5d
Rolling Window count fix: Python review
mythrocks Oct 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
- PR #6304 Fix span_tests.cu includes
- PR #6331 Avoids materializing `RangeIndex` during frame concatnation (when not needed)
- PR #6278 Add filter tests for struct columns
- PR #6344 Fix rolling-window count for null input
- PR #6353 Rename `skip_rows` parameter to `skiprows` in `read_parquet`, `read_avro` and `read_orc`
- PR #6361 Detect overflow in hash join
- PR #6397 Fix `build.sh` when `PARALLEL_LEVEL` environment variable isn't set
Expand Down
40 changes: 36 additions & 4 deletions cpp/src/rolling/rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace cudf {
namespace detail {
namespace { // anonymous
/**
* @brief Only count operation is executed and count is updated
* @brief Only COUNT_VALID operation is executed and count is updated
* depending on `min_periods` and returns true if it was
* valid, else false.
*/
Expand All @@ -58,7 +58,7 @@ template <typename InputType,
typename agg_op,
aggregation::Kind op,
bool has_nulls>
std::enable_if_t<op == aggregation::COUNT_VALID || op == aggregation::COUNT_ALL, bool> __device__
std::enable_if_t<op == aggregation::COUNT_VALID, bool> __device__
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
process_rolling_window(column_device_view input,
mutable_column_device_view output,
size_type start_index,
Expand All @@ -70,10 +70,42 @@ process_rolling_window(column_device_view input,
// for CUDA 10.0 and below (fixed in CUDA 10.1)
volatile cudf::size_type count = 0;

for (size_type j = start_index; j < end_index; j++) {
if (op == aggregation::COUNT_ALL || !has_nulls || input.is_valid(j)) { count++; }
bool output_is_valid = ((end_index - start_index) >= min_periods);

if (output_is_valid) {
if (!has_nulls) {
count = end_index - start_index;
} else {
for (size_type j = start_index; j < end_index; j++) {
if (input.is_valid(j)) { count++; }
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
}
}
output.element<OutputType>(current_index) = count;
}

return output_is_valid;
}

/**
* @brief Only COUNT_ALL operation is executed and count is updated
* depending on `min_periods` and returns true if it was
* valid, else false.
*/
template <typename InputType,
typename OutputType,
typename agg_op,
aggregation::Kind op,
bool has_nulls>
std::enable_if_t<op == aggregation::COUNT_ALL, bool> __device__
process_rolling_window(column_device_view input,
mutable_column_device_view output,
size_type start_index,
size_type end_index,
size_type current_index,
size_type min_periods)
{
cudf::size_type count = end_index - start_index;

bool output_is_valid = (count >= min_periods);
output.element<OutputType>(current_index) = count;

Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/grouped_rolling/grouped_rolling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class GroupedRollingTest : public cudf::test::BaseFixture {
if (include_nulls || !input.nullable() || cudf::bit_is_set(valid_mask, j)) count++;
}

ref_valid[i] = (count >= min_periods);
ref_valid[i] = ((end_index - start_index) >= min_periods);
if (ref_valid[i]) ref_data[i] = count;
}

Expand Down Expand Up @@ -861,7 +861,7 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture {
if (include_nulls || !input.nullable() || cudf::bit_is_set(valid_mask, j)) count++;
}

ref_valid[i] = (count >= min_periods);
ref_valid[i] = ((end_index - start_index) >= min_periods);
if (ref_valid[i]) ref_data[i] = count;
}

Expand Down
6 changes: 3 additions & 3 deletions cpp/tests/rolling/rolling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ TEST_F(RollingStringTest, MinPeriods)
cudf::test::strings_column_wrapper expected_max(
{"This", "test", "test", "test", "test", "string", "string", "string", "string"},
{0, 0, 0, 0, 1, 1, 1, 0, 0});
fixed_width_column_wrapper<size_type> expected_count_val({0, 2, 2, 2, 3, 3, 3, 3, 2},
{0, 0, 0, 0, 1, 1, 1, 0, 0});
fixed_width_column_wrapper<size_type> expected_count_val({1, 2, 1, 2, 3, 3, 3, 2, 2},
{1, 1, 1, 1, 1, 1, 1, 1, 0});
fixed_width_column_wrapper<size_type> expected_count_all({3, 4, 4, 4, 4, 4, 4, 3, 2},
{0, 1, 1, 1, 1, 1, 1, 0, 0});

Expand Down Expand Up @@ -248,7 +248,7 @@ class RollingTest : public cudf::test::BaseFixture {
if (include_nulls || !input.nullable() || cudf::bit_is_set(valid_mask, j)) count++;
}

ref_valid[i] = (count >= min_periods);
ref_valid[i] = ((end_index - start_index) >= min_periods);
if (ref_valid[i]) ref_data[i] = count;
}

Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,6 @@ def _apply_agg_dataframe(self, df, agg_name):
return result_df

def _apply_agg(self, agg_name):
if agg_name == "count" and not self._time_window:
self.min_periods = 0
if isinstance(self.obj, cudf.Series):
return self._apply_agg_series(self.obj, agg_name)
else:
Expand Down Expand Up @@ -388,6 +386,8 @@ def _window_to_window_sizes(self, window):
)

def _apply_agg(self, agg_name):
if agg_name == "count" and not self._time_window:
self.min_periods = 0
index = cudf.MultiIndex.from_frame(
cudf.DataFrame(
{
Expand Down
33 changes: 16 additions & 17 deletions python/cudf/cudf/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,7 @@ def test_rollling_series_basic(data, index, agg, nulls, center):
got = getattr(
gsr.rolling(window_size, min_periods, center), agg
)().fillna(-1)
try:
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
assert_eq(expect, got, check_dtype=False, **kwargs)
except AssertionError as e:
if agg == "count" and data != []:
pytest.xfail(
reason="Differ from Pandas behavior for count"
)
else:
raise e
assert_eq(expect, got, check_dtype=False, **kwargs)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -97,17 +89,24 @@ def test_rolling_dataframe_basic(data, agg, nulls, center):
got = getattr(
gdf.rolling(window_size, min_periods, center), agg
)().fillna(-1)
try:
assert_eq(expect, got, check_dtype=False)
except AssertionError as e:
if agg == "count" and len(pdf) > 0:
pytest.xfail(reason="Differ from pandas behavior here")
else:
raise e
assert_eq(expect, got, check_dtype=False)


@pytest.mark.parametrize(
"agg", ["sum", pytest.param("min"), pytest.param("max"), "mean", "count"]
"agg",
[
"sum",
pytest.param("min"),
pytest.param("max"),
"mean",
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
pytest.param(
"count", # Does not follow similar conventions as
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
# with non-offset columns
marks=pytest.mark.xfail(
reason="Differs from pandas behaviour here"
),
),
],
)
def test_rolling_with_offset(agg):
psr = pd.Series(
Expand Down