Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enable CSE in eager if struct are expanded #18426

Merged
merged 4 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions crates/polars-plan/src/frame/opt_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ bitflags! {
const FILE_CACHING = 1 << 6;
/// Pushdown slices/limits.
const SLICE_PUSHDOWN = 1 << 7;
#[cfg(feature = "cse")]
/// Run common-subplan-elimination. This elides duplicate plans and caches their
/// outputs.
const COMM_SUBPLAN_ELIM = 1 << 8;
#[cfg(feature = "cse")]
/// Run common-subexpression-elimination. This elides duplicate expressions and caches their
/// outputs.
const COMM_SUBEXPR_ELIM = 1 << 9;
Expand Down
70 changes: 49 additions & 21 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ fn expand_expressions(
exprs: Vec<Expr>,
lp_arena: &Arena<IR>,
expr_arena: &mut Arena<AExpr>,
opt_flags: &mut OptFlags,
) -> PolarsResult<Vec<ExprIR>> {
let schema = lp_arena.get(input).schema(lp_arena);
let exprs = rewrite_projections(exprs, &schema, &[])?;
let exprs = rewrite_projections(exprs, &schema, &[], opt_flags)?;
to_expr_irs(exprs, expr_arena)
}

Expand Down Expand Up @@ -57,17 +58,18 @@ pub fn to_alp(
expr_arena: &mut Arena<AExpr>,
lp_arena: &mut Arena<IR>,
// Only `SIMPLIFY_EXPR` and `TYPE_COERCION` are respected.
opt_state: &mut OptFlags,
opt_flags: &mut OptFlags,
) -> PolarsResult<Node> {
let conversion_optimizer = ConversionOptimizer::new(
opt_state.contains(OptFlags::SIMPLIFY_EXPR),
opt_state.contains(OptFlags::TYPE_COERCION),
opt_flags.contains(OptFlags::SIMPLIFY_EXPR),
opt_flags.contains(OptFlags::TYPE_COERCION),
);

let mut ctxt = ConversionContext {
expr_arena,
lp_arena,
conversion_optimizer,
opt_flags,
};

to_alp_impl(lp, &mut ctxt)
Expand All @@ -77,6 +79,7 @@ struct ConversionContext<'a> {
expr_arena: &'a mut Arena<AExpr>,
lp_arena: &'a mut Arena<IR>,
conversion_optimizer: ConversionOptimizer,
opt_flags: &'a mut OptFlags,
}

/// converts LogicalPlan to IR
Expand Down Expand Up @@ -305,7 +308,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
DslPlan::Filter { input, predicate } => {
let mut input =
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(filter)))?;
let predicate = expand_filter(predicate, input, ctxt.lp_arena)
let predicate = expand_filter(predicate, input, ctxt.lp_arena, ctxt.opt_flags)
.map_err(|e| e.context(failed_here!(filter)))?;

let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?;
Expand Down Expand Up @@ -378,8 +381,8 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
let input =
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(select)))?;
let schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
let (exprs, schema) =
prepare_projection(expr, &schema).map_err(|e| e.context(failed_here!(select)))?;
let (exprs, schema) = prepare_projection(expr, &schema, ctxt.opt_flags)
.map_err(|e| e.context(failed_here!(select)))?;

if exprs.is_empty() {
ctxt.lp_arena.replace(input, empty_df());
Expand Down Expand Up @@ -442,8 +445,14 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
.cycle()
.zip(sort_options.descending.iter().cycle()),
) {
let exprs = expand_expressions(input, vec![c], ctxt.lp_arena, ctxt.expr_arena)
.map_err(|e| e.context(failed_here!(sort)))?;
let exprs = expand_expressions(
input,
vec![c],
ctxt.lp_arena,
ctxt.expr_arena,
ctxt.opt_flags,
)
.map_err(|e| e.context(failed_here!(sort)))?;

nulls_last.extend(std::iter::repeat(n).take(exprs.len()));
descending.extend(std::iter::repeat(d).take(exprs.len()));
Expand Down Expand Up @@ -489,9 +498,16 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
let input =
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(group_by)))?;

let (keys, aggs, schema) =
resolve_group_by(input, keys, aggs, &options, ctxt.lp_arena, ctxt.expr_arena)
.map_err(|e| e.context(failed_here!(group_by)))?;
let (keys, aggs, schema) = resolve_group_by(
input,
keys,
aggs,
&options,
ctxt.lp_arena,
ctxt.expr_arena,
ctxt.opt_flags,
)
.map_err(|e| e.context(failed_here!(group_by)))?;

let (apply, schema) = if let Some((apply, schema)) = apply {
(Some(apply), schema)
Expand Down Expand Up @@ -614,7 +630,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
let input = to_alp_impl(owned(input), ctxt)
.map_err(|e| e.context(failed_input!(with_columns)))?;
let (exprs, schema) =
resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena)
resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena, ctxt.opt_flags)
.map_err(|e| e.context(failed_here!(with_columns)))?;

ctxt.conversion_optimizer
Expand Down Expand Up @@ -680,9 +696,14 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut ConversionContext) -> PolarsResult<No
})
.collect::<Vec<_>>();

let (exprs, schema) =
resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena)
.map_err(|e| e.context(failed_here!(fill_nan)))?;
let (exprs, schema) = resolve_with_columns(
exprs,
input,
ctxt.lp_arena,
ctxt.expr_arena,
ctxt.opt_flags,
)
.map_err(|e| e.context(failed_here!(fill_nan)))?;

ctxt.conversion_optimizer
.fill_scratch(&exprs, ctxt.expr_arena);
Expand Down Expand Up @@ -911,7 +932,12 @@ fn expand_scan_paths_with_hive_update(
Ok(expanded_paths)
}

fn expand_filter(predicate: Expr, input: Node, lp_arena: &Arena<IR>) -> PolarsResult<Expr> {
fn expand_filter(
predicate: Expr,
input: Node,
lp_arena: &Arena<IR>,
opt_flags: &mut OptFlags,
) -> PolarsResult<Expr> {
let schema = lp_arena.get(input).schema(lp_arena);
let predicate = if has_expr(&predicate, |e| match e {
Expr::Column(name) => is_regex_projection(name),
Expand All @@ -924,7 +950,7 @@ fn expand_filter(predicate: Expr, input: Node, lp_arena: &Arena<IR>) -> PolarsRe
| Expr::Nth(_) => true,
_ => false,
}) {
let mut rewritten = rewrite_projections(vec![predicate], &schema, &[])?;
let mut rewritten = rewrite_projections(vec![predicate], &schema, &[], opt_flags)?;
match rewritten.len() {
1 => {
// all good
Expand Down Expand Up @@ -971,10 +997,11 @@ fn resolve_with_columns(
input: Node,
lp_arena: &Arena<IR>,
expr_arena: &mut Arena<AExpr>,
opt_flags: &mut OptFlags,
) -> PolarsResult<(Vec<ExprIR>, SchemaRef)> {
let schema = lp_arena.get(input).schema(lp_arena);
let mut new_schema = (**schema).clone();
let (exprs, _) = prepare_projection(exprs, &schema)?;
let (exprs, _) = prepare_projection(exprs, &schema, opt_flags)?;
let mut output_names = PlHashSet::with_capacity(exprs.len());

let mut arena = Arena::with_capacity(8);
Expand Down Expand Up @@ -1008,10 +1035,11 @@ fn resolve_group_by(
_options: &GroupbyOptions,
lp_arena: &Arena<IR>,
expr_arena: &mut Arena<AExpr>,
opt_flags: &mut OptFlags,
) -> PolarsResult<(Vec<ExprIR>, Vec<ExprIR>, SchemaRef)> {
let current_schema = lp_arena.get(input).schema(lp_arena);
let current_schema = current_schema.as_ref();
let mut keys = rewrite_projections(keys, current_schema, &[])?;
let mut keys = rewrite_projections(keys, current_schema, &[], opt_flags)?;

// Initialize schema from keys
let mut schema = expressions_to_schema(&keys, current_schema, Context::Default)?;
Expand Down Expand Up @@ -1042,7 +1070,7 @@ fn resolve_group_by(
}
let keys_index_len = schema.len();

let aggs = rewrite_projections(aggs, current_schema, &keys)?;
let aggs = rewrite_projections(aggs, current_schema, &keys, opt_flags)?;
if pop_keys {
let _ = keys.pop();
}
Expand Down
45 changes: 38 additions & 7 deletions crates/polars-plan/src/plans/conversion/expr_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use super::*;
pub(crate) fn prepare_projection(
exprs: Vec<Expr>,
schema: &Schema,
opt_flags: &mut OptFlags,
) -> PolarsResult<(Vec<Expr>, Schema)> {
let exprs = rewrite_projections(exprs, schema, &[])?;
let exprs = rewrite_projections(exprs, schema, &[], opt_flags)?;
let schema = expressions_to_schema(&exprs, schema, Context::Default)?;
Ok((exprs, schema))
}
Expand Down Expand Up @@ -541,14 +542,18 @@ fn prepare_excluded(
}

// functions can have col(["a", "b"]) or col(String) as inputs
fn expand_function_inputs(expr: Expr, schema: &Schema) -> PolarsResult<Expr> {
fn expand_function_inputs(
expr: Expr,
schema: &Schema,
opt_flags: &mut OptFlags,
) -> PolarsResult<Expr> {
expr.try_map_expr(|mut e| match &mut e {
Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. }
if options
.flags
.contains(FunctionFlags::INPUT_WILDCARD_EXPANSION) =>
{
*input = rewrite_projections(core::mem::take(input), schema, &[]).unwrap();
*input = rewrite_projections(core::mem::take(input), schema, &[], opt_flags).unwrap();
if input.is_empty() && !options.flags.contains(FunctionFlags::ALLOW_EMPTY_INPUTS) {
// Needed to visualize the error
*input = vec![Expr::Literal(LiteralValue::Null)];
Expand Down Expand Up @@ -639,12 +644,27 @@ fn find_flags(expr: &Expr) -> PolarsResult<ExpansionFlags> {
})
}

#[cfg(feature = "dtype-struct")]
fn toggle_cse(opt_flags: &mut OptFlags) {
if opt_flags.contains(OptFlags::EAGER) {
#[cfg(debug_assertions)]
{
use polars_core::config::verbose;
if verbose() {
eprintln!("CSE turned on because of struct expansion")
}
}
*opt_flags |= OptFlags::COMM_SUBEXPR_ELIM;
}
}

/// In case of single col(*) -> do nothing, no selection is the same as select all
/// In other cases replace the wildcard with an expression with all columns
pub(crate) fn rewrite_projections(
exprs: Vec<Expr>,
schema: &Schema,
keys: &[Expr],
opt_flags: &mut OptFlags,
) -> PolarsResult<Vec<Expr>> {
let mut result = Vec::with_capacity(exprs.len() + schema.len());

Expand All @@ -653,7 +673,7 @@ pub(crate) fn rewrite_projections(
let result_offset = result.len();

// Functions can have col(["a", "b"]) or col(String) as inputs.
expr = expand_function_inputs(expr, schema)?;
expr = expand_function_inputs(expr, schema, opt_flags)?;

let mut flags = find_flags(&expr)?;
if flags.has_selector {
Expand All @@ -662,10 +682,11 @@ pub(crate) fn rewrite_projections(
flags.multiple_columns = true;
}

replace_and_add_to_results(expr, flags, &mut result, schema, keys)?;
replace_and_add_to_results(expr, flags, &mut result, schema, keys, opt_flags)?;

#[cfg(feature = "dtype-struct")]
if flags.has_struct_field_by_index {
toggle_cse(opt_flags);
for e in &mut result[result_offset..] {
*e = struct_index_to_field(std::mem::take(e), schema)?;
}
Expand All @@ -680,6 +701,7 @@ fn replace_and_add_to_results(
result: &mut Vec<Expr>,
schema: &Schema,
keys: &[Expr],
opt_flags: &mut OptFlags,
) -> PolarsResult<()> {
if flags.has_nth {
expr = replace_nth(expr, schema);
Expand Down Expand Up @@ -732,19 +754,21 @@ fn replace_and_add_to_results(
&mut intermediate,
schema,
keys,
opt_flags,
)?;

// Then expand the fields and add to the final result vec.
flags.expands_fields = true;
flags.multiple_columns = false;
flags.has_wildcard = false;
for e in intermediate {
replace_and_add_to_results(e, flags, result, schema, keys)?;
replace_and_add_to_results(e, flags, result, schema, keys, opt_flags)?;
}
}
// has only field expansion
// col('a').struct.field('*')
else {
toggle_cse(opt_flags);
expand_struct_fields(e, &expr, result, schema, names, &exclude)?
}
},
Expand Down Expand Up @@ -787,7 +811,14 @@ fn replace_selector_inner(
match s {
Selector::Root(expr) => {
let local_flags = find_flags(&expr)?;
replace_and_add_to_results(*expr, local_flags, scratch, schema, keys)?;
replace_and_add_to_results(
*expr,
local_flags,
scratch,
schema,
keys,
&mut Default::default(),
)?;
members.extend(scratch.drain(..))
},
Selector::Add(lhs, rhs) => {
Expand Down
11 changes: 9 additions & 2 deletions crates/polars-plan/src/plans/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ pub fn optimize(
let opt = StackOptimizer {};
let mut rules: Vec<Box<dyn OptimizationRule>> = Vec::with_capacity(8);

// Unset CSE
// This can be turned on again during ir-conversion.
#[allow(clippy::eq_op)]
#[cfg(feature = "cse")]
if opt_state.contains(OptFlags::EAGER) {
opt_state &= !(OptFlags::COMM_SUBEXPR_ELIM | OptFlags::COMM_SUBEXPR_ELIM);
}
let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena, &mut opt_state)?;

// get toggle values
Expand All @@ -87,10 +94,10 @@ pub fn optimize(
// This keeps eager execution more snappy.
let eager = opt_state.contains(OptFlags::EAGER);
#[cfg(feature = "cse")]
let comm_subplan_elim = opt_state.contains(OptFlags::COMM_SUBPLAN_ELIM) && !eager;
let comm_subplan_elim = opt_state.contains(OptFlags::COMM_SUBPLAN_ELIM);

#[cfg(feature = "cse")]
let comm_subexpr_elim = opt_state.contains(OptFlags::COMM_SUBEXPR_ELIM) && !eager;
let comm_subexpr_elim = opt_state.contains(OptFlags::COMM_SUBEXPR_ELIM);
#[cfg(not(feature = "cse"))]
let comm_subexpr_elim = false;

Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,3 +783,15 @@ def test_cse_chunks_18124() -> None:
)
.filter(pl.col("ts_diff") > 1)
).collect().shape == (4, 4)


def test_eager_cse_during_struct_expansion_18411() -> None:
df = pl.DataFrame({"foo": [0, 0, 0, 1, 1]})
vc = pl.col("foo").value_counts()
classes = vc.struct[0]
counts = vc.struct[1]
# Check if output is stable
assert (
df.select(pl.col("foo").replace(classes, counts))
== df.select(pl.col("foo").replace(classes, counts))
)["foo"].all()
Loading