Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TreeNode recursions #7942

Closed
wants to merge 11 commits into from
6 changes: 3 additions & 3 deletions datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {

impl MyAnalyzerRule {
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
plan.transform(&|plan| {
plan.transform_up_old(&|plan| {
Ok(match plan {
LogicalPlan::Filter(filter) => {
let predicate = Self::analyze_expr(filter.predicate.clone())?;
Expand All @@ -106,7 +106,7 @@ impl MyAnalyzerRule {
}

fn analyze_expr(expr: Expr) -> Result<Expr> {
expr.transform(&|expr| {
expr.transform_up_old(&|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Literal(ScalarValue::Int64(i)) => {
Expand Down Expand Up @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule {

/// use rewrite_expr to modify the expression tree.
fn my_rewrite(expr: Expr) -> Result<Expr> {
expr.transform(&|expr| {
expr.transform_up_old(&|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Between(Between {
Expand Down
407 changes: 304 additions & 103 deletions datafusion/common/src/tree_node.rs

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use crate::{error::Result, scalar::ScalarValue};
use super::PartitionedFile;
use crate::datasource::listing::ListingTableUrl;
use crate::execution::context::SessionState;
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError};
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
use datafusion_physical_expr::create_physical_expr;
Expand All @@ -52,14 +52,14 @@ use object_store::{ObjectMeta, ObjectStore};
/// was performed
pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
let mut is_applicable = true;
expr.apply(&mut |expr| {
expr.visit_down(&mut |expr| {
match expr {
Expr::Column(Column { ref name, .. }) => {
is_applicable &= col_names.contains(name);
if is_applicable {
Ok(VisitRecursion::Skip)
Ok(TreeNodeRecursion::Prune)
} else {
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
Expr::Literal(_)
Expand Down Expand Up @@ -88,27 +88,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
| Expr::ScalarSubquery(_)
| Expr::GetIndexedField { .. }
| Expr::GroupingSet(_)
| Expr::Case { .. } => Ok(VisitRecursion::Continue),
| Expr::Case { .. } => Ok(TreeNodeRecursion::Continue),

Expr::ScalarFunction(scalar_function) => {
match &scalar_function.func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
match fun.volatility() {
Volatility::Immutable => Ok(VisitRecursion::Continue),
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
}
ScalarFunctionDefinition::UDF(fun) => {
match fun.signature().volatility {
Volatility::Immutable => Ok(VisitRecursion::Continue),
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
is_applicable = false;
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
}
Expand All @@ -128,7 +128,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
| Expr::Wildcard { .. }
| Expr::Placeholder(_) => {
is_applicable = false;
Ok(VisitRecursion::Stop)
Ok(TreeNodeRecursion::Stop)
}
}
})
Expand Down
12 changes: 8 additions & 4 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::{
use datafusion_common::{
alias::AliasGenerator,
exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion},
tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor},
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
Expand Down Expand Up @@ -2090,9 +2090,9 @@ impl<'a> BadPlanVisitor<'a> {
}

impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
type N = LogicalPlan;
type Node = LogicalPlan;

fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion> {
fn pre_visit(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
match node {
LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
plan_err!("DDL not supported: {}", ddl.name())
Expand All @@ -2106,9 +2106,13 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
plan_err!("Statement not supported: {}", stmt.name())
}
_ => Ok(VisitRecursion::Continue),
_ => Ok(TreeNodeRecursion::Continue),
}
}

fn post_visit(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/coalesce_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl PhysicalOptimizerRule for CoalesceBatches {
}

let target_batch_size = config.execution.batch_size;
plan.transform_up(&|plan| {
plan.transform_up_old(&|plan| {
let plan_any = plan.as_any();
// The goal here is to detect operators that could produce small batches and only
// wrap those ones with a CoalesceBatchesExec operator. An alternate approach here
Expand Down
103 changes: 48 additions & 55 deletions datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGro
use crate::physical_plan::ExecutionPlan;

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};

Expand All @@ -48,27 +48,27 @@ impl CombinePartialFinalAggregate {
impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
mut plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_down(&|plan| {
let transformed =
plan.as_any()
.downcast_ref::<AggregateExec>()
.and_then(|agg_exec| {
if matches!(
agg_exec.mode(),
AggregateMode::Final | AggregateMode::FinalPartitioned
) {
agg_exec
.input()
.as_any()
.downcast_ref::<AggregateExec>()
.and_then(|input_agg_exec| {
if matches!(
input_agg_exec.mode(),
AggregateMode::Partial
) && can_combine(
plan.transform_down(&mut |plan| {
plan.clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is certainly nice to avoid the clone

.as_any()
.downcast_ref::<AggregateExec>()
.into_iter()
.for_each(|agg_exec| {
if matches!(
agg_exec.mode(),
AggregateMode::Final | AggregateMode::FinalPartitioned
) {
agg_exec
.input()
.as_any()
.downcast_ref::<AggregateExec>()
.into_iter()
.for_each(|input_agg_exec| {
if matches!(input_agg_exec.mode(), AggregateMode::Partial)
&& can_combine(
(
agg_exec.group_by(),
agg_exec.aggr_expr(),
Expand All @@ -79,41 +79,34 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
input_agg_exec.aggr_expr(),
input_agg_exec.filter_expr(),
),
) {
let mode =
if agg_exec.mode() == &AggregateMode::Final {
AggregateMode::Single
} else {
AggregateMode::SinglePartitioned
};
AggregateExec::try_new(
mode,
input_agg_exec.group_by().clone(),
input_agg_exec.aggr_expr().to_vec(),
input_agg_exec.filter_expr().to_vec(),
input_agg_exec.input().clone(),
input_agg_exec.input_schema(),
)
.map(|combined_agg| {
combined_agg.with_limit(agg_exec.limit())
})
.ok()
.map(Arc::new)
)
{
let mode = if agg_exec.mode() == &AggregateMode::Final
{
AggregateMode::Single
} else {
None
}
})
} else {
None
}
});

Ok(if let Some(transformed) = transformed {
Transformed::Yes(transformed)
} else {
Transformed::No(plan)
})
})
AggregateMode::SinglePartitioned
};
AggregateExec::try_new(
mode,
input_agg_exec.group_by().clone(),
input_agg_exec.aggr_expr().to_vec(),
input_agg_exec.filter_expr().to_vec(),
input_agg_exec.input().clone(),
input_agg_exec.input_schema(),
)
.map(|combined_agg| {
combined_agg.with_limit(agg_exec.limit())
})
.into_iter()
.for_each(|p| *plan = Arc::new(p))
}
})
}
});
Ok(TreeNodeRecursion::Continue)
})?;
Ok(plan)
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -178,7 +171,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs {
fn discard_column_index(group_expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
group_expr
.clone()
.transform(&|expr| {
.transform_up_old(&|expr| {
let normalized_form: Option<Arc<dyn PhysicalExpr>> =
match expr.as_any().downcast_ref::<Column>() {
Some(column) => Some(Arc::new(Column::new(column.name(), 0))),
Expand Down
44 changes: 40 additions & 4 deletions datafusion/core/src/physical_optimizer/enforce_distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ use crate::physical_plan::{
};

use arrow::compute::SortOptions;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator,
};
use datafusion_expr::logical_plan::JoinType;
use datafusion_physical_expr::expressions::{Column, NoOp};
use datafusion_physical_expr::utils::map_columns_before_projection;
Expand Down Expand Up @@ -201,19 +203,19 @@ impl PhysicalOptimizerRule for EnforceDistribution {
// Run a top-down process to adjust input key ordering recursively
let plan_requirements = PlanWithKeyRequirements::new(plan);
let adjusted =
plan_requirements.transform_down(&adjust_input_keys_ordering)?;
plan_requirements.transform_down_old(&adjust_input_keys_ordering)?;
adjusted.plan
} else {
// Run a bottom-up process
plan.transform_up(&|plan| {
plan.transform_up_old(&|plan| {
Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?))
})?
};

let distribution_context = DistributionContext::new(adjusted);
// Distribution enforcement needs to be applied bottom-up.
let distribution_context =
distribution_context.transform_up(&|distribution_context| {
distribution_context.transform_up_old(&|distribution_context| {
ensure_distribution(distribution_context, config)
})?;
Ok(distribution_context.plan)
Expand Down Expand Up @@ -1432,6 +1434,23 @@ impl TreeNode for DistributionContext {
}
Ok(self)
}

fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
{
if !self.children_nodes.is_empty() {
let tnr = self.children_nodes.iter_mut().for_each_till_continue(f)?;
self.plan = with_new_children_if_necessary(
self.plan.clone(),
self.children_nodes.iter().map(|c| c.plan.clone()).collect(),
)?
.into();
Ok(tnr)
} else {
Ok(TreeNodeRecursion::Continue)
}
}
}

/// implement Display method for `DistributionContext` struct.
Expand Down Expand Up @@ -1496,6 +1515,23 @@ impl TreeNode for PlanWithKeyRequirements {
}
Ok(self)
}

fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
{
if !self.children.is_empty() {
let tnr = self.children.iter_mut().for_each_till_continue(f)?;
self.plan = with_new_children_if_necessary(
self.plan.clone(),
self.children.iter().map(|c| c.plan.clone()).collect(),
)?
.into();
Ok(tnr)
} else {
Ok(TreeNodeRecursion::Continue)
}
}
}

/// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on
Expand Down
Loading