Skip to content

Commit

Permalink
Maintain the input index in the result of a groupby-transform (NVIDIA…
Browse files Browse the repository at this point in the history
…#11068)

I believe this should close NVIDIA#11067, but I'm unable to reproduce the original bug locally. Will report back here once I'm able to do that.

Edit: it does.

Authors:
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Bradley Dice (https://github.com/bdice)

URL: rapidsai/cudf#11068
  • Loading branch information
shwina authored Jun 8, 2022
1 parent a00cca6 commit 6be6466
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
7 changes: 6 additions & 1 deletion python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Iterable, List, Tuple, Union

import numpy as np
import pandas as pd

import cudf
from cudf._lib import groupby as libgroupby
Expand Down Expand Up @@ -822,7 +823,7 @@ def transform(self, function):
result = result._align_to_index(
self.grouping.keys, how="right", allow_non_unique=True
)
result = result.reset_index(drop=True)
result.index = self.obj.index
return result

def rolling(self, *args, **kwargs):
Expand Down Expand Up @@ -1663,6 +1664,10 @@ def _handle_by_or_level(self, by=None, level=None):
self._handle_mapping(by)
elif isinstance(by, Grouper):
self._handle_grouper(by)
elif isinstance(by, pd.Series):
self._handle_series(cudf.Series.from_pandas(by))
elif isinstance(by, pd.Index):
self._handle_index(cudf.Index.from_pandas(by))
else:
try:
self._handle_label(by)
Expand Down
23 changes: 22 additions & 1 deletion python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,16 @@ def test_groupby_column_numeral():

@pytest.mark.parametrize(
"series",
[[0, 1, 0], [1, 1, 1], [0, 1, 1], [1, 2, 3], [4, 3, 2], [0, 2, 0]],
[
[0, 1, 0],
[1, 1, 1],
[0, 1, 1],
[1, 2, 3],
[4, 3, 2],
[0, 2, 0],
pd.Series([0, 2, 0]),
pd.Series([0, 2, 0], index=[0, 2, 1]),
],
) # noqa: E501
def test_groupby_external_series(series):
pdf = pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1, 2, 1]})
Expand Down Expand Up @@ -2581,3 +2590,15 @@ def test_groupby_select_then_diff():


# TODO: Add a test including datetime64[ms] column in input data


@pytest.mark.parametrize("by", ["a", ["a", "b"], pd.Series([1, 2, 1, 3])])
def test_groupby_transform_maintain_index(by):
# test that we maintain the index after a groupby transform
gdf = cudf.DataFrame(
{"a": [1, 1, 1, 2], "b": [1, 2, 1, 2]}, index=[3, 2, 1, 0]
)
pdf = gdf.to_pandas()
assert_groupby_results_equal(
pdf.groupby(by).transform("max"), gdf.groupby(by).transform("max")
)

0 comments on commit 6be6466

Please sign in to comment.