Skip to content

Commit

Permalink
feat(python): Improve Series & Numpy arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Feb 18, 2023
1 parent e3d21cf commit dd8ebe3
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
5 changes: 4 additions & 1 deletion py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Series:
other = self.to_frame().select(other).to_series()
if isinstance(other, Series):
return wrap_s(getattr(self._s, op_s)(other._s))

if _check_for_numpy(other) and isinstance(other, np.ndarray):
return wrap_s(getattr(self._s, op_s)(Series(other)._s))
# recurse; the 'if' statement above will ensure we return early
if isinstance(other, (date, datetime, timedelta, str)):
other = Series("", [other])
Expand Down Expand Up @@ -653,6 +654,8 @@ def __pow__(self, power: int | float | Series) -> Series:
raise ValueError(
"first cast to integer before raising datelike dtypes to a power"
)
if _check_for_numpy(power) and isinstance(power, np.ndarray):
power = Series(power)
return self.to_frame().select(pli.col(self.name).pow(power)).to_series()

def __rpow__(self, other: Any) -> Series:
Expand Down
38 changes: 38 additions & 0 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,3 +2443,41 @@ def test_upper_lower_bounds(
s = pl.Series("s", dtype=dtype)
assert s.lower_bound().item() == lower
assert s.upper_bound().item() == upper


def test_numpy_series_arithmetic() -> None:
sx = pl.Series(values=[1, 2])
y = np.array([3.0, 4.0])

result_add1 = y + sx
result_add2 = sx + y
expected_add = pl.Series([4.0, 6.0], dtype=pl.Float64)
assert_series_equal(result_add1, expected_add) # type: ignore[arg-type]
assert_series_equal(result_add2, expected_add)

result_sub1 = cast(pl.Series, y - sx) # py37 is different vs py311 on this one
expected = pl.Series([2.0, 2.0], dtype=pl.Float64)
assert_series_equal(result_sub1, expected)
result_sub2 = sx - y
expected = pl.Series([-2.0, -2.0], dtype=pl.Float64)
assert_series_equal(result_sub2, expected)

result_mul1 = y * sx
result_mul2 = sx * y
expected = pl.Series([3.0, 8.0], dtype=pl.Float64)
assert_series_equal(result_mul1, expected) # type: ignore[arg-type]
assert_series_equal(result_mul2, expected)

result_div1 = y / sx
expected = pl.Series([3.0, 2.0], dtype=pl.Float64)
assert_series_equal(result_div1, expected) # type: ignore[arg-type]
result_div2 = sx / y
expected = pl.Series([1 / 3, 0.5], dtype=pl.Float64)
assert_series_equal(result_div2, expected)

result_pow1 = y**sx
expected = pl.Series([3.0, 16.0], dtype=pl.Float64)
assert_series_equal(result_pow1, expected) # type: ignore[arg-type]
result_pow2 = sx**y
expected = pl.Series([1.0, 16.0], dtype=pl.Float64)
assert_series_equal(result_pow2, expected) # type: ignore[arg-type]

0 comments on commit dd8ebe3

Please sign in to comment.