Skip to content

Commit

Permalink
feat(python)!: Groupby iteration now returns tuples of (name, data) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jan 26, 2023
1 parent 9105d5b commit d67bfe3
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 60 deletions.
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/dataframe/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This namespace is available after calling :code:`DataFrame.groupby(...)`.
.. autosummary::
:toctree: api/

GroupBy.__iter__
GroupBy.agg
GroupBy.agg_list
GroupBy.apply
Expand Down
112 changes: 58 additions & 54 deletions py-polars/polars/internals/dataframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,7 @@


class GroupBy(Generic[DF]):
"""
Starts a new GroupBy operation.
You can also loop over this Object to loop over `DataFrames` with unique groups.
Examples
--------
>>> df = pl.DataFrame({"foo": ["a", "a", "b"], "bar": [1, 2, 3]})
>>> for group in df.groupby("foo", maintain_order=True):
... print(group)
...
shape: (2, 2)
┌─────┬─────┐
│ foo ┆ bar │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ a ┆ 1 │
│ a ┆ 2 │
└─────┴─────┘
shape: (1, 2)
┌─────┬─────┐
│ foo ┆ bar │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ b ┆ 3 │
└─────┴─────┘
"""
"""Starts a new GroupBy operation."""

def __init__(
self,
Expand Down Expand Up @@ -83,44 +54,77 @@ def __init__(
self.maintain_order = maintain_order

def __iter__(self) -> GroupBy[DF]:
warnings.warn(
"Return type of groupby iteration will change in the next breaking release."
" Iteration will return tuples of (group_key, data) instead of just data.",
category=FutureWarning,
stacklevel=2,
)
"""
Allows iteration over the groups of the groupby operation.
by = [self.by] if isinstance(self.by, (str, pli.Expr)) else self.by
Returns
-------
Iterator returning tuples of (name, data) for each group.
# Find any single column that is not specified as 'by'
columns = self._df.columns()
by_names = {c if isinstance(c, str) else c.meta.output_name() for c in by}
try:
non_by_col = next(c for c in columns if c not in by_names)
except StopIteration:
non_by_col = None
Examples
--------
>>> df = pl.DataFrame({"foo": ["a", "a", "b"], "bar": [1, 2, 3]})
>>> for name, data in df.groupby("foo"): # doctest: +SKIP
... print(name)
... print(data)
...
a
shape: (2, 2)
┌─────┬─────┐
│ foo ┆ bar │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ a ┆ 1 │
│ a ┆ 2 │
└─────┴─────┘
b
shape: (1, 2)
┌─────┬─────┐
│ foo ┆ bar │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ b ┆ 3 │
└─────┴─────┘
# Get the group indices using that column
if non_by_col is not None:
groups_df = self.agg(pli.col(non_by_col).agg_groups())
group_indices = groups_df.select(non_by_col).to_series()
"""
temp_col = "__POLARS_GB_GROUP_INDICES"
groups_df = (
pli.wrap_df(self._df)
.lazy()
.with_row_count(name=temp_col)
.groupby(self.by, maintain_order=self.maintain_order)
.agg(pli.col(temp_col).list())
.collect(no_optimization=True)
)

group_names = groups_df.select(pli.all().exclude(temp_col))

# When grouping by a single column, group name is a single value
# When grouping by multiple columns, group name is a tuple of values
self._group_names: Iterator[object] | Iterator[tuple[object, ...]]
if isinstance(self.by, (str, pli.Expr)):
self._group_names = iter(group_names.to_series())
else:
# TODO: Properly handle expression input
group_indices = pli.Series([[i] for i in range(self._df.height())])
self._group_names = group_names.iterrows()

self._group_indices = group_indices
self._group_indices = groups_df.select(temp_col).to_series()
self._current_index = 0

return self

def __next__(self) -> DF:
def __next__(self) -> tuple[object, DF] | tuple[tuple[object, ...], DF]:
if self._current_index >= len(self._group_indices):
raise StopIteration

df = self._dataframe_class._from_pydf(self._df)
group = df[self._group_indices[self._current_index]]

group_name = next(self._group_names)
group_data = df[self._group_indices[self._current_index]]
self._current_index += 1
return group

return group_name, group_data

def apply(self, f: Callable[[pli.DataFrame], pli.DataFrame]) -> DF:
"""
Expand Down
19 changes: 13 additions & 6 deletions py-polars/tests/unit/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,30 @@ def test_groupby_iteration() -> None:
"baz": [6, 5, 4, 3, 2, 1],
}
)
expected_shapes = [(2, 3), (3, 3), (1, 3)]
expected_names = ["a", "b", "c"]
expected_rows = [
[("a", 1, 6), ("a", 3, 4)],
[("b", 2, 5), ("b", 4, 3), ("b", 5, 2)],
[("c", 6, 1)],
]
for i, group in enumerate(df.groupby("foo", maintain_order=True)):
assert group.shape == expected_shapes[i]
assert group.rows() == expected_rows[i]
for i, (group, data) in enumerate(df.groupby("foo", maintain_order=True)):
assert group == expected_names[i]
assert data.rows() == expected_rows[i]

# Grouped by ALL columns should give groups of a single row
result = list(df.groupby(["foo", "bar", "baz"]))
assert len(result) == 6

# Iterating over groups should also work when grouping by expressions
result = list(df.groupby(["foo", pl.col("bar") * pl.col("baz")]))
assert len(result) == 5
result2 = list(df.groupby(["foo", pl.col("bar") * pl.col("baz")]))
assert len(result2) == 5

# Single column, alias in groupby
df = pl.DataFrame({"foo": [1, 2, 3, 4, 5, 6]})
gb = df.groupby((pl.col("foo") // 2).alias("bar"), maintain_order=True)
result3 = [(group, df.rows()) for group, df in gb]
expected3 = [(0, [(1,)]), (1, [(2,), (3,)]), (2, [(4,), (5,)]), (3, [(6,)])]
assert result3 == expected3


def bad_agg_parameters() -> list[Any]:
Expand Down

0 comments on commit d67bfe3

Please sign in to comment.