Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Oct 14, 2024
1 parent 8b4ac02 commit a3216e9
Show file tree
Hide file tree
Showing 2 changed files with 526 additions and 225 deletions.
225 changes: 1 addition & 224 deletions py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
UInt32,
UInt64,
)
from polars.exceptions import ColumnNotFoundError, InvalidOperationError, ShapeError
from polars.exceptions import ColumnNotFoundError, InvalidOperationError
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES

Expand Down Expand Up @@ -610,229 +610,6 @@ def test_array_arithmetic_same_size(
)


@pytest.mark.parametrize(
("expected", "expr", "column_names"),
[
([[2, 4], [6]], lambda a, b: a + b, ("a", "a")),
([[0, 0], [0]], lambda a, b: a - b, ("a", "a")),
([[1, 4], [9]], lambda a, b: a * b, ("a", "a")),
([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")),
([[0, 0], [0]], lambda a, b: a % b, ("a", "a")),
(
[[3, 4], [7]],
lambda a, b: a + b,
("a", "uint8"),
),
],
)
def test_list_arithmetic_same_size(
expected: Any,
expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],
column_names: tuple[str, str],
) -> None:
df = pl.DataFrame(
[
pl.Series("a", [[1, 2], [3]]),
pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())),
pl.Series("nested", [[[1, 2]], [[3]]]),
pl.Series(
"nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8()))
),
]
)
# Expr-based arithmetic:
assert_frame_equal(
df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),
pl.Series(column_names[0], expected).to_frame(),
)
# Direct arithmetic on the Series:
assert_series_equal(
expr(df[column_names[0]], df[column_names[1]]),
pl.Series(column_names[0], expected),
)


@pytest.mark.parametrize(
("a", "b", "expected"),
[
([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]),
([[2], None, [5]], [None, [3], [2]], [None, None, [7]]),
],
)
def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None:
series_a = pl.Series(a)
series_b = pl.Series(b)
series_expected = pl.Series(expected)

# Same dtype:
assert_series_equal(series_a + series_b, series_expected)

# Different dtype:
assert_series_equal(
series_a._recursive_cast_to_dtype(pl.Int32())
+ series_b._recursive_cast_to_dtype(pl.Int64()),
series_expected._recursive_cast_to_dtype(pl.Int64()),
)


def test_list_arithmetic_error_cases() -> None:
# Different series length:
with pytest.raises(InvalidOperationError, match="different lengths"):
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]])
with pytest.raises(InvalidOperationError, match="different lengths"):
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None])

# Different list length:
with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]])

with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
_ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None])

# Wrong types:
with pytest.raises(
InvalidOperationError, match="add operation not supported for dtypes"
):
_ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"])

# Different nesting:
with pytest.raises(
InvalidOperationError,
match="cannot add two list columns with non-numeric inner types",
):
_ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]])


@pytest.mark.parametrize(
("expected", "expr", "column_names"),
[
# All 5 arithmetic operations:
([[3, 4], [6]], lambda a, b: a + b, ("list", "int64")),
([[-1, 0], [0]], lambda a, b: a - b, ("list", "int64")),
([[2, 4], [9]], lambda a, b: a * b, ("list", "int64")),
([[0.5, 1.0], [1.0]], lambda a, b: a / b, ("list", "int64")),
([[1, 0], [0]], lambda a, b: a % b, ("list", "int64")),
# Different types:
(
[[3, 4], [7]],
lambda a, b: a + b,
("list", "uint8"),
),
# Extra nesting + different types:
(
[[[3, 4]], [[8]]],
lambda a, b: a + b,
("nested", "int64"),
),
# Primitive numeric on the left; only addition and multiplication are
# supported:
([[3, 4], [6]], lambda a, b: a + b, ("int64", "list")),
([[2, 4], [9]], lambda a, b: a * b, ("int64", "list")),
# Primitive numeric on the left with different types:
(
[[3, 4], [7]],
lambda a, b: a + b,
("uint8", "list"),
),
(
[[2, 4], [12]],
lambda a, b: a * b,
("uint8", "list"),
),
],
)
def test_list_and_numeric_arithmetic_same_size(
expected: Any,
expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],
column_names: tuple[str, str],
) -> None:
df = pl.DataFrame(
[
pl.Series("list", [[1, 2], [3]]),
pl.Series("int64", [2, 3], dtype=pl.Int64()),
pl.Series("uint8", [2, 4], dtype=pl.UInt8()),
pl.Series("nested", [[[1, 2]], [[5]]]),
]
)
# Expr-based arithmetic:
assert_frame_equal(
df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),
pl.Series(column_names[0], expected).to_frame(),
)
# Direct arithmetic on the Series:
assert_series_equal(
expr(df[column_names[0]], df[column_names[1]]),
pl.Series(column_names[0], expected),
)


@pytest.mark.parametrize(
("a", "b", "expected"),
[
# Null on numeric on the right:
([[1, 2], [3]], [1, None], [[2, 3], None]),
# Null on list on the left:
([[[1, 2]], [[3]]], [None, 1], [None, [[4]]]),
# Extra nesting:
([[[2, None]], [[3, 6]]], [3, 4], [[[5, None]], [[7, 10]]]),
],
)
def test_list_and_numeric_arithmetic_nulls(
a: list[Any], b: list[Any], expected: list[Any]
) -> None:
series_a = pl.Series(a)
series_b = pl.Series(b)
series_expected = pl.Series(expected)

# Same dtype:
assert_series_equal(series_a + series_b, series_expected)

# Different dtype:
assert_series_equal(
series_a._recursive_cast_to_dtype(pl.Int32())
+ series_b._recursive_cast_to_dtype(pl.Int64()),
series_expected._recursive_cast_to_dtype(pl.Int64()),
)

# Swap sides:
assert_series_equal(series_b + series_a, series_expected)
assert_series_equal(
series_b._recursive_cast_to_dtype(pl.Int32())
+ series_a._recursive_cast_to_dtype(pl.Int64()),
series_expected._recursive_cast_to_dtype(pl.Int64()),
)


def test_list_and_numeric_arithmetic_error_cases() -> None:
# Different series length:
with pytest.raises(
InvalidOperationError, match="series of different lengths: got 3 and 2"
):
_ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) + pl.Series("b", [1, 2])
with pytest.raises(
InvalidOperationError, match="series of different lengths: got 3 and 2"
):
_ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) / pl.Series("b", [1, None])

# Wrong types:
with pytest.raises(
InvalidOperationError, match="they and other Series are numeric"
):
_ = pl.Series("a", [[1, 2], [3, 4]]) + pl.Series("b", ["hello", "world"])

# Numeric on right and list on left doesn't work for subtraction, division,
# or reminder, since they're not commutative operations and it seems
# semantically weird.
numeric = pl.Series("a", [1, 2])
list_num = pl.Series("b", [[3, 4], [5, 6]])
with pytest.raises(InvalidOperationError, match="operation not supported"):
numeric / list_num
with pytest.raises(InvalidOperationError, match="operation not supported"):
numeric - list_num
with pytest.raises(InvalidOperationError, match="operation not supported"):
numeric % list_num


def test_schema_owned_arithmetic_5669() -> None:
df = (
pl.LazyFrame({"A": [1, 2, 3]})
Expand Down
Loading

0 comments on commit a3216e9

Please sign in to comment.