Skip to content

Commit

Permalink
chore(python): Rename pivot aggregate_fn to aggregate_function (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Feb 21, 2023
1 parent ee20b36 commit e703ac2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 30 deletions.
47 changes: 24 additions & 23 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5131,15 +5131,16 @@ def explode(
self.lazy().explode(columns).collect(no_optimization=True)._df
)

@deprecated_alias(aggregate_fn="aggregate_function")
@deprecate_nonkeyword_arguments(
allowed_args=["self", "values", "index", "columns", "aggregate_fn"]
allowed_args=["self", "values", "index", "columns", "aggregate_function"]
)
def pivot(
self,
values: Sequence[str] | str,
index: Sequence[str] | str,
columns: Sequence[str] | str,
aggregate_fn: PivotAgg | pli.Expr = "first",
aggregate_function: PivotAgg | pli.Expr = "first",
maintain_order: bool = True,
sort_columns: bool = False,
separator: str = "_",
Expand All @@ -5157,7 +5158,7 @@ def pivot(
columns
Name of the column(s) whose values will be used as the header of the output
DataFrame.
aggregate_fn : {'first', 'sum', 'max', 'min', 'mean', 'median', 'last', 'count'}
aggregate_function : {'first', 'sum', 'max', 'min', 'mean', 'median', 'last', 'count'}
A predefined aggregate function str or an expression.
maintain_order
Sort the grouped keys so that the output order is predictable.
Expand Down Expand Up @@ -5190,42 +5191,42 @@ def pivot(
│ two ┆ 4 ┆ 5 ┆ 6 │
└─────┴─────┴─────┴─────┘
"""
""" # noqa: W505
if isinstance(values, str):
values = [values]
if isinstance(index, str):
index = [index]
if isinstance(columns, str):
columns = [columns]

if isinstance(aggregate_fn, str):
if aggregate_fn == "first":
aggregate_fn = pli.element().first()
elif aggregate_fn == "sum":
aggregate_fn = pli.element().sum()
elif aggregate_fn == "max":
aggregate_fn = pli.element().max()
elif aggregate_fn == "min":
aggregate_fn = pli.element().min()
elif aggregate_fn == "mean":
aggregate_fn = pli.element().mean()
elif aggregate_fn == "median":
aggregate_fn = pli.element().median()
elif aggregate_fn == "last":
aggregate_fn = pli.element().last()
elif aggregate_fn == "count":
aggregate_fn = pli.count()
if isinstance(aggregate_function, str):
if aggregate_function == "first":
aggregate_function = pli.element().first()
elif aggregate_function == "sum":
aggregate_function = pli.element().sum()
elif aggregate_function == "max":
aggregate_function = pli.element().max()
elif aggregate_function == "min":
aggregate_function = pli.element().min()
elif aggregate_function == "mean":
aggregate_function = pli.element().mean()
elif aggregate_function == "median":
aggregate_function = pli.element().median()
elif aggregate_function == "last":
aggregate_function = pli.element().last()
elif aggregate_function == "count":
aggregate_function = pli.count()
else:
raise ValueError(
f"Argument aggregate fn: '{aggregate_fn}' " f"was not expected."
f"Invalid input for `aggregate_function` argument: {aggregate_function!r}"
)

return self._from_pydf(
self._df.pivot_expr(
values,
index,
columns,
aggregate_fn._pyexpr,
aggregate_function._pyexpr,
maintain_order,
sort_columns,
separator,
Expand Down
20 changes: 13 additions & 7 deletions py-polars/tests/unit/operations/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def test_pivot_list() -> None:
"3": [None, None, [3, 3]],
}
)
out = df.pivot("b", index="a", columns="a", aggregate_fn="first", sort_columns=True)
out = df.pivot(
"b", index="a", columns="a", aggregate_function="first", sort_columns=True
)
assert_frame_equal(out, expected)


Expand All @@ -69,7 +71,7 @@ def test_pivot_aggregate(agg_fn: PivotAgg, expected_rows: list[tuple[Any]]) -> N
}
)
result = df.pivot(
values="c", index="b", columns="a", aggregate_fn=agg_fn, sort_columns=True
values="c", index="b", columns="a", aggregate_function=agg_fn, sort_columns=True
)
assert result.rows() == expected_rows

Expand Down Expand Up @@ -98,12 +100,14 @@ def test_pivot_categorical_index() -> None:
schema=[("A", pl.Categorical), ("B", pl.Categorical)],
)

result = df.pivot(values="B", index=["A"], columns="B", aggregate_fn="count")
result = df.pivot(values="B", index=["A"], columns="B", aggregate_function="count")
expected = {"A": ["Fire", "Water"], "Car": [1, 2], "Ship": [1, None]}
assert result.to_dict(False) == expected

# test expression dispatch
result = df.pivot(values="B", index=["A"], columns="B", aggregate_fn=pl.count())
result = df.pivot(
values="B", index=["A"], columns="B", aggregate_function=pl.count()
)
assert result.to_dict(False) == expected

df = pl.DataFrame(
Expand All @@ -114,7 +118,9 @@ def test_pivot_categorical_index() -> None:
},
schema=[("A", pl.Categorical), ("B", pl.Categorical), ("C", pl.Categorical)],
)
result = df.pivot(values="B", index=["A", "C"], columns="B", aggregate_fn="count")
result = df.pivot(
values="B", index=["A", "C"], columns="B", aggregate_function="count"
)
expected = {
"A": ["Fire", "Water"],
"C": ["Paper", "Paper"],
Expand Down Expand Up @@ -184,7 +190,7 @@ def test_pivot_reinterpret_5907() -> None:
)

result = df.pivot(
index=["A"], values=["C"], columns=["B"], aggregate_fn=pl.element().sum()
index=["A"], values=["C"], columns=["B"], aggregate_function=pl.element().sum()
)
expected = {"A": [3, -2], "x": [100, 50], "y": [500, -80]}
assert result.to_dict(False) == expected
Expand All @@ -195,7 +201,7 @@ class SubClassedDataFrame(pl.DataFrame):
pass

df = SubClassedDataFrame({"a": [1, 2], "b": [3, 4]})
result = df.pivot(values="b", index="a", columns="a", aggregate_fn="first")
result = df.pivot(values="b", index="a", columns="a", aggregate_function="first")
assert isinstance(result, SubClassedDataFrame)


Expand Down

0 comments on commit e703ac2

Please sign in to comment.