Skip to content

Commit

Permalink
fix: Enable CSE in eager if struct are expanded (pola-rs#18426)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored and r-brink committed Aug 29, 2024
1 parent 88d499c commit 89cceb0
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 32 deletions.
2 changes: 0 additions & 2 deletions crates/polars-plan/src/frame/opt_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ bitflags! {
const FILE_CACHING = 1 << 6;
/// Pushdown slices/limits.
const SLICE_PUSHDOWN = 1 << 7;
#[cfg(feature = "cse")]
/// Run common-subplan-elimination. This elides duplicate plans and caches their
/// outputs.
const COMM_SUBPLAN_ELIM = 1 << 8;
#[cfg(feature = "cse")]
/// Run common-subexpression-elimination. This elides duplicate expressions and caches their
/// outputs.
const COMM_SUBEXPR_ELIM = 1 << 9;
Expand Down
70 changes: 49 additions & 21 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ fn expand_expressions(
exprs: Vec<Expr>,
lp_arena: &Arena<IR>,
expr_arena: &mut Arena<AExpr>,
opt_flags: &mut OptFlags,
) -> PolarsResult<Vec<ExprIR>> {
let schema = lp_arena.get(input).schema(lp_arena);
let exprs = rewrite_projections(exprs, &schema, &[])?;
let exprs = rewrite_projections(exprs, &schema, &[], opt_flags)?;
to_expr_irs(exprs, expr_arena)
}

Expand Down Expand Up @@ -57,17 +58,18 @@ pub fn to_alp(
expr_arena: &mut Arena<AExpr>,
lp_arena: &mut Arena<IR>,
// Only `SIMPLIFY_EXPR` and `TYPE_COERCION` are respected.
opt_state: &mut OptFlags,
opt_flags: &mut OptFlags,
) -> PolarsResult<Node> {
let conversion_optimizer = ConversionOptimizer::new(
opt_state.contains(OptFlags::SIMPLIFY_EXPR),
opt_state.contains(OptFlags::TYPE_COERCION),
opt_flags.contains(OptFlags::SIMPLIFY_EXPR),
opt_flags.contains(OptFlags::TYPE_COERCION),
);

let mut ctxt = ConversionContext {
expr_arena,
lp_arena,
conversion_optimizer,
opt_flags,
};

to_alp_impl(lp, &mut ctxt)
Expand All @@ -77,6 +79,7 @@ struct ConversionContext<'a> {
expr_arena: &'a mut Arena<AExpr>,
lp_arena: &'a mut Arena<IR>,
conversion_optimizer: ConversionOptimizer,
opt_flags: &'a mut OptFlags,
}

/// converts LogicalPlan to IR
Expand Down Expand Up @@ -305,7 +308,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
DslPlan::Filter { input, predicate } => {
let mut input =
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(filter)))?;
let predicate = expand_filter(predicate, input, ctxt.lp_arena)
let predicate = expand_filter(predicate, input, ctxt.lp_arena, ctxt.opt_flags)
.map_err(|e| e.context(failed_here!(filter)))?;

let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?;
Expand Down Expand Up @@ -378,8 +381,8 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
let input =
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(select)))?;
let schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
let (exprs, schema) =
prepare_projection(expr, &schema).map_err(|e| e.context(failed_here!(select)))?;
let (exprs, schema) = prepare_projection(expr, &schema, ctxt.opt_flags)
.map_err(|e| e.context(failed_here!(select)))?;

if exprs.is_empty() {
ctxt.lp_arena.replace(input, empty_df());
Expand Down Expand Up @@ -442,8 +445,14 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
.cycle()
.zip(sort_options.descending.iter().cycle()),
) {
let exprs = expand_expressions(input, vec![c], ctxt.lp_arena, ctxt.expr_arena)
.map_err(|e| e.context(failed_here!(sort)))?;
let exprs = expand_expressions(
input,
vec![c],
ctxt.lp_arena,
ctxt.expr_arena,
ctxt.opt_flags,
)
.map_err(|e| e.context(failed_here!(sort)))?;

nulls_last.extend(std::iter::repeat(n).take(exprs.len()));
descending.extend(std::iter::repeat(d).take(exprs.len()));
Expand Down Expand Up @@ -489,9 +498,16 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
let input =
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(group_by)))?;

let (keys, aggs, schema) =
resolve_group_by(input, keys, aggs, &options, ctxt.lp_arena, ctxt.expr_arena)
.map_err(|e| e.context(failed_here!(group_by)))?;
let (keys, aggs, schema) = resolve_group_by(
input,
keys,
aggs,
&options,
ctxt.lp_arena,
ctxt.expr_arena,
ctxt.opt_flags,
)
.map_err(|e| e.context(failed_here!(group_by)))?;

let (apply, schema) = if let Some((apply, schema)) = apply {
(Some(apply), schema)
Expand Down Expand Up @@ -614,7 +630,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
let input = to_alp_impl(owned(input), ctxt)
.map_err(|e| e.context(failed_input!(with_columns)))?;
let (exprs, schema) =
resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena)
resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena, ctxt.opt_flags)
.map_err(|e| e.context(failed_here!(with_columns)))?;

ctxt.conversion_optimizer
Expand Down Expand Up @@ -680,9 +696,14 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
})
.collect::<Vec<_>>();

let (exprs, schema) =
resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena)
.map_err(|e| e.context(failed_here!(fill_nan)))?;
let (exprs, schema) = resolve_with_columns(
exprs,
input,
ctxt.lp_arena,
ctxt.expr_arena,
ctxt.opt_flags,
)
.map_err(|e| e.context(failed_here!(fill_nan)))?;

ctxt.conversion_optimizer
.fill_scratch(&exprs, ctxt.expr_arena);
Expand Down Expand Up @@ -911,7 +932,12 @@ fn expand_scan_paths_with_hive_update(
Ok(expanded_paths)
}

fn expand_filter(predicate: Expr, input: Node, lp_arena: &Arena<IR>) -> PolarsResult<Expr> {
fn expand_filter(
predicate: Expr,
input: Node,
lp_arena: &Arena<IR>,
opt_flags: &mut OptFlags,
) -> PolarsResult<Expr> {
let schema = lp_arena.get(input).schema(lp_arena);
let predicate = if has_expr(&predicate, |e| match e {
Expr::Column(name) => is_regex_projection(name),
Expand All @@ -924,7 +950,7 @@ fn expand_filter(predicate: Expr, input: Node, lp_arena: &Arena<IR>) -> PolarsRe
| Expr::Nth(_) => true,
_ => false,
}) {
let mut rewritten = rewrite_projections(vec![predicate], &schema, &[])?;
let mut rewritten = rewrite_projections(vec![predicate], &schema, &[], opt_flags)?;
match rewritten.len() {
1 => {
// all good
Expand Down Expand Up @@ -971,10 +997,11 @@ fn resolve_with_columns(
input: Node,
lp_arena: &Arena<IR>,
expr_arena: &mut Arena<AExpr>,
opt_flags: &mut OptFlags,
) -> PolarsResult<(Vec<ExprIR>, SchemaRef)> {
let schema = lp_arena.get(input).schema(lp_arena);
let mut new_schema = (**schema).clone();
let (exprs, _) = prepare_projection(exprs, &schema)?;
let (exprs, _) = prepare_projection(exprs, &schema, opt_flags)?;
let mut output_names = PlHashSet::with_capacity(exprs.len());

let mut arena = Arena::with_capacity(8);
Expand Down Expand Up @@ -1008,10 +1035,11 @@ fn resolve_group_by(
_options: &GroupbyOptions,
lp_arena: &Arena<IR>,
expr_arena: &mut Arena<AExpr>,
opt_flags: &mut OptFlags,
) -> PolarsResult<(Vec<ExprIR>, Vec<ExprIR>, SchemaRef)> {
let current_schema = lp_arena.get(input).schema(lp_arena);
let current_schema = current_schema.as_ref();
let mut keys = rewrite_projections(keys, current_schema, &[])?;
let mut keys = rewrite_projections(keys, current_schema, &[], opt_flags)?;

// Initialize schema from keys
let mut schema = expressions_to_schema(&keys, current_schema, Context::Default)?;
Expand Down Expand Up @@ -1042,7 +1070,7 @@ fn resolve_group_by(
}
let keys_index_len = schema.len();

let aggs = rewrite_projections(aggs, current_schema, &keys)?;
let aggs = rewrite_projections(aggs, current_schema, &keys, opt_flags)?;
if pop_keys {
let _ = keys.pop();
}
Expand Down
45 changes: 38 additions & 7 deletions crates/polars-plan/src/plans/conversion/expr_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use super::*;
pub(crate) fn prepare_projection(
exprs: Vec<Expr>,
schema: &Schema,
opt_flags: &mut OptFlags,
) -> PolarsResult<(Vec<Expr>, Schema)> {
let exprs = rewrite_projections(exprs, schema, &[])?;
let exprs = rewrite_projections(exprs, schema, &[], opt_flags)?;
let schema = expressions_to_schema(&exprs, schema, Context::Default)?;
Ok((exprs, schema))
}
Expand Down Expand Up @@ -541,14 +542,18 @@ fn prepare_excluded(
}

// functions can have col(["a", "b"]) or col(String) as inputs
fn expand_function_inputs(expr: Expr, schema: &Schema) -> PolarsResult<Expr> {
fn expand_function_inputs(
expr: Expr,
schema: &Schema,
opt_flags: &mut OptFlags,
) -> PolarsResult<Expr> {
expr.try_map_expr(|mut e| match &mut e {
Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. }
if options
.flags
.contains(FunctionFlags::INPUT_WILDCARD_EXPANSION) =>
{
*input = rewrite_projections(core::mem::take(input), schema, &[]).unwrap();
*input = rewrite_projections(core::mem::take(input), schema, &[], opt_flags).unwrap();
if input.is_empty() && !options.flags.contains(FunctionFlags::ALLOW_EMPTY_INPUTS) {
// Needed to visualize the error
*input = vec![Expr::Literal(LiteralValue::Null)];
Expand Down Expand Up @@ -639,12 +644,27 @@ fn find_flags(expr: &Expr) -> PolarsResult<ExpansionFlags> {
})
}

#[cfg(feature = "dtype-struct")]
fn toggle_cse(opt_flags: &mut OptFlags) {
if opt_flags.contains(OptFlags::EAGER) {
#[cfg(debug_assertions)]
{
use polars_core::config::verbose;
if verbose() {
eprintln!("CSE turned on because of struct expansion")
}
}
*opt_flags |= OptFlags::COMM_SUBEXPR_ELIM;
}
}

/// In case of single col(*) -> do nothing, no selection is the same as select all
/// In other cases replace the wildcard with an expression with all columns
pub(crate) fn rewrite_projections(
exprs: Vec<Expr>,
schema: &Schema,
keys: &[Expr],
opt_flags: &mut OptFlags,
) -> PolarsResult<Vec<Expr>> {
let mut result = Vec::with_capacity(exprs.len() + schema.len());

Expand All @@ -653,7 +673,7 @@ pub(crate) fn rewrite_projections(
let result_offset = result.len();

// Functions can have col(["a", "b"]) or col(String) as inputs.
expr = expand_function_inputs(expr, schema)?;
expr = expand_function_inputs(expr, schema, opt_flags)?;

let mut flags = find_flags(&expr)?;
if flags.has_selector {
Expand All @@ -662,10 +682,11 @@ pub(crate) fn rewrite_projections(
flags.multiple_columns = true;
}

replace_and_add_to_results(expr, flags, &mut result, schema, keys)?;
replace_and_add_to_results(expr, flags, &mut result, schema, keys, opt_flags)?;

#[cfg(feature = "dtype-struct")]
if flags.has_struct_field_by_index {
toggle_cse(opt_flags);
for e in &mut result[result_offset..] {
*e = struct_index_to_field(std::mem::take(e), schema)?;
}
Expand All @@ -680,6 +701,7 @@ fn replace_and_add_to_results(
result: &mut Vec<Expr>,
schema: &Schema,
keys: &[Expr],
opt_flags: &mut OptFlags,
) -> PolarsResult<()> {
if flags.has_nth {
expr = replace_nth(expr, schema);
Expand Down Expand Up @@ -732,19 +754,21 @@ fn replace_and_add_to_results(
&mut intermediate,
schema,
keys,
opt_flags,
)?;

// Then expand the fields and add to the final result vec.
flags.expands_fields = true;
flags.multiple_columns = false;
flags.has_wildcard = false;
for e in intermediate {
replace_and_add_to_results(e, flags, result, schema, keys)?;
replace_and_add_to_results(e, flags, result, schema, keys, opt_flags)?;
}
}
// has only field expansion
// col('a').struct.field('*')
else {
toggle_cse(opt_flags);
expand_struct_fields(e, &expr, result, schema, names, &exclude)?
}
},
Expand Down Expand Up @@ -787,7 +811,14 @@ fn replace_selector_inner(
match s {
Selector::Root(expr) => {
let local_flags = find_flags(&expr)?;
replace_and_add_to_results(*expr, local_flags, scratch, schema, keys)?;
replace_and_add_to_results(
*expr,
local_flags,
scratch,
schema,
keys,
&mut Default::default(),
)?;
members.extend(scratch.drain(..))
},
Selector::Add(lhs, rhs) => {
Expand Down
11 changes: 9 additions & 2 deletions crates/polars-plan/src/plans/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ pub fn optimize(
let opt = StackOptimizer {};
let mut rules: Vec<Box<dyn OptimizationRule>> = Vec::with_capacity(8);

// Unset CSE
// This can be turned on again during ir-conversion.
#[allow(clippy::eq_op)]
#[cfg(feature = "cse")]
if opt_state.contains(OptFlags::EAGER) {
opt_state &= !(OptFlags::COMM_SUBEXPR_ELIM | OptFlags::COMM_SUBEXPR_ELIM);
}
let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena, &mut opt_state)?;

// get toggle values
Expand All @@ -87,10 +94,10 @@ pub fn optimize(
// This keeps eager execution more snappy.
let eager = opt_state.contains(OptFlags::EAGER);
#[cfg(feature = "cse")]
let comm_subplan_elim = opt_state.contains(OptFlags::COMM_SUBPLAN_ELIM) && !eager;
let comm_subplan_elim = opt_state.contains(OptFlags::COMM_SUBPLAN_ELIM);

#[cfg(feature = "cse")]
let comm_subexpr_elim = opt_state.contains(OptFlags::COMM_SUBEXPR_ELIM) && !eager;
let comm_subexpr_elim = opt_state.contains(OptFlags::COMM_SUBEXPR_ELIM);
#[cfg(not(feature = "cse"))]
let comm_subexpr_elim = false;

Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,3 +783,15 @@ def test_cse_chunks_18124() -> None:
)
.filter(pl.col("ts_diff") > 1)
).collect().shape == (4, 4)


def test_eager_cse_during_struct_expansion_18411() -> None:
df = pl.DataFrame({"foo": [0, 0, 0, 1, 1]})
vc = pl.col("foo").value_counts()
classes = vc.struct[0]
counts = vc.struct[1]
# Check if output is stable
assert (
df.select(pl.col("foo").replace(classes, counts))
== df.select(pl.col("foo").replace(classes, counts))
)["foo"].all()

0 comments on commit 89cceb0

Please sign in to comment.