Skip to content

Commit

Permalink
fix: Incorrect is_between pushdown to scan_pyarrow_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Apr 19, 2024
1 parent 0c2783a commit fa64234
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 66 deletions.
8 changes: 4 additions & 4 deletions crates/polars-plan/src/logical_plan/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,12 @@ pub(super) fn predicate_to_pa(
} else {
let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?;
let left_cmp_op = match closed {
ClosedInterval::Both | ClosedInterval::Left => Operator::Lt,
ClosedInterval::None | ClosedInterval::Right => Operator::LtEq,
ClosedInterval::None | ClosedInterval::Right => Operator::Gt,
ClosedInterval::Both | ClosedInterval::Left => Operator::GtEq,
};
let right_cmp_op = match closed {
ClosedInterval::Both | ClosedInterval::Right => Operator::Gt,
ClosedInterval::None | ClosedInterval::Left => Operator::GtEq,
ClosedInterval::None | ClosedInterval::Left => Operator::Lt,
ClosedInterval::Both | ClosedInterval::Right => Operator::LtEq,
};

let lower = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?;
Expand Down
118 changes: 56 additions & 62 deletions py-polars/tests/unit/io/test_pyarrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@

def helper_dataset_test(
file_path: Path,
query: Callable[[pl.LazyFrame], pl.DataFrame],
query: Callable[[pl.LazyFrame], pl.LazyFrame],
batch_size: int | None = None,
n_expected: int | None = None,
) -> None:
dset = ds.dataset(file_path, format="ipc")
expected = query(pl.scan_ipc(file_path))
out = query(
pl.scan_pyarrow_dataset(dset, batch_size=batch_size),
)
expected = pl.scan_ipc(file_path).pipe(query).collect()
out = pl.scan_pyarrow_dataset(dset, batch_size=batch_size).pipe(query).collect()
assert_frame_equal(out, expected)
if n_expected is not None:
assert len(out) == n_expected
Expand All @@ -36,107 +34,105 @@ def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None:

helper_dataset_test(
file_path,
lambda lf: lf.filter("bools").select(["bools", "floats", "date"]).collect(),
lambda lf: lf.filter("bools").select("bools", "floats", "date"),
n_expected=1,
)
helper_dataset_test(
file_path,
lambda lf: lf.filter(~pl.col("bools"))
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(~pl.col("bools")).select("bools", "floats", "date"),
n_expected=2,
)
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("int_nulls").is_null())
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(pl.col("int_nulls").is_null()).select(
"bools", "floats", "date"
),
n_expected=1,
)
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("int_nulls").is_not_null())
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(pl.col("int_nulls").is_not_null()).select(
"bools", "floats", "date"
),
n_expected=2,
)
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("int_nulls").is_not_null() == pl.col("bools"))
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(
pl.col("int_nulls").is_not_null() == pl.col("bools")
).select("bools", "floats", "date"),
n_expected=0,
)
# this equality on a column with nulls fails as pyarrow has different
# handling kleene logic. We leave it for now and document it in the function.
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("int") == 10)
.select(["bools", "floats", "int_nulls"])
.collect(),
lambda lf: lf.filter(pl.col("int") == 10).select(
"bools", "floats", "int_nulls"
),
n_expected=0,
)
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("int") != 10)
.select(["bools", "floats", "int_nulls"])
.collect(),
lambda lf: lf.filter(pl.col("int") != 10).select(
"bools", "floats", "int_nulls"
),
n_expected=3,
)
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("int").is_between(9, 11))
.select(["bools", "floats", "date"])
.collect(),
n_expected=0,
)

for closed, n_expected in zip(["both", "left", "right", "none"], [3, 2, 2, 1]):
helper_dataset_test(
file_path,
lambda lf, closed=closed: lf.filter( # type: ignore[misc]
pl.col("int").is_between(1, 3, closed=closed)
).select("bools", "floats", "date"),
n_expected=n_expected,
)
# this predicate is not supported by pyarrow
# check if we still do it on our side
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("floats").sum().over("date") == 10)
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(pl.col("floats").sum().over("date") == 10).select(
"bools", "floats", "date"
),
n_expected=0,
)
# temporal types
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("date") < date(1972, 1, 1))
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(pl.col("date") < date(1972, 1, 1)).select(
"bools", "floats", "date"
),
n_expected=1,
)
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("datetime") > datetime(1970, 1, 1, second=13))
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(
pl.col("datetime") > datetime(1970, 1, 1, second=13)
).select("bools", "floats", "date"),
n_expected=1,
)
# not yet supported in pyarrow
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("time") >= time(microsecond=100))
.select(["bools", "time", "date"])
.collect(),
lambda lf: lf.filter(pl.col("time") >= time(microsecond=100)).select(
"bools", "time", "date"
),
n_expected=3,
)
# pushdown is_in
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("int").is_in([1, 3, 20]))
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(pl.col("int").is_in([1, 3, 20])).select(
"bools", "floats", "date"
),
n_expected=2,
)
helper_dataset_test(
file_path,
lambda lf: lf.filter(
pl.col("date").is_in([date(1973, 8, 17), date(1973, 5, 19)])
)
.select(["bools", "floats", "date"])
.collect(),
).select("bools", "floats", "date"),
n_expected=2,
)
helper_dataset_test(
Expand All @@ -148,40 +144,38 @@ def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None:
datetime(1970, 1, 1, 0, 0, 13, 241324),
]
)
)
.select(["bools", "floats", "date"])
.collect(),
).select("bools", "floats", "date"),
n_expected=2,
)
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("int").is_in(list(range(120))))
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(pl.col("int").is_in(list(range(120)))).select(
"bools", "floats", "date"
),
n_expected=3,
)
# TODO: remove string cache
with pl.StringCache():
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("cat").is_in([]))
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(pl.col("cat").is_in([])).select(
"bools", "floats", "date"
),
n_expected=0,
)
helper_dataset_test(
file_path,
lambda lf: lf.select(pl.exclude("enum")).collect(),
lambda lf: lf.select(pl.exclude("enum")),
batch_size=2,
n_expected=3,
)

# direct filter
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.Series([True, False, True]))
.select(["bools", "floats", "date"])
.collect(),
lambda lf: lf.filter(pl.Series([True, False, True])).select(
"bools", "floats", "date"
),
n_expected=2,
)

Expand Down

0 comments on commit fa64234

Please sign in to comment.