Skip to content

Commit

Permalink
test(python): Use assert_series_equal instead of `s.series_equal(..…
Browse files Browse the repository at this point in the history
….)` (#6582)
  • Loading branch information
stinodego authored Jan 31, 2023
1 parent 307140e commit 628311e
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 163 deletions.
6 changes: 3 additions & 3 deletions py-polars/tests/unit/io/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import cast

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal


def test_copy() -> None:
Expand All @@ -13,8 +13,8 @@ def test_copy() -> None:
assert_frame_equal(copy.deepcopy(df), df)

a = pl.Series("a", [1, 2])
assert copy.copy(a).series_equal(a, True)
assert copy.deepcopy(a).series_equal(a, True)
assert_series_equal(copy.copy(a), a)
assert_series_equal(copy.deepcopy(a), a)


def test_categorical_round_trip() -> None:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/io/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import pickle

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal


def test_pickle() -> None:
a = pl.Series("a", [1, 2])
b = pickle.dumps(a)
out = pickle.loads(b)
assert a.series_equal(out)
assert_series_equal(a, out)
df = pl.DataFrame({"a": [1, 2], "b": ["a", None], "c": [True, False]})
b = pickle.dumps(df)
out = pickle.loads(b)
Expand Down
22 changes: 13 additions & 9 deletions py-polars/tests/unit/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

import polars as pl
from polars.datatypes import DATETIME_DTYPES, DTYPE_TEMPORAL_UNITS, PolarsTemporalType
from polars.testing import assert_frame_equal, assert_series_equal
from polars.testing import (
assert_frame_equal,
assert_series_equal,
assert_series_not_equal,
)

if TYPE_CHECKING:
from polars.internals.type_aliases import TimeUnit
Expand Down Expand Up @@ -152,7 +156,7 @@ def test_series_add_timedelta() -> None:
out = pl.Series(
[datetime(2027, 5, 19), datetime(2054, 10, 4), datetime(2082, 2, 19)]
)
assert (dates + timedelta(days=10_000)).series_equal(out)
assert_series_equal((dates + timedelta(days=10_000)), out)


def test_series_add_datetime() -> None:
Expand Down Expand Up @@ -342,18 +346,17 @@ def test_timezone() -> None:
data = pa.array([1000, 2000], type=ts)
s = cast(pl.Series, pl.from_arrow(data))

# with timezone; we do expect a warning here
tz_ts = pa.timestamp("s", tz="America/New_York")
tz_data = pa.array([1000, 2000], type=tz_ts)
# with pytest.warns(Warning):
tz_s = cast(pl.Series, pl.from_arrow(tz_data))

# different timezones are not considered equal
# we check both `null_equal=True` and `null_equal=False`
# https://github.com/pola-rs/polars/issues/5023
assert not s.series_equal(tz_s, null_equal=False)
assert not s.series_equal(tz_s, null_equal=True)
assert s.cast(int).series_equal(tz_s.cast(int))
assert_series_not_equal(tz_s, s)
assert_series_equal(s.cast(int), tz_s.cast(int))


def test_to_list() -> None:
Expand Down Expand Up @@ -906,11 +909,12 @@ def test_epoch() -> None:
dates = pl.Series("dates", [datetime(2001, 1, 1), datetime(2001, 2, 1, 10, 8, 9)])

for unit in DTYPE_TEMPORAL_UNITS:
assert dates.dt.epoch(unit).series_equal(dates.dt.timestamp(unit))
assert_series_equal(dates.dt.epoch(unit), dates.dt.timestamp(unit))

assert dates.dt.epoch("s").series_equal(dates.dt.timestamp("ms") // 1000)
assert dates.dt.epoch("d").series_equal(
(dates.dt.timestamp("ms") // (1000 * 3600 * 24)).cast(pl.Int32)
assert_series_equal(dates.dt.epoch("s"), dates.dt.timestamp("ms") // 1000)
assert_series_equal(
dates.dt.epoch("d"),
(dates.dt.timestamp("ms") // (1000 * 3600 * 24)).cast(pl.Int32),
)


Expand Down
21 changes: 12 additions & 9 deletions py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def test_concat() -> None:

def test_arg_where() -> None:
s = pl.Series([True, False, True, False])
assert pl.arg_where(s, eager=True).cast(int).series_equal(pl.Series([0, 2]))
assert_series_equal(pl.arg_where(s, eager=True).cast(int), pl.Series([0, 2]))


def test_get_dummies() -> None:
Expand Down Expand Up @@ -810,14 +810,17 @@ def test_df_stats(df: pl.DataFrame) -> None:
def test_df_fold() -> None:
df = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]})

assert df.fold(lambda s1, s2: s1 + s2).series_equal(pl.Series("a", [4.0, 5.0, 9.0]))
assert df.fold(lambda s1, s2: s1.zip_with(s1 < s2, s2)).series_equal(
pl.Series("a", [1.0, 1.0, 3.0])
assert_series_equal(
df.fold(lambda s1, s2: s1 + s2), pl.Series("a", [4.0, 5.0, 9.0])
)
assert_series_equal(
df.fold(lambda s1, s2: s1.zip_with(s1 < s2, s2)),
pl.Series("a", [1.0, 1.0, 3.0]),
)

df = pl.DataFrame({"a": ["foo", "bar", "2"], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]})
out = df.fold(lambda s1, s2: s1 + s2)
out.series_equal(pl.Series("", ["foo11", "bar22", "233"]))
assert_series_equal(out, pl.Series("a", ["foo11.0", "bar22.0", "233.0"]))

df = pl.DataFrame({"a": [3, 2, 1], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]})
# just check dispatch. values are tested on rust side.
Expand All @@ -827,7 +830,7 @@ def test_df_fold() -> None:
assert len(df.max(axis=1)) == 3

df_width_one = df[["a"]]
assert df_width_one.fold(lambda s1, s2: s1).series_equal(df["a"])
assert_series_equal(df_width_one.fold(lambda s1, s2: s1), df["a"])


def test_df_apply() -> None:
Expand Down Expand Up @@ -1847,13 +1850,13 @@ def test_shift_and_fill() -> None:

def test_is_duplicated() -> None:
df = pl.DataFrame({"foo": [1, 2, 2], "bar": [6, 7, 7]})
assert df.is_duplicated().series_equal(pl.Series("", [False, True, True]))
assert_series_equal(df.is_duplicated(), pl.Series("", [False, True, True]))


def test_is_unique() -> None:
df = pl.DataFrame({"foo": [1, 2, 2], "bar": [6, 7, 7]})

assert df.is_unique().series_equal(pl.Series("", [True, False, False]))
assert_series_equal(df.is_unique(), pl.Series("", [True, False, False]))
assert df.unique(maintain_order=True).rows() == [(1, 6), (2, 7)]
assert df.n_unique() == 2

Expand Down Expand Up @@ -1979,7 +1982,7 @@ def test_get_item() -> None:
assert_frame_equal(df[:, :], df)

# str, always refers to a column name
assert df["a"].series_equal(pl.Series("a", [1.0, 2.0, 3.0, 4.0]))
assert_series_equal(df["a"], pl.Series("a", [1.0, 2.0, 3.0, 4.0]))

# int, always refers to a row index (zero-based): index=1 => second row
assert_frame_equal(df[1], pl.DataFrame({"a": [2.0], "b": [4]}))
Expand Down
6 changes: 3 additions & 3 deletions py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_shuffle() -> None:
result1 = pl.select(pl.lit(s).shuffle()).to_series()
random.seed(1)
result2 = pl.select(pl.lit(s).shuffle()).to_series()
assert result1.series_equal(result2)
assert_series_equal(result1, result2)


def test_sample() -> None:
Expand All @@ -124,7 +124,7 @@ def test_sample() -> None:
result1 = pl.select(pl.lit(a).sample(n=10)).to_series()
random.seed(1)
result2 = pl.select(pl.lit(a).sample(n=10)).to_series()
assert result1.series_equal(result2)
assert_series_equal(result1, result2)


def test_map_alias() -> None:
Expand Down Expand Up @@ -474,7 +474,7 @@ def test_rank_so_4109() -> None:
def test_unique_empty() -> None:
for dt in [pl.Utf8, pl.Boolean, pl.Int32, pl.UInt32]:
s = pl.Series([], dtype=dt)
assert s.unique().series_equal(s)
assert_series_equal(s.unique(), s)


@typing.no_type_check
Expand Down
7 changes: 4 additions & 3 deletions py-polars/tests/unit/test_folds.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import polars as pl
from polars.testing import assert_series_equal


def test_fold() -> None:
Expand All @@ -10,9 +11,9 @@ def test_fold() -> None:
pl.min(["a", pl.col("b") ** 2]),
]
)
assert out["sum"].series_equal(pl.Series("sum", [2.0, 4.0, 6.0]))
assert out["max"].series_equal(pl.Series("max", [1.0, 4.0, 9.0]))
assert out["min"].series_equal(pl.Series("min", [1.0, 2.0, 3.0]))
assert_series_equal(out["sum"], pl.Series("sum", [2.0, 4.0, 6.0]))
assert_series_equal(out["max"], pl.Series("max", [1.0, 4.0, 9.0]))
assert_series_equal(out["min"], pl.Series("min", [1.0, 2.0, 3.0]))

out = df.select(
pl.fold(acc=pl.lit(0), f=lambda acc, x: acc + x, exprs=pl.all()).alias("foo")
Expand Down
6 changes: 3 additions & 3 deletions py-polars/tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal


def test_date_datetime() -> None:
Expand All @@ -23,8 +23,8 @@ def test_date_datetime() -> None:
pl.date("year", "month", "day").dt.day().cast(int).alias("date"),
]
)
assert out["date"].series_equal(df["day"].rename("date"))
assert out["h2"].series_equal(df["hour"].rename("h2"))
assert_series_equal(out["date"], df["day"].rename("date"))
assert_series_equal(out["h2"], df["hour"].rename("h2"))


def test_diag_concat() -> None:
Expand Down
7 changes: 4 additions & 3 deletions py-polars/tests/unit/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import polars as pl
from polars.datatypes import dtype_to_py_type
from polars.testing import assert_series_equal


def test_df_from_numpy() -> None:
Expand Down Expand Up @@ -195,7 +196,7 @@ def test_arrow_dict_to_polars() -> None:
values=["AAA", "BBB", "CCC", "DDD", "BBB", "AAA", "CCC", "DDD", "DDD", "CCC"],
)

assert s.series_equal(pl.Series("pa_dict", pa_dict))
assert_series_equal(s, pl.Series("pa_dict", pa_dict))


def test_arrow_list_chunked_array() -> None:
Expand Down Expand Up @@ -244,7 +245,7 @@ def test_from_dict() -> None:
df = pl.from_dict(data)
assert df.shape == (2, 2)
for s1, s2 in zip(list(df), [pl.Series("a", [1, 2]), pl.Series("b", [3, 4])]):
assert s1.series_equal(s2)
assert_series_equal(s1, s2)


def test_from_dict_struct() -> None:
Expand Down Expand Up @@ -555,7 +556,7 @@ def test_cat_int_types_3500() -> None:

for t in [int_dict_type, uint_dict_type]:
s = cast(pl.Series, pl.from_arrow(pyarrow_array.cast(t)))
assert s.series_equal(pl.Series(["a", "a", "b"]).cast(pl.Categorical))
assert_series_equal(s, pl.Series(["a", "a", "b"]).cast(pl.Categorical))


def test_from_pyarrow_chunked_array() -> None:
Expand Down
6 changes: 3 additions & 3 deletions py-polars/tests/unit/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from polars.internals.type_aliases import JoinStrategy
Expand Down Expand Up @@ -339,11 +339,11 @@ def test_join() -> None:
)

joined = df_left.join(df_right, left_on="a", right_on="a").sort("a")
assert joined["b"].series_equal(pl.Series("b", [1, 3, 2, 2]))
assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2]))

joined = df_left.join(df_right, left_on="a", right_on="a", how="left").sort("a")
assert joined["c_right"].is_null().sum() == 1
assert joined["b"].series_equal(pl.Series("b", [1, 3, 2, 2, 4]))
assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2, 4]))

joined = df_left.join(df_right, left_on="a", right_on="a", how="outer").sort("a")
assert joined["c_right"].null_count() == 1
Expand Down
Loading

0 comments on commit 628311e

Please sign in to comment.