diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 9dfc238ab9e8..226f548dd446 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule { impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform_up(&|plan| { + plan.transform_up_old(&|plan| { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; @@ -106,7 +106,7 @@ impl MyAnalyzerRule { } fn analyze_expr(expr: Expr) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule { /// use rewrite_expr to modify the expression tree. fn my_rewrite(expr: Expr) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Between(Between { diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 39d691a9dcea..e15c43543a22 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -57,19 +57,19 @@ pub trait TreeNode: Sized { F: FnMut(&Self) -> Result, { // Apply `f` on self. - f(self) + f(self)? // If it returns continue (not prune or stop or stop all) then continue // traversal on inner children and children. .and_then_on_continue(|| { // Run the recursive `apply` on each inner children, but as they are // unrelated root nodes of inner trees if any returns stop then continue // with the next one. - self.apply_inner_children(&mut |c| c.visit_down(f).continue_on_stop()) + self.apply_inner_children(&mut |c| c.visit_down(f)?.continue_on_stop())? // Run the recursive `apply` on each children. .and_then_on_continue(|| { self.apply_children(&mut |c| c.visit_down(f)) }) - }) + })? // Applying `f` on self might have returned prune, but we need to propagate // continue. .continue_on_prune() @@ -107,21 +107,21 @@ pub trait TreeNode: Sized { ) -> Result { // Apply `pre_visit` on self. visitor - .pre_visit(self) + .pre_visit(self)? // If it returns continue (not prune or stop or stop all) then continue // traversal on inner children and children. .and_then_on_continue(|| { // Run the recursive `visit` on each inner children, but as they are // unrelated subquery plans if any returns stop then continue with the // next one. - self.apply_inner_children(&mut |c| c.visit(visitor).continue_on_stop()) + self.apply_inner_children(&mut |c| c.visit(visitor)?.continue_on_stop())? // Run the recursive `visit` on each children. .and_then_on_continue(|| { self.apply_children(&mut |c| c.visit(visitor)) - }) + })? // Apply `post_visit` on self. .and_then_on_continue(|| visitor.post_visit(self)) - }) + })? // Applying `pre_visit` or `post_visit` on self might have returned prune, // but we need to propagate continue. .continue_on_prune() @@ -133,31 +133,144 @@ pub trait TreeNode: Sized { ) -> Result { // Apply `pre_transform` on self. transformer - .pre_transform(self) + .pre_transform(self)? // If it returns continue (not prune or stop or stop all) then continue // traversal on inner children and children. .and_then_on_continue(|| // Run the recursive `transform` on each children. self - .transform_children(&mut |c| c.transform(transformer)) + .transform_children(&mut |c| c.transform(transformer))? // Apply `post_transform` on new self. - .and_then_on_continue(|| { - transformer.post_transform(self) - })) + .and_then_on_continue(|| transformer.post_transform(self)))? // Applying `pre_transform` or `post_transform` on self might have returned // prune, but we need to propagate continue. .continue_on_prune() } + fn transform_with_payload( + &mut self, + f_down: &mut FD, + payload_down: Option, + f_up: &mut FU, + ) -> Result<(TreeNodeRecursion, Option)> + where + FD: FnMut(&mut Self, Option) -> Result<(TreeNodeRecursion, Vec)>, + FU: FnMut(&mut Self, Vec) -> Result<(TreeNodeRecursion, PU)>, + { + // Apply `f_down` on self. + let (tnr, new_payload_down) = f_down(self, payload_down)?; + let mut new_payload_down_iter = new_payload_down.into_iter(); + // If it returns continue (not prune or stop or stop all) then continue traversal + // on inner children and children. + let mut new_payload_up = None; + tnr.and_then_on_continue(|| { + // Run the recursive `transform` on each children. + let mut payload_up = vec![]; + let tnr = self.transform_children(&mut |c| { + let (tnr, p) = + c.transform_with_payload(f_down, new_payload_down_iter.next(), f_up)?; + p.into_iter().for_each(|p| payload_up.push(p)); + Ok(tnr) + })?; + // Apply `f_up` on self. + tnr.and_then_on_continue(|| { + let (tnr, np) = f_up(self, payload_up)?; + new_payload_up = Some(np); + Ok(tnr) + }) + })? + // Applying `f_down` or `f_up` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() + .map(|tnr| (tnr, new_payload_up)) + } + + fn transform_down(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + // Apply `f` on self. + f(self)? + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| + // Run the recursive `transform` on each children. + self.transform_children(&mut |c| c.transform_down(f)))? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() + } + + fn transform_down_with_payload( + &mut self, + f: &mut F, + payload: P, + ) -> Result + where + F: FnMut(&mut Self, P) -> Result<(TreeNodeRecursion, Vec

)>, + { + // Apply `f` on self. + let (tnr, new_payload) = f(self, payload)?; + let mut new_payload_iter = new_payload.into_iter(); + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + tnr.and_then_on_continue(|| + // Run the recursive `transform` on each children. + self.transform_children(&mut |c| c.transform_down_with_payload(f, new_payload_iter.next().unwrap())))? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() + } + + fn transform_up(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + // Run the recursive `transform` on each children. + self.transform_children(&mut |c| c.transform_up(f))? + // Apply `f` on self. + .and_then_on_continue(|| f(self))? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() + } + + fn transform_up_with_payload( + &mut self, + f: &mut F, + ) -> Result<(TreeNodeRecursion, Option

)> + where + F: FnMut(&mut Self, Vec

) -> Result<(TreeNodeRecursion, P)>, + { + // Run the recursive `transform` on each children. + let mut payload = vec![]; + let tnr = self.transform_children(&mut |c| { + let (tnr, p) = c.transform_up_with_payload(f)?; + p.into_iter().for_each(|p| payload.push(p)); + Ok(tnr) + })?; + let mut new_payload = None; + // Apply `f` on self. + tnr.and_then_on_continue(|| { + let (tnr, np) = f(self, payload)?; + new_payload = Some(np); + Ok(tnr) + })? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() + .map(|tnr| (tnr, new_payload)) + } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its /// children(Preorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down(self, op: &F) -> Result + fn transform_down_old(self, op: &F) -> Result where F: Fn(Self) -> Result>, { let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down(op)) + after_op.map_children(|node| node.transform_down_old(op)) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -174,11 +287,11 @@ pub trait TreeNode: Sized { /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up(self, op: &F) -> Result + fn transform_up_old(self, op: &F) -> Result where F: Fn(Self) -> Result>, { - let after_op_children = self.map_children(|node| node.transform_up(op))?; + let after_op_children = self.map_children(|node| node.transform_up_old(op))?; let new_node = op(after_op_children)?.into(); Ok(new_node) @@ -402,63 +515,35 @@ pub enum TreeNodeRecursion { } impl TreeNodeRecursion { - fn continue_on_prune(self) -> TreeNodeRecursion { - match self { - TreeNodeRecursion::Prune => TreeNodeRecursion::Continue, - o => o, - } - } - - fn fail_on_prune(self) -> TreeNodeRecursion { - match self { - TreeNodeRecursion::Prune => panic!("Recursion can't prune."), - o => o, - } - } - - fn continue_on_stop(self) -> TreeNodeRecursion { - match self { - TreeNodeRecursion::Stop => TreeNodeRecursion::Continue, - o => o, - } - } -} - -/// This helper trait provide functions to control recursion on -/// [`Result`]. -pub trait TreeNodeRecursionResult: Sized { - fn and_then_on_continue(self, f: F) -> Result - where - F: FnOnce() -> Result; - - fn continue_on_prune(self) -> Result; - - fn fail_on_prune(self) -> Result; - - fn continue_on_stop(self) -> Result; -} - -impl TreeNodeRecursionResult for Result { - fn and_then_on_continue(self, f: F) -> Result + pub fn and_then_on_continue(self, f: F) -> Result where F: FnOnce() -> Result, { - match self? { + match self { TreeNodeRecursion::Continue => f(), o => Ok(o), } } - fn continue_on_prune(self) -> Result { - self.map(|tnr| tnr.continue_on_prune()) + pub fn continue_on_prune(self) -> Result { + Ok(match self { + TreeNodeRecursion::Prune => TreeNodeRecursion::Continue, + o => o, + }) } - fn fail_on_prune(self) -> Result { - self.map(|tnr| tnr.fail_on_prune()) + pub fn fail_on_prune(self) -> Result { + Ok(match self { + TreeNodeRecursion::Prune => panic!("Recursion can't prune."), + o => o, + }) } - fn continue_on_stop(self) -> Result { - self.map(|tnr| tnr.continue_on_stop()) + pub fn continue_on_stop(self) -> Result { + Ok(match self { + TreeNodeRecursion::Stop => TreeNodeRecursion::Continue, + o => o, + }) } } diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 7b66ca529094..19a8701a1003 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -52,7 +52,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { } let target_batch_size = config.execution.batch_size; - plan.transform_up(&|plan| { + plan.transform_up_old(&|plan| { let plan_any = plan.as_any(); // The goal here is to detect operators that could produce small batches and only // wrap those ones with a CoalesceBatchesExec operator. An alternate approach here diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 5878650a49e3..09963edd5979 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -26,7 +26,7 @@ use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGro use crate::physical_plan::ExecutionPlan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; @@ -48,27 +48,27 @@ impl CombinePartialFinalAggregate { impl PhysicalOptimizerRule for CombinePartialFinalAggregate { fn optimize( &self, - plan: Arc, + mut plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_down(&|plan| { - let transformed = - plan.as_any() - .downcast_ref::() - .and_then(|agg_exec| { - if matches!( - agg_exec.mode(), - AggregateMode::Final | AggregateMode::FinalPartitioned - ) { - agg_exec - .input() - .as_any() - .downcast_ref::() - .and_then(|input_agg_exec| { - if matches!( - input_agg_exec.mode(), - AggregateMode::Partial - ) && can_combine( + plan.transform_down(&mut |plan| { + plan.clone() + .as_any() + .downcast_ref::() + .into_iter() + .for_each(|agg_exec| { + if matches!( + agg_exec.mode(), + AggregateMode::Final | AggregateMode::FinalPartitioned + ) { + agg_exec + .input() + .as_any() + .downcast_ref::() + .into_iter() + .for_each(|input_agg_exec| { + if matches!(input_agg_exec.mode(), AggregateMode::Partial) + && can_combine( ( agg_exec.group_by(), agg_exec.aggr_expr(), @@ -79,41 +79,34 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { input_agg_exec.aggr_expr(), input_agg_exec.filter_expr(), ), - ) { - let mode = - if agg_exec.mode() == &AggregateMode::Final { - AggregateMode::Single - } else { - AggregateMode::SinglePartitioned - }; - AggregateExec::try_new( - mode, - input_agg_exec.group_by().clone(), - input_agg_exec.aggr_expr().to_vec(), - input_agg_exec.filter_expr().to_vec(), - input_agg_exec.input().clone(), - input_agg_exec.input_schema(), - ) - .map(|combined_agg| { - combined_agg.with_limit(agg_exec.limit()) - }) - .ok() - .map(Arc::new) + ) + { + let mode = if agg_exec.mode() == &AggregateMode::Final + { + AggregateMode::Single } else { - None - } - }) - } else { - None - } - }); - - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(plan) - }) - }) + AggregateMode::SinglePartitioned + }; + AggregateExec::try_new( + mode, + input_agg_exec.group_by().clone(), + input_agg_exec.aggr_expr().to_vec(), + input_agg_exec.filter_expr().to_vec(), + input_agg_exec.input().clone(), + input_agg_exec.input_schema(), + ) + .map(|combined_agg| { + combined_agg.with_limit(agg_exec.limit()) + }) + .into_iter() + .for_each(|p| *plan = Arc::new(p)) + } + }) + } + }); + Ok(TreeNodeRecursion::Continue) + })?; + Ok(plan) } fn name(&self) -> &str { @@ -178,7 +171,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs { fn discard_column_index(group_expr: Arc) -> Arc { group_expr .clone() - .transform_up(&|expr| { + .transform_up_old(&|expr| { let normalized_form: Option> = match expr.as_any().downcast_ref::() { Some(column) => Some(Arc::new(Column::new(column.name(), 0))), diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 9392d443e150..b54ec2d6a7f0 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -200,11 +200,11 @@ impl PhysicalOptimizerRule for EnforceDistribution { // Run a top-down process to adjust input key ordering recursively let plan_requirements = PlanWithKeyRequirements::new(plan); let adjusted = - plan_requirements.transform_down(&adjust_input_keys_ordering)?; + plan_requirements.transform_down_old(&adjust_input_keys_ordering)?; adjusted.plan } else { // Run a bottom-up process - plan.transform_up(&|plan| { + plan.transform_up_old(&|plan| { Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) })? }; @@ -212,7 +212,7 @@ impl PhysicalOptimizerRule for EnforceDistribution { let distribution_context = DistributionContext::new(adjusted); // Distribution enforcement needs to be applied bottom-up. let distribution_context = - distribution_context.transform_up(&|distribution_context| { + distribution_context.transform_up_old(&|distribution_context| { ensure_distribution(distribution_context, config) })?; Ok(distribution_context.plan) diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 9a57a030fcc6..7512f6e8aa2c 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -41,7 +41,7 @@ use crate::error::Result; use crate::physical_optimizer::replace_with_order_preserving_variants::{ replace_with_order_preserving_variants, OrderPreservationContext, }; -use crate::physical_optimizer::sort_pushdown::{pushdown_sorts, SortPushDown}; +use crate::physical_optimizer::sort_pushdown::pushdown_requirement_to_children; use crate::physical_optimizer::utils::{ add_sort_above, is_coalesce_partitions, is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, ExecTree, @@ -340,19 +340,19 @@ impl PhysicalOptimizerRule for EnforceSorting { let plan_requirements = PlanWithCorrespondingSort::new(plan); // Execute a bottom-up traversal to enforce sorting requirements, // remove unnecessary sorts, and optimize sort-sensitive operators: - let adjusted = plan_requirements.transform_up(&ensure_sorting)?; + let adjusted = plan_requirements.transform_up_old(&ensure_sorting)?; let new_plan = if config.optimizer.repartition_sorts { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new(adjusted.plan); let parallel = - plan_with_coalesce_partitions.transform_up(¶llelize_sorts)?; + plan_with_coalesce_partitions.transform_up_old(¶llelize_sorts)?; parallel.plan } else { adjusted.plan }; let plan_with_pipeline_fixer = OrderPreservationContext::new(new_plan); - let updated_plan = - plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { + let mut updated_plan = + plan_with_pipeline_fixer.transform_up_old(&|plan_with_pipeline_fixer| { replace_with_order_preserving_variants( plan_with_pipeline_fixer, false, @@ -363,9 +363,64 @@ impl PhysicalOptimizerRule for EnforceSorting { // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: - let sort_pushdown = SortPushDown::init(updated_plan.plan); - let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; - Ok(adjusted.plan) + updated_plan.plan.transform_down_with_payload( + &mut |plan, required_ordering: Option>| { + let parent_required = required_ordering.as_deref().unwrap_or(&[]); + if let Some(sort_exec) = plan.as_any().downcast_ref::() { + let new_plan = if !plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required) + { + // If the current plan is a SortExec, modify it to satisfy parent requirements: + let mut new_plan = sort_exec.input().clone(); + add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); + new_plan + } else { + plan.clone() + }; + let required_ordering = new_plan + .output_ordering() + .map(PhysicalSortRequirement::from_sort_exprs) + .unwrap_or_default(); + // Since new_plan is a SortExec, we can safely get the 0th index. + let child = new_plan.children().swap_remove(0); + if let Some(adjusted) = + pushdown_requirement_to_children(&child, &required_ordering)? + { + *plan = child; + Ok((TreeNodeRecursion::Continue, adjusted)) + } else { + *plan = new_plan; + // Can not push down requirements + Ok((TreeNodeRecursion::Continue, plan.required_input_ordering())) + } + } else { + // Executors other than SortExec + if plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required) + { + // Satisfies parent requirements, immediately return. + return Ok(( + TreeNodeRecursion::Continue, + plan.required_input_ordering(), + )); + } + // Can not satisfy the parent requirements, check whether the requirements can be pushed down: + if let Some(adjusted) = + pushdown_requirement_to_children(plan, parent_required)? + { + Ok((TreeNodeRecursion::Continue, adjusted)) + } else { + // Can not push down requirements, add new SortExec: + add_sort_above(plan, parent_required, None); + Ok((TreeNodeRecursion::Continue, plan.required_input_ordering())) + } + } + }, + None, + )?; + Ok(updated_plan.plan) } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 6b2fe24acf00..66a27aa7cb6b 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -237,7 +237,8 @@ impl PhysicalOptimizerRule for JoinSelection { Box::new(hash_join_convert_symmetric_subrule), Box::new(hash_join_swap_subrule), ]; - let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules, config))?; + let state = + pipeline.transform_up_old(&|p| apply_subrules(p, &subrules, config))?; // Next, we apply another subrule that tries to optimize joins using any // statistics their inputs might have. // - For a hash join with partition mode [`PartitionMode::Auto`], we will @@ -251,7 +252,7 @@ impl PhysicalOptimizerRule for JoinSelection { // side is the small side. let config = &config.optimizer; let collect_left_threshold = config.hash_join_single_partition_threshold; - state.plan.transform_up(&|plan| { + state.plan.transform_up_old(&|plan| { statistical_join_selection_subrule(plan, collect_left_threshold) }) } diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 540f9a6a132b..249537534ada 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -160,7 +160,7 @@ impl PhysicalOptimizerRule for LimitedDistinctAggregation { config: &ConfigOptions, ) -> Result> { let plan = if config.optimizer.enable_distinct_aggregation_soft_limit { - plan.transform_down(&|plan| { + plan.transform_down_old(&|plan| { Ok( if let Some(plan) = LimitedDistinctAggregation::transform_limit(plan.clone()) diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index f8bf3bb965e8..c817f6b4ad35 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -192,7 +192,7 @@ impl PhysicalOptimizerRule for OutputRequirements { ) -> Result> { match self.mode { RuleMode::Add => require_top_ordering(plan), - RuleMode::Remove => plan.transform_up(&|plan| { + RuleMode::Remove => plan.transform_up_old(&|plan| { if let Some(sort_req) = plan.as_any().downcast_ref::() { diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 122ce7171bd3..9176bab57656 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -54,7 +54,7 @@ impl PhysicalOptimizerRule for PipelineChecker { ) -> Result> { let pipeline = PipelineStatePropagator::new(plan); let state = pipeline - .transform_up(&|p| check_finiteness_requirements(p, &config.optimizer))?; + .transform_up_old(&|p| check_finiteness_requirements(p, &config.optimizer))?; Ok(state.plan) } diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index e2b290f3f5ce..6b7e139e711a 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -72,7 +72,7 @@ impl PhysicalOptimizerRule for ProjectionPushdown { plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_down(&remove_unnecessary_projections) + plan.transform_down_old(&remove_unnecessary_projections) } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 2423ccc4c32e..41a3a9397bbc 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -678,7 +678,7 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform_up(&|expr| { + e.transform_up_old(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { return Ok(Transformed::Yes(Arc::new(column_new.clone()))); diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 21602487640f..4b001d67aca9 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -344,7 +344,7 @@ mod tests { // let optimized_physical_plan = physical_plan.transform_down(&replace_repartition_execs)?; let config = SessionConfig::new().with_prefer_existing_sort($ALLOW_BOUNDED); let plan_with_pipeline_fixer = OrderPreservationContext::new(physical_plan); - let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; + let parallel = plan_with_pipeline_fixer.transform_up_old(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; let optimized_physical_plan = parallel.plan; // Get string representation of the plan diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 4b06218df9e9..d06adb82c83e 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -18,19 +18,15 @@ use std::sync::Arc; use crate::physical_optimizer::utils::{ - add_sort_above, is_limit, is_sort_preserving_merge, is_union, is_window, + is_limit, is_sort_preserving_merge, is_union, is_window, }; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::joins::utils::calculate_join_output_ordering; use crate::physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use crate::physical_plan::ExecutionPlan; -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, -}; use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; @@ -38,162 +34,7 @@ use datafusion_physical_expr::{ LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use itertools::izip; - -/// This is a "data class" we use within the [`EnforceSorting`] rule to push -/// down [`SortExec`] in the plan. In some cases, we can reduce the total -/// computational cost by pushing down `SortExec`s through some executors. -/// -/// [`EnforceSorting`]: crate::physical_optimizer::enforce_sorting::EnforceSorting -#[derive(Debug, Clone)] -pub(crate) struct SortPushDown { - /// Current plan - pub plan: Arc, - /// Parent required sort ordering - required_ordering: Option>, - /// The adjusted request sort ordering to children. - /// By default they are the same as the plan's required input ordering, but can be adjusted based on parent required sort ordering properties. - adjusted_request_ordering: Vec>>, -} - -impl SortPushDown { - pub fn init(plan: Arc) -> Self { - let request_ordering = plan.required_input_ordering(); - SortPushDown { - plan, - required_ordering: None, - adjusted_request_ordering: request_ordering, - } - } - - pub fn children(&self) -> Vec { - izip!( - self.plan.children().into_iter(), - self.adjusted_request_ordering.clone().into_iter(), - ) - .map(|(child, from_parent)| { - let child_request_ordering = child.required_input_ordering(); - SortPushDown { - plan: child, - required_ordering: from_parent, - adjusted_request_ordering: child_request_ordering, - } - }) - .collect() - } -} - -impl TreeNode for SortPushDown { - fn apply_children(&self, f: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - self.children().iter().for_each_till_continue(f) - } - - fn map_children(mut self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.children(); - if !children.is_empty() { - let children_plans = children - .into_iter() - .map(transform) - .map(|r| r.map(|s| s.plan)) - .collect::>>()?; - - match with_new_children_if_necessary(self.plan, children_plans)? { - Transformed::Yes(plan) | Transformed::No(plan) => { - self.plan = plan; - } - } - }; - Ok(self) - } - - fn transform_children(&mut self, f: &mut F) -> Result - where - F: FnMut(&mut Self) -> Result, - { - let mut children = self.children(); - if !children.is_empty() { - let tnr = children.iter_mut().for_each_till_continue(f)?; - let children_plans = children.into_iter().map(|c| c.plan).collect(); - self.plan = - with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); - Ok(tnr) - } else { - Ok(TreeNodeRecursion::Continue) - } - } -} - -pub(crate) fn pushdown_sorts( - requirements: SortPushDown, -) -> Result> { - let plan = &requirements.plan; - let parent_required = requirements.required_ordering.as_deref().unwrap_or(&[]); - if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let new_plan = if !plan - .equivalence_properties() - .ordering_satisfy_requirement(parent_required) - { - // If the current plan is a SortExec, modify it to satisfy parent requirements: - let mut new_plan = sort_exec.input().clone(); - add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); - new_plan - } else { - requirements.plan - }; - let required_ordering = new_plan - .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs) - .unwrap_or_default(); - // Since new_plan is a SortExec, we can safely get the 0th index. - let child = new_plan.children().swap_remove(0); - if let Some(adjusted) = - pushdown_requirement_to_children(&child, &required_ordering)? - { - // Can push down requirements - Ok(Transformed::Yes(SortPushDown { - plan: child, - required_ordering: None, - adjusted_request_ordering: adjusted, - })) - } else { - // Can not push down requirements - Ok(Transformed::Yes(SortPushDown::init(new_plan))) - } - } else { - // Executors other than SortExec - if plan - .equivalence_properties() - .ordering_satisfy_requirement(parent_required) - { - // Satisfies parent requirements, immediately return. - return Ok(Transformed::Yes(SortPushDown { - required_ordering: None, - ..requirements - })); - } - // Can not satisfy the parent requirements, check whether the requirements can be pushed down: - if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_required)? { - Ok(Transformed::Yes(SortPushDown { - plan: requirements.plan, - required_ordering: None, - adjusted_request_ordering: adjusted, - })) - } else { - // Can not push down requirements, add new SortExec: - let mut new_plan = requirements.plan; - add_sort_above(&mut new_plan, parent_required, None); - Ok(Transformed::Yes(SortPushDown::init(new_plan))) - } - } -} - -fn pushdown_requirement_to_children( +pub fn pushdown_requirement_to_children( plan: &Arc, parent_required: LexRequirementRef, ) -> Result>>>> { diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index dd0261420304..f00c44b3234f 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -138,7 +138,7 @@ impl PhysicalOptimizerRule for TopKAggregation { config: &ConfigOptions, ) -> Result> { let plan = if config.optimizer.enable_topk_aggregation { - plan.transform_down(&|plan| { + plan.transform_down_old(&|plan| { Ok( if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { Transformed::Yes(plan) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 5369d502113b..93ec2f369b41 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1150,7 +1150,7 @@ impl Expr { /// For example, gicen an expression like ` = $0` will infer `$0` to /// have type `int32`. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { - self.transform_up(&|mut expr| { + self.transform_up_old(&|mut expr| { // Default to assuming the arguments are the same type if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index cbdeb16f99b2..a91cd408aed1 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -33,7 +33,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; @@ -57,7 +57,7 @@ pub fn normalize_col_with_schemas( schemas: &[&Arc], using_columns: &[HashSet], ) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; @@ -75,7 +75,7 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( schemas: &[&[&DFSchema]], using_columns: &[HashSet], ) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = @@ -102,7 +102,7 @@ pub fn normalize_cols( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { @@ -122,7 +122,7 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul /// For example, if there were expressions like `foo.bar` this would /// rewrite it to just `bar`. pub fn unnormalize_col(expr: Expr) -> Expr { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = Column { @@ -164,7 +164,7 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { /// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column /// in the expression tree. pub fn strip_outer_reference(expr: Expr) -> Expr { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { Transformed::Yes(Expr::Column(col)) @@ -307,14 +307,14 @@ mod test { // rewrites "foo" --> "bar" let rewritten = col("state") .eq(lit("foo")) - .transform_up(&transformer) + .transform_up_old(&transformer) .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); // doesn't rewrite let rewritten = col("state") .eq(lit("baz")) - .transform_up(&transformer) + .transform_up_old(&transformer) .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 1e7efcafd04d..e275487c4574 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -83,7 +83,7 @@ fn rewrite_in_terms_of_projection( ) -> Result { // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let col = Expr::Column( diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 3d8a8356f397..013a7b673265 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -42,8 +42,7 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeRecursionResult, - TreeNodeTransformer, VisitRecursionIterator, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeTransformer, VisitRecursionIterator, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, @@ -316,7 +315,7 @@ impl LogicalPlan { where F: FnMut(&Expr) -> Result, { - let f = &mut |e: &Expr| f(e).fail_on_prune(); + let f = &mut |e: &Expr| f(e)?.fail_on_prune(); match self { LogicalPlan::Projection(Projection { expr, .. }) => { @@ -352,7 +351,7 @@ impl LogicalPlan { on.iter() // it not ideal to create an expr here to analyze them, but could cache it on the Join itself .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .for_each_till_continue(&mut |e| f(&e)) + .for_each_till_continue(&mut |e| f(&e))? .and_then_on_continue(|| filter.iter().for_each_till_continue(f)) } LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().for_each_till_continue(f), @@ -1154,7 +1153,7 @@ impl LogicalPlan { // LogicalPlan::Subquery (even though it is // actually a Subquery alias) let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - f(&synthetic_plan).fail_on_prune() + f(&synthetic_plan)?.fail_on_prune() } _ => Ok(TreeNodeRecursion::Continue), }) @@ -1225,7 +1224,7 @@ impl LogicalPlan { expr: Expr, param_values: &ParamValues, ) -> Result { - expr.transform_up(&|expr| { + expr.transform_up_old(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { let value = @@ -3195,7 +3194,7 @@ digraph { // after transformation, because plan is not the same anymore, // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs let plan = plan - .transform_up(&|plan| match plan { + .transform_up_old(&|plan| match plan { LogicalPlan::TableScan(table) => { let filter = Filter::try_new( external_filter.clone(), diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 8ec4a94204b0..de407063d78f 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -24,9 +24,7 @@ use crate::expr::{ }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::{ - TreeNode, TreeNodeRecursion, TreeNodeRecursionResult, VisitRecursionIterator, -}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIterator}; use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { @@ -51,12 +49,12 @@ impl TreeNode for Expr { | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery{ expr, .. }) => f(expr), Expr::GetIndexedField(GetIndexedField { expr, field }) => { - f(expr).and_then_on_continue(|| match field { + f(expr)?.and_then_on_continue(|| match field { GetFieldAccess::ListIndex {key} => { f(key) }, GetFieldAccess::ListRange { start, stop} => { - f(start).and_then_on_continue(|| f(stop)) + f(start)?.and_then_on_continue(|| f(stop)) } GetFieldAccess::NamedStructField { name: _name } => Ok(TreeNodeRecursion::Continue) }) @@ -78,38 +76,38 @@ impl TreeNode for Expr { | Expr::Wildcard {..} | Expr::Placeholder (_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - f(left) + f(left)? .and_then_on_continue(|| f(right)) } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - f(expr) + f(expr)? .and_then_on_continue(|| f(pattern)) } Expr::Between(Between { expr, low, high, .. }) => { - f(expr) - .and_then_on_continue(|| f(low)) + f(expr)? + .and_then_on_continue(|| f(low))? .and_then_on_continue(|| f(high)) }, Expr::Case( Case { expr, when_then_expr, else_expr }) => { - expr.as_deref().into_iter().for_each_till_continue(f) + expr.as_deref().into_iter().for_each_till_continue(f)? .and_then_on_continue(|| - when_then_expr.iter().for_each_till_continue(&mut |(w, t)| f(w).and_then_on_continue(|| f(t)))) + when_then_expr.iter().for_each_till_continue(&mut |(w, t)| f(w)?.and_then_on_continue(|| f(t))))? .and_then_on_continue(|| else_expr.as_deref().into_iter().for_each_till_continue(f)) } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => { - args.iter().for_each_till_continue(f) - .and_then_on_continue(|| filter.as_deref().into_iter().for_each_till_continue(f)) + args.iter().for_each_till_continue(f)? + .and_then_on_continue(|| filter.as_deref().into_iter().for_each_till_continue(f))? .and_then_on_continue(|| order_by.iter().flatten().for_each_till_continue(f)) } Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. }) => { - args.iter().for_each_till_continue(f) - .and_then_on_continue(|| partition_by.iter().for_each_till_continue(f)) + args.iter().for_each_till_continue(f)? + .and_then_on_continue(|| partition_by.iter().for_each_till_continue(f))? .and_then_on_continue(|| order_by.iter().for_each_till_continue(f)) } Expr::InList(InList { expr, list, .. }) => { - f(expr) + f(expr)? .and_then_on_continue(|| list.iter().for_each_till_continue(f)) } } @@ -362,14 +360,17 @@ impl TreeNode for Expr { | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => f(expr), + | Expr::InSubquery(InSubquery{ expr, .. }) => { + let x = expr; + f(x) + } Expr::GetIndexedField(GetIndexedField { expr, field }) => { - f(expr).and_then_on_continue(|| match field { + f(expr)?.and_then_on_continue(|| match field { GetFieldAccess::ListIndex {key} => { f(key) }, GetFieldAccess::ListRange { start, stop} => { - f(start).and_then_on_continue(|| f(stop)) + f(start)?.and_then_on_continue(|| f(stop)) } GetFieldAccess::NamedStructField { name: _name } => Ok(TreeNodeRecursion::Continue) }) @@ -391,37 +392,37 @@ impl TreeNode for Expr { | Expr::Wildcard {..} | Expr::Placeholder (_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - f(left) + f(left)? .and_then_on_continue(|| f(right)) } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - f(expr) + f(expr)? .and_then_on_continue(|| f(pattern)) } Expr::Between(Between { expr, low, high, .. }) => { - f(expr) - .and_then_on_continue(|| f(low)) + f(expr)? + .and_then_on_continue(|| f(low))? .and_then_on_continue(|| f(high)) }, Expr::Case( Case { expr, when_then_expr, else_expr }) => { - expr.as_deref_mut().into_iter().for_each_till_continue(f) + expr.as_deref_mut().into_iter().for_each_till_continue(f)? .and_then_on_continue(|| - when_then_expr.iter_mut().for_each_till_continue(&mut |(w, t)| f(w).and_then_on_continue(|| f(t)))) + when_then_expr.iter_mut().for_each_till_continue(&mut |(w, t)| f(w)?.and_then_on_continue(|| f(t))))? .and_then_on_continue(|| else_expr.as_deref_mut().into_iter().for_each_till_continue(f)) } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => { - args.iter_mut().for_each_till_continue(f) - .and_then_on_continue(|| filter.as_deref_mut().into_iter().for_each_till_continue(f)) + args.iter_mut().for_each_till_continue(f)? + .and_then_on_continue(|| filter.as_deref_mut().into_iter().for_each_till_continue(f))? .and_then_on_continue(|| order_by.iter_mut().flatten().for_each_till_continue(f)) } Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. }) => { - args.iter_mut().for_each_till_continue(f) - .and_then_on_continue(|| partition_by.iter_mut().for_each_till_continue(f)) + args.iter_mut().for_each_till_continue(f)? + .and_then_on_continue(|| partition_by.iter_mut().for_each_till_continue(f))? .and_then_on_continue(|| order_by.iter_mut().for_each_till_continue(f)) } Expr::InList(InList { expr, list, .. }) => { - f(expr) + f(expr)? .and_then_on_continue(|| list.iter_mut().for_each_till_continue(f)) } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 17b1ad8cc73f..73309c1882dc 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -47,7 +47,7 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down(&analyze_internal) + plan.transform_down_old(&analyze_internal) } fn name(&self) -> &str { @@ -155,7 +155,7 @@ impl TreeNodeTransformer for CountWildcardRewriter { let new_plan = subquery .as_ref() .clone() - .transform_down(&analyze_internal)?; + .transform_down_old(&analyze_internal)?; *subquery = Arc::new(new_plan); } Expr::InSubquery(InSubquery { @@ -165,7 +165,7 @@ impl TreeNodeTransformer for CountWildcardRewriter { let new_plan = subquery .as_ref() .clone() - .transform_down(&analyze_internal)?; + .transform_down_old(&analyze_internal)?; *subquery = Arc::new(new_plan); } Expr::Exists(Exists { @@ -175,7 +175,7 @@ impl TreeNodeTransformer for CountWildcardRewriter { let new_plan = subquery .as_ref() .clone() - .transform_down(&analyze_internal)?; + .transform_down_old(&analyze_internal)?; *subquery = Arc::new(new_plan); } _ => {} diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index a418fbf5537b..f2e00ed0763d 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -42,7 +42,7 @@ impl InlineTableScan { impl AnalyzerRule for InlineTableScan { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_up(&analyze_internal) + plan.transform_up_old(&analyze_internal) } fn name(&self) -> &str { @@ -74,7 +74,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Transformed::Yes(plan) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform_up(&rewrite_subquery)?; + let new_expr = filter.predicate.transform_up_old(&rewrite_subquery)?; Transformed::Yes(LogicalPlan::Filter(Filter::try_new( new_expr, filter.input, @@ -88,7 +88,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { match expr { Expr::Exists(Exists { subquery, negated }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up_old(&analyze_internal)?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::Yes(Expr::Exists(Exists { subquery, negated }))) } @@ -98,7 +98,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { negated, }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up_old(&analyze_internal)?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::Yes(Expr::InSubquery(InSubquery::new( expr, subquery, negated, @@ -106,7 +106,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { } Expr::ScalarSubquery(subquery) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up_old(&analyze_internal)?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::Yes(Expr::ScalarSubquery(subquery))) } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index a68f374d9fe6..3df604a62c8a 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -370,7 +370,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for e in agg_expr.iter() { - let mut result_expr = e.clone().transform_up(&|expr| { + let mut result_expr = e.clone().transform_up_old(&|expr| { let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { match func_def { @@ -415,7 +415,7 @@ fn proj_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for expr in proj_expr.iter() { - let result_expr = expr.clone().transform_up(&|expr| { + let result_expr = expr.clone().transform_up_old(&|expr| { if let Expr::Column(Column { name, .. }) = &expr { if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::Yes(result_expr.clone())) @@ -448,7 +448,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( input_expr_result_map_for_count_bug: &ExprResultMap, expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result> { - let result_expr = filter_expr.clone().transform_up(&|expr| { + let result_expr = filter_expr.clone().transform_up_old(&|expr| { if let Expr::Column(Column { name, .. }) = &expr { if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::Yes(result_expr.clone())) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index cb9c06154bad..7a58944e5ac9 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -961,7 +961,7 @@ pub fn replace_cols_by_name( e: Expr, replace_map: &HashMap, ) -> Result { - e.transform_up(&|expr| { + e.transform_up_old(&|expr| { Ok(if let Expr::Column(c) = &expr { match replace_map.get(&c.flat_name()) { Some(new_c) => Transformed::Yes(new_c.clone()), diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 34ed4a9475cb..378e01f4bfa9 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -87,7 +87,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = - rewrite_expr.clone().transform_up(&|expr| { + rewrite_expr.clone().transform_up_old(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) @@ -141,8 +141,9 @@ impl OptimizerRule for ScalarSubqueryToJoin { if let Some(rewrite_expr) = expr_to_rewrite_expr_map.get(expr) { - let new_expr = - rewrite_expr.clone().transform_up(&|expr| { + let new_expr = rewrite_expr + .clone() + .transform_up_old(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index 4899d69bad58..78e232a6bb59 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -30,7 +30,7 @@ use crate::{ use arrow::datatypes::SchemaRef; use arrow_schema::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{JoinSide, JoinType, Result}; use indexmap::IndexSet; @@ -169,10 +169,10 @@ impl ProjectionMapping { .enumerate() .map(|(expr_idx, (expression, name))| { let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - expression - .clone() - .transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => { + let mut source_expr = expression.clone(); + source_expr + .transform_down(&mut |e| { + if let Some(col) = e.as_any().downcast_ref::() { // Sometimes, an expression and its name in the input_schema // doesn't match. This can cause problems, so we make sure // that the expression name matches with the name in `input_schema`. @@ -181,11 +181,11 @@ impl ProjectionMapping { let matching_input_field = input_schema.field(idx); let matching_input_column = Column::new(matching_input_field.name(), idx); - Ok(Transformed::Yes(Arc::new(matching_input_column))) + *e = Arc::new(matching_input_column) } - None => Ok(Transformed::No(e)), + Ok(TreeNodeRecursion::Continue) }) - .map(|source_expr| (source_expr, target_expr)) + .map(|_| (source_expr, target_expr)) }) .collect::>>() .map(|map| Self { map }) @@ -352,7 +352,7 @@ impl EquivalenceGroup { /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.clone() - .transform_up(&|expr| { + .transform_up_old(&|expr| { for cls in self.iter() { if cls.contains(&expr) { return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); @@ -752,7 +752,7 @@ pub fn add_offset_to_expr( expr: Arc, offset: usize, ) -> Arc { - expr.transform_down(&|e| match e.as_any().downcast_ref::() { + expr.transform_down_old(&|e| match e.as_any().downcast_ref::() { Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( col.name(), offset + col.index(), @@ -1517,7 +1517,7 @@ impl EquivalenceProperties { /// the given expression. pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { ExprOrdering::new(expr.clone()) - .transform_up(&|expr| Ok(update_ordering(expr, self))) + .transform_up_old(&|expr| Ok(update_ordering(expr, self))) // Guaranteed to always return `Ok`. .unwrap() } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index d637cf1e54e6..8958c71c585e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -944,7 +944,7 @@ mod tests { let expr2 = expr .clone() - .transform_up(&|e| { + .transform_up_old(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { @@ -965,7 +965,7 @@ mod tests { let expr3 = expr .clone() - .transform_down(&|e| { + .transform_down_old(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index cf4d0a077e6f..d07a960b0f71 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -282,7 +282,7 @@ pub fn reassign_predicate_columns( schema: &SchemaRef, ignore_not_found: bool, ) -> Result> { - pred.transform_down(&|expr| { + pred.transform_down_old(&|expr| { let expr_any = expr.as_any(); if let Some(column) = expr_any.downcast_ref::() { diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 64a976a1e39f..5be90a6e3bed 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -281,7 +281,7 @@ pub fn convert_sort_expr_with_filter_schema( if all_columns_are_included { // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. - let converted_filter_expr = expr.transform_up(&|p| { + let converted_filter_expr = expr.transform_up_old(&|p| { convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { match transformed { Some(transformed) => Transformed::Yes(transformed), diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 616a2fc74932..76c270ffe463 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -33,7 +33,7 @@ use std::collections::HashMap; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { - expr.clone().transform_up(&|nested_expr| { + expr.clone().transform_up_old(&|nested_expr| { match nested_expr { Expr::Column(col) => { let field = plan.schema().field_from_column(&col)?; @@ -66,7 +66,7 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - expr.clone().transform_up(&|nested_expr| { + expr.clone().transform_up_old(&|nested_expr| { if base_exprs.contains(&nested_expr) { Ok(Transformed::Yes(expr_as_column_expr(&nested_expr, plan)?)) } else { @@ -170,16 +170,17 @@ pub(crate) fn resolve_aliases_to_exprs( expr: &Expr, aliases: &HashMap, ) -> Result { - expr.clone().transform_up(&|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { - if let Some(aliased_expr) = aliases.get(&c.name) { - Ok(Transformed::Yes(aliased_expr.clone())) - } else { - Ok(Transformed::No(Expr::Column(c))) + expr.clone() + .transform_up_old(&|nested_expr| match nested_expr { + Expr::Column(c) if c.relation.is_none() => { + if let Some(aliased_expr) = aliases.get(&c.name) { + Ok(Transformed::Yes(aliased_expr.clone())) + } else { + Ok(Transformed::No(Expr::Column(c))) + } } - } - _ => Ok(Transformed::No(nested_expr)), - }) + _ => Ok(Transformed::No(nested_expr)), + }) } /// given a slice of window expressions sharing the same sort key, find their common partition