Skip to content

Commit

Permalink
test: Add more tests for list arithmetic (#19225)
Browse files Browse the repository at this point in the history
Co-authored-by: Itamar Turner-Trauring <itamar@pythonspeed.com>
  • Loading branch information
nameexhaustion and pythonspeed authored Oct 14, 2024
1 parent f7c6a05 commit 900dc3b
Show file tree
Hide file tree
Showing 2 changed files with 526 additions and 95 deletions.
95 changes: 1 addition & 94 deletions py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
UInt32,
UInt64,
)
from polars.exceptions import ColumnNotFoundError, InvalidOperationError, ShapeError
from polars.exceptions import ColumnNotFoundError, InvalidOperationError
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES

Expand Down Expand Up @@ -610,99 +610,6 @@ def test_array_arithmetic_same_size(
)


@pytest.mark.parametrize(
("expected", "expr", "column_names"),
[
([[2, 4], [6]], lambda a, b: a + b, ("a", "a")),
([[0, 0], [0]], lambda a, b: a - b, ("a", "a")),
([[1, 4], [9]], lambda a, b: a * b, ("a", "a")),
([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")),
([[0, 0], [0]], lambda a, b: a % b, ("a", "a")),
(
[[3, 4], [7]],
lambda a, b: a + b,
("a", "uint8"),
),
],
)
def test_list_arithmetic_same_size(
expected: Any,
expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],
column_names: tuple[str, str],
) -> None:
df = pl.DataFrame(
[
pl.Series("a", [[1, 2], [3]]),
pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())),
pl.Series("nested", [[[1, 2]], [[3]]]),
pl.Series(
"nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8()))
),
]
)
# Expr-based arithmetic:
assert_frame_equal(
df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),
pl.Series(column_names[0], expected).to_frame(),
)
# Direct arithmetic on the Series:
assert_series_equal(
expr(df[column_names[0]], df[column_names[1]]),
pl.Series(column_names[0], expected),
)


@pytest.mark.parametrize(
("a", "b", "expected"),
[
([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]),
([[2], None, [5]], [None, [3], [2]], [None, None, [7]]),
],
)
def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None:
series_a = pl.Series(a)
series_b = pl.Series(b)
series_expected = pl.Series(expected)

# Same dtype:
assert_series_equal(series_a + series_b, series_expected)

# Different dtype:
assert_series_equal(
series_a._recursive_cast_to_dtype(pl.Int32())
+ series_b._recursive_cast_to_dtype(pl.Int64()),
series_expected._recursive_cast_to_dtype(pl.Int64()),
)


def test_list_arithmetic_error_cases() -> None:
# Different series length:
with pytest.raises(InvalidOperationError, match="different lengths"):
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]])
with pytest.raises(InvalidOperationError, match="different lengths"):
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None])

# Different list length:
with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]])

with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
_ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None])

# Wrong types:
with pytest.raises(
InvalidOperationError, match="add operation not supported for dtypes"
):
_ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"])

# Different nesting:
with pytest.raises(
InvalidOperationError,
match="cannot add two list columns with non-numeric inner types",
):
_ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]])


def test_schema_owned_arithmetic_5669() -> None:
df = (
pl.LazyFrame({"A": [1, 2, 3]})
Expand Down
Loading

0 comments on commit 900dc3b

Please sign in to comment.