Skip to content

Commit

Permalink
fix(rust): Correctly set should_broadcast flag in HStack CSE rewrite (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- authored and atigbadr committed Jul 23, 2024
1 parent 1a73f79 commit 2f6592d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
28 changes: 20 additions & 8 deletions crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,13 +798,12 @@ impl RewritingVisitor for CommonSubExprOptimizer {
ProjectionOptions {
run_parallel: options.run_parallel,
duplicate_check: options.duplicate_check,
// TODO: Somewhat of a hack, we're
// going to extend the input dataframe
// with the result of evaluating these
// expressions and then select them
// out again. That means that we don't
// want to broadcast them if they turn
// out to be scalars.
// These columns might have different
// lengths from the dataframe, but
// they are only temporaries that will
// be removed by the evaluation of the
// default_exprs and the subsequent
// projection.
should_broadcast: false,
},
)
Expand Down Expand Up @@ -839,7 +838,20 @@ impl RewritingVisitor for CommonSubExprOptimizer {
let input = *input;

let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0)
.with_columns(exprs.cse_exprs().to_vec(), options)
.with_columns(
exprs.cse_exprs().to_vec(),
// These columns might have different
// lengths from the dataframe, but they
// are only temporaries that will be
// removed by the evaluation of the
// default_exprs and the subsequent
// projection.
ProjectionOptions {
run_parallel: options.run_parallel,
duplicate_check: options.duplicate_check,
should_broadcast: false,
},
)
.with_columns(exprs.default_exprs().to_vec(), options)
.build();
let input = arena.0.add(lp);
Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,3 +744,24 @@ def test_hash_empty_series_16577() -> None:
s = pl.Series(values=None)
out = pl.LazyFrame().select(s).collect()
assert out.equals(s.to_frame())


def test_cse_non_scalar_length_mismatch_17732() -> None:
df = pl.LazyFrame({"a": pl.Series(range(30), dtype=pl.Int32)})
got = (
df.lazy()
.with_columns(
pl.col("a").head(5).min().alias("b"),
pl.col("a").head(5).max().alias("c"),
)
.collect(comm_subexpr_elim=True)
)
expect = pl.DataFrame(
{
"a": pl.Series(range(30), dtype=pl.Int32),
"b": pl.Series([0] * 30, dtype=pl.Int32),
"c": pl.Series([4] * 30, dtype=pl.Int32),
}
)

assert_frame_equal(expect, got)

0 comments on commit 2f6592d

Please sign in to comment.