From 900dc3b60c0ba050d9b19c936b772e101fda830e Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Mon, 14 Oct 2024 22:18:20 +1100 Subject: [PATCH] test: Add more tests for list arithmetic (#19225) Co-authored-by: Itamar Turner-Trauring --- .../operations/arithmetic/test_arithmetic.py | 95 +--- .../arithmetic/test_list_arithmetic.py | 526 +++++++++++++++++- 2 files changed, 526 insertions(+), 95 deletions(-) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index b04cb92b8889..f989bc4681e2 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -21,7 +21,7 @@ UInt32, UInt64, ) -from polars.exceptions import ColumnNotFoundError, InvalidOperationError, ShapeError +from polars.exceptions import ColumnNotFoundError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES @@ -610,99 +610,6 @@ def test_array_arithmetic_same_size( ) -@pytest.mark.parametrize( - ("expected", "expr", "column_names"), - [ - ([[2, 4], [6]], lambda a, b: a + b, ("a", "a")), - ([[0, 0], [0]], lambda a, b: a - b, ("a", "a")), - ([[1, 4], [9]], lambda a, b: a * b, ("a", "a")), - ([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")), - ([[0, 0], [0]], lambda a, b: a % b, ("a", "a")), - ( - [[3, 4], [7]], - lambda a, b: a + b, - ("a", "uint8"), - ), - ], -) -def test_list_arithmetic_same_size( - expected: Any, - expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], - column_names: tuple[str, str], -) -> None: - df = pl.DataFrame( - [ - pl.Series("a", [[1, 2], [3]]), - pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())), - pl.Series("nested", [[[1, 2]], [[3]]]), - pl.Series( - "nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8())) - ), - ] - ) - # Expr-based arithmetic: - assert_frame_equal( - df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), - pl.Series(column_names[0], expected).to_frame(), - ) - # Direct arithmetic on the Series: - assert_series_equal( - expr(df[column_names[0]], df[column_names[1]]), - pl.Series(column_names[0], expected), - ) - - -@pytest.mark.parametrize( - ("a", "b", "expected"), - [ - ([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]), - ([[2], None, [5]], [None, [3], [2]], [None, None, [7]]), - ], -) -def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None: - series_a = pl.Series(a) - series_b = pl.Series(b) - series_expected = pl.Series(expected) - - # Same dtype: - assert_series_equal(series_a + series_b, series_expected) - - # Different dtype: - assert_series_equal( - series_a._recursive_cast_to_dtype(pl.Int32()) - + series_b._recursive_cast_to_dtype(pl.Int64()), - series_expected._recursive_cast_to_dtype(pl.Int64()), - ) - - -def test_list_arithmetic_error_cases() -> None: - # Different series length: - with pytest.raises(InvalidOperationError, match="different lengths"): - _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) - with pytest.raises(InvalidOperationError, match="different lengths"): - _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None]) - - # Different list length: - with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"): - _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]]) - - with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"): - _ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None]) - - # Wrong types: - with pytest.raises( - InvalidOperationError, match="add operation not supported for dtypes" - ): - _ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"]) - - # Different nesting: - with pytest.raises( - InvalidOperationError, - match="cannot add two list columns with non-numeric inner types", - ): - _ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]]) - - def test_schema_owned_arithmetic_5669() -> None: df = ( pl.LazyFrame({"A": [1, 2, 3]}) diff --git a/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py index c2a9a9186d31..64ad0e533d8d 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_list_arithmetic.py @@ -7,7 +7,7 @@ import polars as pl from polars.exceptions import InvalidOperationError, ShapeError -from polars.testing import assert_series_equal +from polars.testing import assert_frame_equal, assert_series_equal def exec_op_with_series(lhs: pl.Series, rhs: pl.Series, op: Any) -> pl.Series: @@ -528,3 +528,527 @@ def test_list_date_to_numeric_arithmetic_raises_error( # is being raised by checks on the Python side that should be moved to Rust. with pytest.raises((InvalidOperationError, TypeError)): exec_op(l, r, op) + + +@pytest.mark.parametrize( + ("expected", "expr", "column_names"), + [ + ([[2, 4], [6]], lambda a, b: a + b, ("a", "a")), + ([[0, 0], [0]], lambda a, b: a - b, ("a", "a")), + ([[1, 4], [9]], lambda a, b: a * b, ("a", "a")), + ([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")), + ([[0, 0], [0]], lambda a, b: a % b, ("a", "a")), + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("a", "uint8"), + ), + ], +) +def test_list_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", [[1, 2], [3]]), + pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())), + pl.Series("nested", [[[1, 2]], [[3]]]), + pl.Series( + "nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8())) + ), + ] + ) + # Expr-based arithmetic: + assert_frame_equal( + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), + ) + + +@pytest.mark.parametrize( + ("a", "b", "expected"), + [ + ([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]), + ([[2], None, [5]], [None, [3], [2]], [None, None, [7]]), + ], +) +def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None: + series_a = pl.Series(a) + series_b = pl.Series(b) + series_expected = pl.Series(expected) + + # Same dtype: + assert_series_equal(series_a + series_b, series_expected) + + # Different dtype: + assert_series_equal( + series_a._recursive_cast_to_dtype(pl.Int32()) + + series_b._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + +def test_list_arithmetic_error_cases() -> None: + # Different series length: + with pytest.raises(InvalidOperationError, match="different lengths"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) + with pytest.raises(InvalidOperationError, match="different lengths"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None]) + + # Different list length: + with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]]) + + with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"): + _ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None]) + + # Wrong types: + with pytest.raises( + InvalidOperationError, match="add operation not supported for dtypes" + ): + _ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"]) + + # Different nesting: + with pytest.raises( + InvalidOperationError, + match="cannot add two list columns with non-numeric inner types", + ): + _ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]]) + + +@pytest.mark.parametrize( + ("expected", "expr", "column_names"), + [ + # All 5 arithmetic operations: + ([[3, 4], [6]], lambda a, b: a + b, ("list", "int64")), + ([[-1, 0], [0]], lambda a, b: a - b, ("list", "int64")), + ([[2, 4], [9]], lambda a, b: a * b, ("list", "int64")), + ([[0.5, 1.0], [1.0]], lambda a, b: a / b, ("list", "int64")), + ([[1, 0], [0]], lambda a, b: a % b, ("list", "int64")), + # Different types: + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("list", "uint8"), + ), + # Extra nesting + different types: + ( + [[[3, 4]], [[8]]], + lambda a, b: a + b, + ("nested", "int64"), + ), + # Primitive numeric on the left; only addition and multiplication are + # supported: + ([[3, 4], [6]], lambda a, b: a + b, ("int64", "list")), + ([[2, 4], [9]], lambda a, b: a * b, ("int64", "list")), + # Primitive numeric on the left with different types: + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("uint8", "list"), + ), + ( + [[2, 4], [12]], + lambda a, b: a * b, + ("uint8", "list"), + ), + ], +) +def test_list_and_numeric_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("list", [[1, 2], [3]]), + pl.Series("int64", [2, 3], dtype=pl.Int64()), + pl.Series("uint8", [2, 4], dtype=pl.UInt8()), + pl.Series("nested", [[[1, 2]], [[5]]]), + ] + ) + # Expr-based arithmetic: + assert_frame_equal( + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), + ) + + +@pytest.mark.parametrize( + ("a", "b", "expected"), + [ + # Null on numeric on the right: + ([[1, 2], [3]], [1, None], [[2, 3], [None]]), + # Null on list on the left: + ([[[1, 2]], [[3]]], [None, 1], [[[None, None]], [[4]]]), + # Extra nesting: + ([[[2, None]], [[3, 6]]], [3, 4], [[[5, None]], [[7, 10]]]), + ], +) +def test_list_and_numeric_arithmetic_nulls( + a: list[Any], b: list[Any], expected: list[Any] +) -> None: + series_a = pl.Series(a) + series_b = pl.Series(b) + series_expected = pl.Series(expected, dtype=series_a.dtype) + + # Same dtype: + assert_series_equal(series_a + series_b, series_expected) + + # Different dtype: + assert_series_equal( + series_a._recursive_cast_to_dtype(pl.Int32()) + + series_b._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + # Swap sides: + assert_series_equal(series_b + series_a, series_expected) + assert_series_equal( + series_b._recursive_cast_to_dtype(pl.Int32()) + + series_a._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + +def test_list_and_numeric_arithmetic_error_cases() -> None: + # Different series length: + with pytest.raises( + InvalidOperationError, match="series of different lengths: got 3 and 2" + ): + _ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) + pl.Series("b", [1, 2]) + with pytest.raises( + InvalidOperationError, match="series of different lengths: got 3 and 2" + ): + _ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) / pl.Series("b", [1, None]) + + # Wrong types: + with pytest.raises( + InvalidOperationError, match="add operation not supported for dtypes" + ): + _ = pl.Series("a", [[1, 2], [3, 4]]) + pl.Series("b", ["hello", "world"]) + + +@pytest.mark.parametrize("broadcast", [True, False]) +@pytest.mark.parametrize("dtype", [pl.Int64(), pl.Float64()]) +def test_list_arithmetic_div_ops_zero_denominator( + broadcast: bool, dtype: pl.DataType +) -> None: + # Notes + # * truediv (/) on integers upcasts to Float64 + # * Otherwise, we test floordiv (//) and module/rem (%) + # * On integers, 0-denominator is expected to output NULL + # * On floats, 0-denominator has different outputs, e.g. NaN, Inf, depending + # on a few factors (e.g. whether the numerator is also 0). + + s = pl.Series([[0], [1], [None], None]).cast(pl.List(dtype)) + + n = 1 if broadcast else s.len() + + # list<->primitive + + # truediv + assert_series_equal( + pl.Series([1]).new_from_index(0, n) / s, + pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)), + ) + + assert_series_equal( + s / pl.Series([1]).new_from_index(0, n), + pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)), + ) + + # floordiv + assert_series_equal( + pl.Series([1]).new_from_index(0, n) // s, + ( + pl.Series([[None], [1], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s // pl.Series([0]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("inf")], [None], None], dtype=s.dtype + ) + ), + ) + + # rem + assert_series_equal( + pl.Series([1]).new_from_index(0, n) % s, + ( + pl.Series([[None], [0], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s % pl.Series([0]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("nan")], [None], None], dtype=s.dtype + ) + ), + ) + + # list<->list + + # truediv + assert_series_equal( + pl.Series([[1]]).new_from_index(0, n) / s, + pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)), + ) + + assert_series_equal( + s / pl.Series([[0]]).new_from_index(0, n), + pl.Series( + [[float("nan")], [float("inf")], [None], None], dtype=pl.List(pl.Float64) + ), + ) + + # floordiv + assert_series_equal( + pl.Series([[1]]).new_from_index(0, n) // s, + ( + pl.Series([[None], [1], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s // pl.Series([[0]]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("inf")], [None], None], dtype=s.dtype + ) + ), + ) + + # rem + assert_series_equal( + pl.Series([[1]]).new_from_index(0, n) % s, + ( + pl.Series([[None], [0], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s % pl.Series([[0]]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("nan")], [None], None], dtype=s.dtype + ) + ), + ) + + +def test_list_to_primitive_arithmetic() -> None: + # Input data + # * List type: List(List(List(Int16))) (triple-nested) + # * Numeric type: Int32 + # + # Tests run + # Broadcast Operation + # | L | R | + # * list<->primitive | | | floor_div + # * primitive<->list | | | floor_div + # * list<->primitive | | * | subtract + # * primitive<->list | * | | subtract + # * list<->primitive | * | | subtract + # * primitive<->list | | * | subtract + # + # Notes + # * In floor_div, we check that results from a 0 denominator are masked out + # * We choose floor_div and subtract as they emit different results when + # sides are swapped + + # Create some non-zero start offsets and masked out rows. + lhs = ( + pl.Series( + [ + [[[None, None, None, None, None]]], # sliced out + # Nulls at every level XO + [[[3, 7]], [[-3], [None], [], [], None], [], None], + [[[1, 2, 3, 4, 5]]], # masked out + [[[3, 7]], [[0], [None], [], [], None]], + [[[3, 7]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int16))), + ) + .slice(1) + .to_frame() + .select(pl.when(pl.int_range(pl.len()) != 1).then(pl.first())) + .to_series() + ) + + # Note to reader: This is what our LHS looks like + assert_series_equal( + lhs, + pl.Series( + [ + [[[3, 7]], [[-3], [None], [], [], None], [], None], + None, + [[[3, 7]], [[0], [None], [], [], None]], + [[[3, 7]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int16))), + ), + ) + + class _: + # Floor div, no broadcasting + rhs = pl.Series([5, 1, 0, None], dtype=pl.Int32) + + assert len(lhs) == len(rhs) + + expect = pl.Series( + [ + [[[0, 1]], [[-1], [None], [], [], None], [], None], + None, + [[[None, None]], [[None], [None], [], [], None]], + [[[None, None]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = ( + pl.select(l=lhs, r=rhs) + .select(pl.col("l") // pl.col("r")) + .to_series() + .alias("") + ) + + assert_series_equal(out, expect) + + # Flipped + + expect = pl.Series( # noqa: PIE794 + [ + [[[1, 0]], [[-2], [None], [], [], None], [], None], + None, + [[[0, 0]], [[None], [None], [], [], None]], + [[[None, None]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = ( # noqa: PIE794 + pl.select(l=lhs, r=rhs) + .select(pl.col("r") // pl.col("l")) + .to_series() + .alias("") + ) + + assert_series_equal(out, expect) + + class _: # type: ignore[no-redef] + # Subtraction with broadcasting + rhs = pl.Series([1], dtype=pl.Int32) + + expect = pl.Series( + [ + [[[2, 6]], [[-4], [None], [], [], None], [], None], + None, + [[[2, 6]], [[-1], [None], [], [], None]], + [[[2, 6]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(l=lhs).select(pl.col("l") - rhs).to_series().alias("") + + assert_series_equal(out, expect) + + # Flipped + + expect = pl.Series( # noqa: PIE794 + [ + [[[-2, -6]], [[4], [None], [], [], None], [], None], + None, + [[[-2, -6]], [[1], [None], [], [], None]], + [[[-2, -6]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(l=lhs).select(rhs - pl.col("l")).to_series().alias("") # noqa: PIE794 + + assert_series_equal(out, expect) + + # Test broadcasting of the list side + lhs = lhs.slice(2, 1) + # Note to reader: This is what our LHS looks like + assert_series_equal( + lhs, + pl.Series( + [ + [[[3, 7]], [[0], [None], [], [], None]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int16))), + ), + ) + + assert len(lhs) == 1 + + class _: # type: ignore[no-redef] + rhs = pl.Series([1, 2, 3, None, 5], dtype=pl.Int32) + + expect = pl.Series( + [ + [[[2, 6]], [[-1], [None], [], [], None]], + [[[1, 5]], [[-2], [None], [], [], None]], + [[[0, 4]], [[-3], [None], [], [], None]], + [[[None, None]], [[None], [None], [], [], None]], + [[[-2, 2]], [[-5], [None], [], [], None]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(r=rhs).select(lhs - pl.col("r")).to_series().alias("") + + assert_series_equal(out, expect) + + # Flipped + + expect = pl.Series( # noqa: PIE794 + [ + [[[-2, -6]], [[1], [None], [], [], None]], + [[[-1, -5]], [[2], [None], [], [], None]], + [[[0, -4]], [[3], [None], [], [], None]], + [[[None, None]], [[None], [None], [], [], None]], + [[[2, -2]], [[5], [None], [], [], None]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(r=rhs).select(pl.col("r") - lhs).to_series().alias("") # noqa: PIE794 + + assert_series_equal(out, expect)