From 6be6466066f393152aaa0a0f28930ff8e0855bef Mon Sep 17 00:00:00 2001 From: Ashwin Srinath <3190405+shwina@users.noreply.github.com> Date: Wed, 8 Jun 2022 17:58:45 -0400 Subject: [PATCH] Maintain the input index in the result of a groupby-transform (#11068) I believe this should close #11067, but I'm unable to reproduce the original bug locally. Will report back here once I'm able to do that. Edit: it does. Authors: - Ashwin Srinath (https://github.com/shwina) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/11068 --- python/cudf/cudf/core/groupby/groupby.py | 7 ++++++- python/cudf/cudf/tests/test_groupby.py | 23 ++++++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index c68cd66acc0..b5538a0f0a8 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -8,6 +8,7 @@ from typing import Any, Iterable, List, Tuple, Union import numpy as np +import pandas as pd import cudf from cudf._lib import groupby as libgroupby @@ -822,7 +823,7 @@ def transform(self, function): result = result._align_to_index( self.grouping.keys, how="right", allow_non_unique=True ) - result = result.reset_index(drop=True) + result.index = self.obj.index return result def rolling(self, *args, **kwargs): @@ -1663,6 +1664,10 @@ def _handle_by_or_level(self, by=None, level=None): self._handle_mapping(by) elif isinstance(by, Grouper): self._handle_grouper(by) + elif isinstance(by, pd.Series): + self._handle_series(cudf.Series.from_pandas(by)) + elif isinstance(by, pd.Index): + self._handle_index(cudf.Index.from_pandas(by)) else: try: self._handle_label(by) diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index b1625b5f67e..b4f52d452af 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -597,7 +597,16 @@ def test_groupby_column_numeral(): @pytest.mark.parametrize( "series", - [[0, 1, 0], [1, 1, 1], [0, 1, 1], [1, 2, 3], [4, 3, 2], [0, 2, 0]], + [ + [0, 1, 0], + [1, 1, 1], + [0, 1, 1], + [1, 2, 3], + [4, 3, 2], + [0, 2, 0], + pd.Series([0, 2, 0]), + pd.Series([0, 2, 0], index=[0, 2, 1]), + ], ) # noqa: E501 def test_groupby_external_series(series): pdf = pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1, 2, 1]}) @@ -2581,3 +2590,15 @@ def test_groupby_select_then_diff(): # TODO: Add a test including datetime64[ms] column in input data + + +@pytest.mark.parametrize("by", ["a", ["a", "b"], pd.Series([1, 2, 1, 3])]) +def test_groupby_transform_maintain_index(by): + # test that we maintain the index after a groupby transform + gdf = cudf.DataFrame( + {"a": [1, 1, 1, 2], "b": [1, 2, 1, 2]}, index=[3, 2, 1, 0] + ) + pdf = gdf.to_pandas() + assert_groupby_results_equal( + pdf.groupby(by).transform("max"), gdf.groupby(by).transform("max") + )