Skip to content

Commit

Permalink
feat(python): Support Numpy ufunc with more than one expression
Browse files Browse the repository at this point in the history
We use the pl.reduce trick basically, where the limitation is that non-expressions have to be passed in as kwargs. That is probably the safest anyway.

Related issues:
#6770 : brought up no support for multiple expression, have added a ValueError in response

#5713 : reminder that there is the `pl.reduce` trick
  • Loading branch information
zundertj committed Apr 1, 2023
1 parent 53d542f commit d46d45a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
14 changes: 9 additions & 5 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import random
import warnings
from datetime import timedelta
from functools import reduce
from functools import partial, reduce
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -47,7 +47,7 @@

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import arg_where as py_arg_where

from polars.polars import reduce as pyreduce
if TYPE_CHECKING:
import sys

Expand Down Expand Up @@ -282,9 +282,13 @@ def __array_ufunc__(
"""Numpy universal functions."""
num_expr = sum(isinstance(inp, Expr) for inp in inputs)
if num_expr > 1:
raise ValueError(
f"Numpy ufunc can only be used with one expression, {num_expr} given. Use `pl.reduce` to call numpy functions over multiple expressions."
)
if num_expr < len(inputs):
raise ValueError(
"Numpy ufunc with more than one expression can only be used if all non-expression inputs are provided as keyword arguments only"
)

exprs = selection_to_pyexpr_list(inputs)
return self._from_pyexpr(pyreduce(partial(ufunc, **kwargs), exprs))

def function(s: Series) -> Series: # pragma: no cover
args = [inp if not isinstance(inp, Expr) else s for inp in inputs]
Expand Down
35 changes: 35 additions & 0 deletions py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -3654,6 +3654,41 @@ def test_ufunc_expr_not_first() -> None:
assert_frame_equal(out, expected)


def test_ufunc_multiple_expressions() -> None:
# example from https://github.com/pola-rs/polars/issues/6770
df = pl.DataFrame(
{
"v": [
-4.293,
-2.4659,
-1.8378,
-0.2821,
-4.5649,
-3.8128,
-7.4274,
3.3443,
3.8604,
-4.2200,
],
"u": [
-11.2268,
6.3478,
7.1681,
3.4986,
2.7320,
-1.0695,
-10.1408,
11.2327,
6.6623,
-8.1412,
],
}
)
expected = np.arctan2(df.get_column("v"), df.get_column("u"))
result = df.select(np.arctan2(pl.col("v"), pl.col("u")))[:, 0] # type: ignore[call-overload]
assert_series_equal(expected, result) # type: ignore[arg-type]


def test_window_deadlock() -> None:
np.random.seed(12)

Expand Down

0 comments on commit d46d45a

Please sign in to comment.