Skip to content

Commit

Permalink
fix: Include predicate in cache state union
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 16, 2024
1 parent bead2eb commit 3e6bb45
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
16 changes: 13 additions & 3 deletions crates/polars-plan/src/logical_plan/optimizer/cache_states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::*;
fn get_upper_projections(
parent: Node,
lp_arena: &Arena<IR>,
expr_arena: &Arena<AExpr>,
names_scratch: &mut Vec<ColumnName>,
) -> bool {
let parent = lp_arena.get(parent);
Expand All @@ -17,7 +18,12 @@ fn get_upper_projections(
names_scratch.extend(iter);
false
},
Filter { .. } => true,
Filter { predicate, .. } => {
// Also add predicate, as the projection is above the filter node.
names_scratch.extend(aexpr_to_leaf_names(predicate.node(), expr_arena));

true
},
// Only filter and projection nodes are allowed, any other node we stop.
_ => false,
}
Expand Down Expand Up @@ -198,8 +204,12 @@ pub(super) fn set_cache_states(
let mut found_columns = false;

for parent_node in frame.parent.into_iter().flatten() {
let keep_going =
get_upper_projections(parent_node, lp_arena, &mut names_scratch);
let keep_going = get_upper_projections(
parent_node,
lp_arena,
expr_arena,
&mut names_scratch,
);
if !names_scratch.is_empty() {
found_columns = true;
v.names_union.extend(names_scratch.drain(..));
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,3 +684,17 @@ def test_cse_and_schema_update_projection_pd(capfd: Any, monkeypatch: Any) -> No
}
captured = capfd.readouterr().err
assert "1 CSE" in captured


@pytest.mark.debug()
def test_cse_predicate_self_join(capfd: Any, monkeypatch: Any) -> None:
monkeypatch.setenv("POLARS_VERBOSE", "1")
y = pl.LazyFrame({"a": [1], "b": [2], "y": [3]})

xf = y.filter(pl.col("y") == 2).select(["a", "b"])
y_xf = y.join(xf, on=["a", "b"], how="left")

y_xf_c = y_xf.select("a", "b")
assert y_xf_c.collect().to_dict(as_series=False) == {"a": [1], "b": [2]}
captured = capfd.readouterr().err
assert "CACHE HIT" in captured

0 comments on commit 3e6bb45

Please sign in to comment.