Skip to content

Commit

Permalink
Add post-processing steps to `dask_cudf.groupby.CudfSeriesGroupby.agg…
Browse files Browse the repository at this point in the history
…regate` (#8694)

Closes #8655 

Adds some post-processing steps to Dask-cuDF's series groupby when using the optimized codepaths for aggregations, to match [those done by Dask](https://github.com/dask/dask/blob/8601b540f8e7eac95fa739a5ca28f1d707299ed0/dask/dataframe/groupby.py#L1917-L1921). These ensure that a `dask_cudf.Series` is always returned for the groupby operation, which was a problem observed in #8655.

Authors:
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

Approvers:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

URL: #8694
  • Loading branch information
charlesbluca authored Jul 12, 2021
1 parent 0b9ea01 commit 7823a18
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 14 deletions.
2 changes: 1 addition & 1 deletion python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def aggregate(self, arg, split_every=None, split_out=1):
sep=self.sep,
sort=self.sort,
as_index=self.as_index,
)
)[self._slice]

return super().aggregate(
arg, split_every=split_every, split_out=split_out
Expand Down
16 changes: 3 additions & 13 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def test_groupby_basic_aggs(aggregation):
"func",
[
lambda df: df.groupby("x").agg({"y": "max"}),
pytest.param(
lambda df: df.groupby("x").y.agg(["sum", "max"]),
marks=pytest.mark.skip,
),
lambda df: df.groupby("x").y.agg(["sum", "max"]),
],
)
def test_groupby_agg(func):
Expand Down Expand Up @@ -98,7 +95,6 @@ def test_groupby_agg_empty_partition(tmpdir, split_out):
dd.assert_eq(gb.compute().sort_index(), expect)


@pytest.mark.xfail(reason="cudf issues")
@pytest.mark.parametrize(
"func",
[lambda df: df.groupby("x").std(), lambda df: df.groupby("x").y.std()],
Expand All @@ -115,23 +111,17 @@ def test_groupby_std(func):

ddf = dask_cudf.from_cudf(gdf, npartitions=5)

a = func(gdf.to_pandas())
a = func(gdf).to_pandas()
b = func(ddf).compute().to_pandas()

a.index.name = None
a.name = None
b.index.name = None

dd.assert_eq(a, b)


@pytest.mark.parametrize(
"func",
[
lambda df: df.groupby("x").agg({"y": "collect"}),
pytest.param(
lambda df: df.groupby("x").y.agg("collect"), marks=pytest.mark.skip
),
lambda df: df.groupby("x").y.agg("collect"),
],
)
def test_groupby_collect(func):
Expand Down

0 comments on commit 7823a18

Please sign in to comment.