Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add Expr.arg_true #7056

Merged
merged 1 commit into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Manipulation/selection
Expr.append
Expr.arg_sort
Expr.argsort
Expr.arg_true
Expr.backward_fill
Expr.cast
Expr.ceil
Expand Down
35 changes: 35 additions & 0 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import math
import os
import random
Expand Down Expand Up @@ -40,6 +41,9 @@
sphinx_accessor,
)

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import arg_where as py_arg_where

if TYPE_CHECKING:
import sys

Expand Down Expand Up @@ -418,6 +422,37 @@ def all(self) -> Self:
"""
return self._from_pyexpr(self._pyexpr.all())

def arg_true(self) -> Self:
"""
Return indices where expression evaluates `True`.

.. warning::
Modifies number of rows returned, so will fail in combination with other
expressions. Use as only expression in `select` / `with_columns`.

Examples
--------
>>> df = pl.DataFrame({"a": [1, 1, 2, 1]})
>>> df.select((pl.col("a") == 1).arg_true())
shape: (3, 1)
┌─────┐
│ a │
│ --- │
│ u32 │
╞═════╡
│ 0 │
│ 1 │
│ 3 │
└─────┘

See Also
--------
Series.arg_true : Return indices where Series is True
pl.arg_where

"""
return self._from_pyexpr(py_arg_where(self._pyexpr))

def sqrt(self) -> Self:
"""
Compute the square root of the elements.
Expand Down
4 changes: 4 additions & 0 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2912,6 +2912,10 @@ def arg_where(
3
]

See Also
--------
Series.arg_true : Return indices where Series is True

"""
if eager:
if not isinstance(condition, pli.Series):
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,3 +701,10 @@ def test_exclude_invalid_input(input: tuple[Any, ...]) -> None:
df = pl.DataFrame(schema=["a", "b", "c"])
with pytest.raises(TypeError):
df.select(pl.all().exclude(*input))


def test_arg_true() -> None:
df = pl.DataFrame({"a": [1, 1, 2, 1]})
res = df.select((pl.col("a") == 1).arg_true())
expected = pl.DataFrame([pl.Series("a", [0, 1, 3], dtype=pl.UInt32)])
assert_frame_equal(res, expected)