diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 5e95562033e6..226f548dd446 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule { impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(&|plan| { + plan.transform_up_old(&|plan| { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; @@ -106,7 +106,7 @@ impl MyAnalyzerRule { } fn analyze_expr(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up_old(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule { /// use rewrite_expr to modify the expression tree. fn my_rewrite(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up_old(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Between(Between { diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 5f11c8cc1d11..9c484d6b52f7 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -23,10 +23,24 @@ use std::sync::Arc; use crate::Result; -/// Defines a visitable and rewriteable a tree node. This trait is -/// implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as -/// well as expression trees ([`PhysicalExpr`], [`Expr`]) in -/// DataFusion +/// Defines a tree node that can have children of the same type as the parent node. The +/// implementations must provide [`TreeNode::visit_children()`] and +/// [`TreeNode::transform_children()`] for visiting and changing the structure of the tree. +/// +/// [`TreeNode`] is implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as well +/// as expression trees ([`PhysicalExpr`], [`Expr`]) in DataFusion. +/// +/// Besides the children, each tree node can define links to embedded trees of the same +/// type. The root node of these trees are called inner children of a node. +/// +/// A logical plan of a query is a tree of [`LogicalPlan`] nodes, where each node can +/// contain multiple expression ([`Expr`]) trees. But expression tree nodes can contain +/// logical plans of subqueries, which are again trees of [`LogicalPlan`] nodes. The root +/// nodes of these subquery plans are the inner children of the containing query plan +/// node. +/// +/// Tree node implementations can provide [`TreeNode::visit_inner_children()`] for +/// visiting the structure of the inner tree. /// /// /// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html @@ -37,28 +51,40 @@ pub trait TreeNode: Sized + Clone { /// Returns all children of the TreeNode fn children_nodes(&self) -> Vec>; - /// Use preorder to iterate the node on the tree so that we can - /// stop fast for some cases. - /// - /// The `op` closure can be used to collect some info from the - /// tree node or do some checking for the tree node. - fn apply(&self, op: &mut F) -> Result + /// Applies `f` to the tree node, then to its inner children and then to its children + /// depending on the result of `f` in a preorder traversal. + /// See [`TreeNodeRecursion`] for more details on how the preorder traversal can be + /// controlled. + /// If an [`Err`] result is returned, recursion is stopped immediately. + fn visit_down(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - match op(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.apply_children(&mut |node| node.apply(op)) + // Apply `f` on self. + f(self)? + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| { + // Run the recursive `apply` on each inner children, but as they are + // unrelated root nodes of inner trees if any returns stop then continue + // with the next one. + self.visit_inner_children(&mut |c| c.visit_down(f)?.continue_on_stop())? + // Run the recursive `apply` on each children. + .and_then_on_continue(|| { + self.visit_children(&mut |c| c.visit_down(f)) + }) + })? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() } - /// Visit the tree node using the given [TreeNodeVisitor] - /// It performs a depth first walk of an node and its children. + /// Uses a [`TreeNodeVisitor`] to visit the tree node, then its inner children and + /// then its children depending on the result of [`TreeNodeVisitor::pre_visit()`] and + /// [`TreeNodeVisitor::post_visit()`] in a traversal. + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. + /// + /// If an [`Err`] result is returned, recursion is stopped immediately. /// /// For an node tree such as /// ```text @@ -77,56 +103,92 @@ pub trait TreeNode: Sized + Clone { /// post_visit(ParentNode) /// ``` /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is post_visit - /// called on that node. Details see [`TreeNodeVisitor`] - /// - /// If using the default [`TreeNodeVisitor::post_visit`] that does - /// nothing, [`Self::apply`] should be preferred. - fn visit>( + /// If using the default [`TreeNodeVisitor::post_visit()`] that does nothing, + /// [`Self::visit_down()`] should be preferred. + fn visit>( &self, visitor: &mut V, - ) -> Result { - match visitor.pre_visit(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; + ) -> Result { + // Apply `pre_visit` on self. + visitor + .pre_visit(self)? + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| { + // Run the recursive `visit` on each inner children, but as they are + // unrelated subquery plans if any returns stop then continue with the + // next one. + self.visit_inner_children(&mut |c| c.visit(visitor)?.continue_on_stop())? + // Run the recursive `visit` on each children. + .and_then_on_continue(|| { + self.visit_children(&mut |c| c.visit(visitor)) + })? + // Apply `post_visit` on self. + .and_then_on_continue(|| visitor.post_visit(self)) + })? + // Applying `pre_visit` or `post_visit` on self might have returned prune, + // but we need to propagate continue. + .continue_on_prune() + } - match self.apply_children(&mut |node| node.visit(visitor))? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } + fn transform>( + &mut self, + transformer: &mut T, + ) -> Result { + // Apply `pre_transform` on self. + transformer + .pre_transform(self)? + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| + // Run the recursive `transform` on each children. + self + .transform_children(&mut |c| c.transform(transformer))? + // Apply `post_transform` on new self. + .and_then_on_continue(|| transformer.post_transform(self)))? + // Applying `pre_transform` or `post_transform` on self might have returned + // prune, but we need to propagate continue. + .continue_on_prune() + } - visitor.post_visit(self) + fn transform_down(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + // Apply `f` on self. + f(self)? + // If it returns continue (not prune or stop or stop all) then continue + // traversal on inner children and children. + .and_then_on_continue(|| + // Run the recursive `transform` on each children. + self.transform_children(&mut |c| c.transform_down(f)))? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() } - /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. - /// When `op` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, op: &F) -> Result + fn transform_up(&mut self, f: &mut F) -> Result where - F: Fn(Self) -> Result>, + F: FnMut(&mut Self) -> Result, { - self.transform_up(op) + // Run the recursive `transform` on each children. + self.transform_children(&mut |c| c.transform_up(f))? + // Apply `f` on self. + .and_then_on_continue(|| f(self))? + // Applying `f` on self might have returned prune, but we need to propagate + // continue. + .continue_on_prune() } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its /// children(Preorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down(self, op: &F) -> Result + fn transform_down_old(self, op: &F) -> Result where F: Fn(Self) -> Result>, { let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down(op)) + after_op.map_children(|node| node.transform_down_old(op)) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -143,11 +205,11 @@ pub trait TreeNode: Sized + Clone { /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up(self, op: &F) -> Result + fn transform_up_old(self, op: &F) -> Result where F: Fn(Self) -> Result>, { - let after_op_children = self.map_children(|node| node.transform_up(op))?; + let after_op_children = self.map_children(|node| node.transform_up_old(op))?; let new_node = op(after_op_children)?.into(); Ok(new_node) @@ -212,64 +274,121 @@ pub trait TreeNode: Sized + Clone { } } - /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + fn visit_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - for child in self.children_nodes() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + self.children_nodes() + .iter() + .for_each_till_continue(&mut |c| f(c)) + } + + /// Apply `f` to the node's inner children. + fn visit_inner_children(&self, _f: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + Ok(TreeNodeRecursion::Continue) } /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result; + + /// Apply `f` to the node's children. + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result; + + /// Convenience function to do a preorder traversal of the tree nodes with `f` that + /// can't fail. + fn for_each(&self, f: &mut F) + where + F: FnMut(&Self), + { + self.visit_down(&mut |n| { + f(n); + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + } + + /// Convenience function to collect the first non-empty value that `f` returns in a + /// preorder traversal. + fn collect_first(&self, f: &mut F) -> Option + where + F: FnMut(&Self) -> Option, + { + let mut res = None; + self.visit_down(&mut |n| { + res = f(n); + if res.is_some() { + Ok(TreeNodeRecursion::StopAll) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .unwrap(); + res + } + + /// Convenience function to collect all values that `f` returns in a preorder + /// traversal. + fn collect(&self, f: &mut F) -> Vec + where + F: FnMut(&Self) -> Vec, + { + let mut res = vec![]; + self.visit_down(&mut |n| { + res.extend(f(n)); + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + res + } } -/// Implements the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. +/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for +/// recursively visiting [`TreeNode`]s. /// -/// [`TreeNodeVisitor`] allows keeping the algorithms -/// separate from the code to traverse the structure of the `TreeNode` -/// tree and makes it easier to add new types of tree node and -/// algorithms. +/// [`TreeNodeVisitor`] allows keeping the algorithms separate from the code to traverse +/// the structure of the [`TreeNode`] tree and makes it easier to add new types of tree +/// node and algorithms. /// -/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::pre_visit`] -/// and [`TreeNodeVisitor::post_visit`] are invoked recursively -/// on an node tree. +/// When passed to [`TreeNode::visit()`], [`TreeNodeVisitor::pre_visit()`] and +/// [`TreeNodeVisitor::post_visit()`] are invoked recursively on an node tree. +/// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// -/// If an [`Err`] result is returned, recursion is stopped -/// immediately. -/// -/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no -/// children of that tree node are visited, nor is post_visit -/// called on that tree node +/// If an [`Err`] result is returned, recursion is stopped immediately. +pub trait TreeNodeVisitor: Sized { + /// The node type which is visitable. + type Node: TreeNode; + + /// Invoked before any inner children or children of a node are visited. + fn pre_visit(&mut self, node: &Self::Node) -> Result; + + /// Invoked after all inner children and children of a node are visited. + fn post_visit(&mut self, _node: &Self::Node) -> Result; +} + +/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for +/// recursively transforming [`TreeNode`]s. /// -/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no -/// siblings of that tree node are visited, nor is post_visit -/// called on its parent tree node +/// When passed to [`TreeNode::transform()`], [`TreeNodeTransformer::pre_transform()`] and +/// [`TreeNodeTransformer::post_transform()`] are invoked recursively on an node tree. +/// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// -/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no -/// children of that tree node are visited. -pub trait TreeNodeVisitor: Sized { +/// If an [`Err`] result is returned, recursion is stopped immediately. +pub trait TreeNodeTransformer: Sized { /// The node type which is visitable. - type N: TreeNode; + type Node: TreeNode; - /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + /// Invoked before any inner children or children of a node are modified. + fn pre_transform(&mut self, node: &mut Self::Node) -> Result; - /// Invoked after all children of `node` are visited. Default - /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(VisitRecursion::Continue) - } + /// Invoked after all inner children and children of a node are modified. + fn post_transform(&mut self, node: &mut Self::Node) -> Result; } /// Trait for potentially recursively transform an [`TreeNode`] node @@ -303,15 +422,83 @@ pub enum RewriteRecursion { Skip, } -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`]. +/// Controls how a [`TreeNode`] recursion should proceed for [`TreeNode::visit_down()`], +/// [`TreeNode::visit()`], [`TreeNode::transform_down()`], [`TreeNode::transform_up()`] +/// and [`TreeNode::transform()`]. #[derive(Debug)] -pub enum VisitRecursion { - /// Continue the visit to this node tree. +pub enum TreeNodeRecursion { + /// Continue the visit to the next node. Continue, - /// Keep recursive but skip applying op on the children - Skip, - /// Stop the visit to this node tree. + + /// Prune the current subtree. + /// If a preorder visit of a tree node returns [`TreeNodeRecursion::Prune`] then inner + /// children and children will not be visited and postorder visit of the node will not + /// be invoked. + Prune, + + /// Stop recursion on current tree. + /// If recursion runs on an inner tree then returning [`TreeNodeRecursion::Stop`] doesn't + /// stop recursion on the outer tree. Stop, + + /// Stop recursion on all (including outer) trees. + StopAll, +} + +impl TreeNodeRecursion { + /// Helper function to define behavior of a [`TreeNode`] recursion to continue with a + /// closure if the recursion so far resulted [`TreeNodeRecursion::Continue]`. + pub fn and_then_on_continue(self, f: F) -> Result + where + F: FnOnce() -> Result, + { + match self { + TreeNodeRecursion::Continue => f(), + o => Ok(o), + } + } + + fn continue_on_prune(self) -> Result { + Ok(match self { + TreeNodeRecursion::Prune => TreeNodeRecursion::Continue, + o => o, + }) + } + + pub fn fail_on_prune(self) -> Result { + Ok(match self { + TreeNodeRecursion::Prune => panic!("Recursion can't prune."), + o => o, + }) + } + + fn continue_on_stop(self) -> Result { + Ok(match self { + TreeNodeRecursion::Stop => TreeNodeRecursion::Continue, + o => o, + }) + } +} + +pub trait VisitRecursionIterator: Iterator { + fn for_each_till_continue(self, f: &mut F) -> Result + where + F: FnMut(Self::Item) -> Result; +} + +impl VisitRecursionIterator for I { + fn for_each_till_continue(self, f: &mut F) -> Result + where + F: FnMut(Self::Item) -> Result, + { + for i in self { + match f(i)? { + TreeNodeRecursion::Continue => {} + o => return Ok(o), + } + } + Ok(TreeNodeRecursion::Continue) + } } pub enum Transformed { @@ -374,4 +561,18 @@ impl TreeNode for Arc { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let mut new_children = self.arc_children(); + if !new_children.is_empty() { + let tnr = new_children.iter_mut().for_each_till_continue(f)?; + *self = self.with_new_arc_children(self.clone(), new_children)?; + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 68de55e1a410..eecdf5c2b193 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -37,7 +37,7 @@ use crate::{error::Result, scalar::ScalarValue}; use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; @@ -52,14 +52,14 @@ use object_store::{ObjectMeta, ObjectStore}; /// was performed pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { let mut is_applicable = true; - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { match expr { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - Ok(VisitRecursion::Skip) + Ok(TreeNodeRecursion::Prune) } else { - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } Expr::Literal(_) @@ -88,27 +88,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => Ok(VisitRecursion::Continue), + | Expr::Case { .. } => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match &scalar_function.func_def { ScalarFunctionDefinition::BuiltIn(fun) => { match fun.volatility() { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } ScalarFunctionDefinition::UDF(fun) => { match fun.signature().volatility { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } @@ -128,7 +128,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } }) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index c51f2d132aad..086df6bb9e6b 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -38,7 +38,7 @@ use crate::{ use datafusion_common::{ alias::AliasGenerator, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -2090,9 +2090,9 @@ impl<'a> BadPlanVisitor<'a> { } impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, node: &Self::N) -> Result { + fn pre_visit(&mut self, node: &Self::Node) -> Result { match node { LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) @@ -2106,9 +2106,13 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { LogicalPlan::Statement(stmt) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } - _ => Ok(VisitRecursion::Continue), + _ => Ok(TreeNodeRecursion::Continue), } } + + fn post_visit(&mut self, _node: &Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } } #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 7b66ca529094..19a8701a1003 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -52,7 +52,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { } let target_batch_size = config.execution.batch_size; - plan.transform_up(&|plan| { + plan.transform_up_old(&|plan| { let plan_any = plan.as_any(); // The goal here is to detect operators that could produce small batches and only // wrap those ones with a CoalesceBatchesExec operator. An alternate approach here diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 7359a6463059..09963edd5979 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -26,7 +26,7 @@ use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGro use crate::physical_plan::ExecutionPlan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; @@ -48,27 +48,27 @@ impl CombinePartialFinalAggregate { impl PhysicalOptimizerRule for CombinePartialFinalAggregate { fn optimize( &self, - plan: Arc, + mut plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_down(&|plan| { - let transformed = - plan.as_any() - .downcast_ref::() - .and_then(|agg_exec| { - if matches!( - agg_exec.mode(), - AggregateMode::Final | AggregateMode::FinalPartitioned - ) { - agg_exec - .input() - .as_any() - .downcast_ref::() - .and_then(|input_agg_exec| { - if matches!( - input_agg_exec.mode(), - AggregateMode::Partial - ) && can_combine( + plan.transform_down(&mut |plan| { + plan.clone() + .as_any() + .downcast_ref::() + .into_iter() + .for_each(|agg_exec| { + if matches!( + agg_exec.mode(), + AggregateMode::Final | AggregateMode::FinalPartitioned + ) { + agg_exec + .input() + .as_any() + .downcast_ref::() + .into_iter() + .for_each(|input_agg_exec| { + if matches!(input_agg_exec.mode(), AggregateMode::Partial) + && can_combine( ( agg_exec.group_by(), agg_exec.aggr_expr(), @@ -79,41 +79,34 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { input_agg_exec.aggr_expr(), input_agg_exec.filter_expr(), ), - ) { - let mode = - if agg_exec.mode() == &AggregateMode::Final { - AggregateMode::Single - } else { - AggregateMode::SinglePartitioned - }; - AggregateExec::try_new( - mode, - input_agg_exec.group_by().clone(), - input_agg_exec.aggr_expr().to_vec(), - input_agg_exec.filter_expr().to_vec(), - input_agg_exec.input().clone(), - input_agg_exec.input_schema(), - ) - .map(|combined_agg| { - combined_agg.with_limit(agg_exec.limit()) - }) - .ok() - .map(Arc::new) + ) + { + let mode = if agg_exec.mode() == &AggregateMode::Final + { + AggregateMode::Single } else { - None - } - }) - } else { - None - } - }); - - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(plan) - }) - }) + AggregateMode::SinglePartitioned + }; + AggregateExec::try_new( + mode, + input_agg_exec.group_by().clone(), + input_agg_exec.aggr_expr().to_vec(), + input_agg_exec.filter_expr().to_vec(), + input_agg_exec.input().clone(), + input_agg_exec.input_schema(), + ) + .map(|combined_agg| { + combined_agg.with_limit(agg_exec.limit()) + }) + .into_iter() + .for_each(|p| *plan = Arc::new(p)) + } + }) + } + }); + Ok(TreeNodeRecursion::Continue) + })?; + Ok(plan) } fn name(&self) -> &str { @@ -178,7 +171,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs { fn discard_column_index(group_expr: Arc) -> Arc { group_expr .clone() - .transform(&|expr| { + .transform_up_old(&|expr| { let normalized_form: Option> = match expr.as_any().downcast_ref::() { Some(column) => Some(Arc::new(Column::new(column.name(), 0))), diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index bf5aa7d02272..7726c7804f6f 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -48,7 +48,9 @@ use crate::physical_plan::{ }; use arrow::compute::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; @@ -201,11 +203,11 @@ impl PhysicalOptimizerRule for EnforceDistribution { // Run a top-down process to adjust input key ordering recursively let plan_requirements = PlanWithKeyRequirements::new(plan); let adjusted = - plan_requirements.transform_down(&adjust_input_keys_ordering)?; + plan_requirements.transform_down_old(&adjust_input_keys_ordering)?; adjusted.plan } else { // Run a bottom-up process - plan.transform_up(&|plan| { + plan.transform_up_old(&|plan| { Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) })? }; @@ -213,7 +215,7 @@ impl PhysicalOptimizerRule for EnforceDistribution { let distribution_context = DistributionContext::new(adjusted); // Distribution enforcement needs to be applied bottom-up. let distribution_context = - distribution_context.transform_up(&|distribution_context| { + distribution_context.transform_up_old(&|distribution_context| { ensure_distribution(distribution_context, config) })?; Ok(distribution_context.plan) @@ -1432,6 +1434,23 @@ impl TreeNode for DistributionContext { } Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + if !self.children_nodes.is_empty() { + let tnr = self.children_nodes.iter_mut().for_each_till_continue(f)?; + self.plan = with_new_children_if_necessary( + self.plan.clone(), + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } /// implement Display method for `DistributionContext` struct. @@ -1496,6 +1515,23 @@ impl TreeNode for PlanWithKeyRequirements { } Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + if !self.children.is_empty() { + let tnr = self.children.iter_mut().for_each_till_continue(f)?; + self.plan = with_new_children_if_necessary( + self.plan.clone(), + self.children.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } /// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index f609ddea66cf..042f34b93dca 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -58,7 +58,9 @@ use crate::physical_plan::{ with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -168,6 +170,23 @@ impl TreeNode for PlanWithCorrespondingSort { } Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + if !self.children_nodes.is_empty() { + let tnr = self.children_nodes.iter_mut().for_each_till_continue(f)?; + self.plan = with_new_children_if_necessary( + self.plan.clone(), + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } /// This object is used within the [`EnforceSorting`] rule to track the closest @@ -249,6 +268,23 @@ impl TreeNode for PlanWithCorrespondingCoalescePartitions { } Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + if !self.children_nodes.is_empty() { + let tnr = self.children_nodes.iter_mut().for_each_till_continue(f)?; + self.plan = with_new_children_if_necessary( + self.plan.clone(), + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } /// The boolean flag `repartition_sorts` defined in the config indicates @@ -264,12 +300,12 @@ impl PhysicalOptimizerRule for EnforceSorting { let plan_requirements = PlanWithCorrespondingSort::new(plan); // Execute a bottom-up traversal to enforce sorting requirements, // remove unnecessary sorts, and optimize sort-sensitive operators: - let adjusted = plan_requirements.transform_up(&ensure_sorting)?; + let adjusted = plan_requirements.transform_up_old(&ensure_sorting)?; let new_plan = if config.optimizer.repartition_sorts { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new(adjusted.plan); let parallel = - plan_with_coalesce_partitions.transform_up(¶llelize_sorts)?; + plan_with_coalesce_partitions.transform_up_old(¶llelize_sorts)?; parallel.plan } else { adjusted.plan @@ -277,7 +313,7 @@ impl PhysicalOptimizerRule for EnforceSorting { let plan_with_pipeline_fixer = OrderPreservationContext::new(new_plan); let updated_plan = - plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { + plan_with_pipeline_fixer.transform_up_old(&|plan_with_pipeline_fixer| { replace_with_order_preserving_variants( plan_with_pipeline_fixer, false, @@ -290,7 +326,7 @@ impl PhysicalOptimizerRule for EnforceSorting { // missed by the bottom-up traversal: let mut sort_pushdown = SortPushDown::new(updated_plan.plan); sort_pushdown.assign_initial_requirements(); - let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; + let adjusted = sort_pushdown.transform_down_old(&pushdown_sorts)?; Ok(adjusted.plan) } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 6b2fe24acf00..66a27aa7cb6b 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -237,7 +237,8 @@ impl PhysicalOptimizerRule for JoinSelection { Box::new(hash_join_convert_symmetric_subrule), Box::new(hash_join_swap_subrule), ]; - let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules, config))?; + let state = + pipeline.transform_up_old(&|p| apply_subrules(p, &subrules, config))?; // Next, we apply another subrule that tries to optimize joins using any // statistics their inputs might have. // - For a hash join with partition mode [`PartitionMode::Auto`], we will @@ -251,7 +252,7 @@ impl PhysicalOptimizerRule for JoinSelection { // side is the small side. let config = &config.optimizer; let collect_left_threshold = config.hash_join_single_partition_threshold; - state.plan.transform_up(&|plan| { + state.plan.transform_up_old(&|plan| { statistical_join_selection_subrule(plan, collect_left_threshold) }) } diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 540f9a6a132b..249537534ada 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -160,7 +160,7 @@ impl PhysicalOptimizerRule for LimitedDistinctAggregation { config: &ConfigOptions, ) -> Result> { let plan = if config.optimizer.enable_distinct_aggregation_soft_limit { - plan.transform_down(&|plan| { + plan.transform_down_old(&|plan| { Ok( if let Some(plan) = LimitedDistinctAggregation::transform_limit(plan.clone()) diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index 4d03840d3dd3..ac6e79096b63 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -196,7 +196,7 @@ impl PhysicalOptimizerRule for OutputRequirements { ) -> Result> { match self.mode { RuleMode::Add => require_top_ordering(plan), - RuleMode::Remove => plan.transform_up(&|plan| { + RuleMode::Remove => plan.transform_up_old(&|plan| { if let Some(sort_req) = plan.as_any().downcast_ref::() { diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index e281d0e7c23e..3200b4ff7c43 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -28,7 +28,9 @@ use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::OptimizerOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; @@ -53,7 +55,7 @@ impl PhysicalOptimizerRule for PipelineChecker { ) -> Result> { let pipeline = PipelineStatePropagator::new(plan); let state = pipeline - .transform_up(&|p| check_finiteness_requirements(p, &config.optimizer))?; + .transform_up_old(&|p| check_finiteness_requirements(p, &config.optimizer))?; Ok(state.plan) } @@ -114,6 +116,23 @@ impl TreeNode for PipelineStatePropagator { } Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + if !self.children.is_empty() { + let tnr = self.children.iter_mut().for_each_till_continue(f)?; + self.plan = with_new_children_if_necessary( + self.plan.clone(), + self.children.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } /// This function propagates finiteness information and rejects any plan with diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index d237a3e8607e..2ff9c3572af5 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -43,7 +43,7 @@ use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::JoinSide; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ @@ -72,7 +72,7 @@ impl PhysicalOptimizerRule for ProjectionPushdown { plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_down(&remove_unnecessary_projections) + plan.transform_down_old(&remove_unnecessary_projections) } fn name(&self) -> &str { @@ -255,12 +255,12 @@ fn try_unifying_projections( // Collect the column references usage in the outer projection. projection.expr().iter().for_each(|(expr, _)| { - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { Ok({ if let Some(column) = expr.as_any().downcast_ref::() { *column_ref_map.entry(column.clone()).or_default() += 1; } - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index fecbffdbb041..5055ec8c7e5c 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -810,7 +810,7 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform(&|expr| { + e.transform_up_old(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { return Ok(Transformed::Yes(Arc::new(column_new.clone()))); diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index e49b358608aa..2c35b18538ac 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -30,7 +30,9 @@ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_physical_plan::unbounded_output; /// For a given `plan`, this object carries the information one needs from its @@ -127,6 +129,23 @@ impl TreeNode for OrderPreservationContext { } Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + if !self.children_nodes.is_empty() { + let tnr = self.children_nodes.iter_mut().for_each_till_continue(f)?; + self.plan = with_new_children_if_necessary( + self.plan.clone(), + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } /// Calculates the updated plan by replacing operators that lose ordering @@ -395,7 +414,7 @@ mod tests { // Run the rule top-down let config = SessionConfig::new().with_prefer_existing_sort($PREFER_EXISTING_SORT); let plan_with_pipeline_fixer = OrderPreservationContext::new(physical_plan); - let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; + let parallel = plan_with_pipeline_fixer.transform_up_old(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; let optimized_physical_plan = parallel.plan; // Get string representation of the plan diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 97ca47baf05f..729c367d7ec3 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -29,7 +29,9 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator, +}; use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; @@ -94,6 +96,23 @@ impl TreeNode for SortPushDown { } Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + if !self.children_nodes.is_empty() { + let tnr = self.children_nodes.iter_mut().for_each_till_continue(f)?; + self.plan = with_new_children_if_necessary( + self.plan.clone(), + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + Ok(tnr) + } else { + Ok(TreeNodeRecursion::Continue) + } + } } pub(crate) fn pushdown_sorts( diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index dd0261420304..f00c44b3234f 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -138,7 +138,7 @@ impl PhysicalOptimizerRule for TopKAggregation { config: &ConfigOptions, ) -> Result> { let plan = if config.optimizer.enable_topk_aggregation { - plan.transform_down(&|plan| { + plan.transform_down_old(&|plan| { Ok( if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { Transformed::Yes(plan) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ebf4d3143c12..0b832c1fec17 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -31,11 +31,11 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; -use std::fmt; use std::fmt::{Display, Formatter, Write}; use std::hash::{BuildHasher, Hash, Hasher}; use std::str::FromStr; use std::sync::Arc; +use std::{fmt, mem}; use crate::Signature; @@ -182,6 +182,12 @@ pub enum Expr { OuterReferenceColumn(DataType, Column), } +impl Default for Expr { + fn default() -> Self { + Expr::Literal(ScalarValue::Null) + } +} + /// Alias expression #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Alias { @@ -1054,11 +1060,11 @@ impl Expr { } /// Remove an alias from an expression if one exists. - pub fn unalias(self) -> Expr { - match self { - Expr::Alias(alias) => *alias.expr, - _ => self, + pub fn unalias(&mut self) -> &mut Self { + if let Expr::Alias(alias) = self { + *self = mem::take(alias.expr.as_mut()); } + self } /// Return `self IN ` if `negated` is false, otherwise @@ -1247,7 +1253,7 @@ impl Expr { /// For example, gicen an expression like ` = $0` will infer `$0` to /// have type `int32`. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { - self.transform(&|mut expr| { + self.transform_up_old(&|mut expr| { // Default to assuming the arguments are the same type if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 1f04c80833f0..a91cd408aed1 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -20,7 +20,7 @@ use crate::expr::Alias; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeTransformer}; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use std::collections::HashMap; @@ -33,7 +33,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; @@ -57,7 +57,7 @@ pub fn normalize_col_with_schemas( schemas: &[&Arc], using_columns: &[HashSet], ) -> Result { - expr.transform(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; @@ -75,7 +75,7 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( schemas: &[&[&DFSchema]], using_columns: &[HashSet], ) -> Result { - expr.transform(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = @@ -102,7 +102,7 @@ pub fn normalize_cols( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { @@ -122,7 +122,7 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul /// For example, if there were expressions like `foo.bar` this would /// rewrite it to just `bar`. pub fn unnormalize_col(expr: Expr) -> Expr { - expr.transform(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = Column { @@ -164,7 +164,7 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { /// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column /// in the expression tree. pub fn strip_outer_reference(expr: Expr) -> Expr { - expr.transform(&|expr| { + expr.transform_up_old(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { Transformed::Yes(Expr::Column(col)) @@ -248,12 +248,12 @@ pub fn unalias(expr: Expr) -> Expr { /// /// This is important when optimizing plans to ensure the output /// schema of plan nodes don't change after optimization -pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result +pub fn rewrite_preserving_name(mut expr: Expr, transformer: &mut R) -> Result where - R: TreeNodeRewriter, + R: TreeNodeTransformer, { let original_name = expr.name_for_alias()?; - let expr = expr.rewrite(rewriter)?; + expr.transform(transformer)?; expr.alias_if_changed(original_name) } @@ -263,7 +263,7 @@ mod test { use crate::expr::Sort; use crate::{col, lit, Cast}; use arrow::datatypes::DataType; - use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; + use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DFField, DFSchema, ScalarValue}; use std::ops::Add; @@ -272,17 +272,17 @@ mod test { v: Vec, } - impl TreeNodeRewriter for RecordingRewriter { - type N = Expr; + impl TreeNodeTransformer for RecordingRewriter { + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_transform(&mut self, expr: &mut Expr) -> Result { self.v.push(format!("Previsited {expr}")); - Ok(RewriteRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn mutate(&mut self, expr: Expr) -> Result { + fn post_transform(&mut self, expr: &mut Expr) -> Result { self.v.push(format!("Mutated {expr}")); - Ok(expr) + Ok(TreeNodeRecursion::Continue) } } @@ -305,11 +305,17 @@ mod test { }; // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("foo")) + .transform_up_old(&transformer) + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); // doesn't rewrite - let rewritten = col("state").eq(lit("baz")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("baz")) + .transform_up_old(&transformer) + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } @@ -399,7 +405,8 @@ mod test { #[test] fn rewriter_visit() { let mut rewriter = RecordingRewriter::default(); - col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); + let mut expr = col("state").eq(lit("CO")); + expr.transform(&mut rewriter).unwrap(); assert_eq!( rewriter.v, @@ -439,22 +446,28 @@ mod test { /// rewrites `expr_from` to `rewrite_to` using /// `rewrite_preserving_name` verifying the result is `expected_expr` fn test_rewrite(expr_from: Expr, rewrite_to: Expr) { - struct TestRewriter { - rewrite_to: Expr, - } + struct TestTransformer {} + + impl TreeNodeTransformer for TestTransformer { + type Node = Expr; - impl TreeNodeRewriter for TestRewriter { - type N = Expr; + fn pre_transform( + &mut self, + _node: &mut Self::Node, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } - fn mutate(&mut self, _: Expr) -> Result { - Ok(self.rewrite_to.clone()) + fn post_transform( + &mut self, + _node: &mut Self::Node, + ) -> Result { + Ok(TreeNodeRecursion::Continue) } } - let mut rewriter = TestRewriter { - rewrite_to: rewrite_to.clone(), - }; - let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); + let mut transformer = TestTransformer {}; + let expr = rewrite_preserving_name(expr_from.clone(), &mut transformer).unwrap(); let original_name = match &expr_from { Expr::Sort(Sort { expr, .. }) => expr.display_name(), diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index c87a724d5646..e275487c4574 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -83,7 +83,7 @@ fn rewrite_in_terms_of_projection( ) -> Result { // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" - expr.transform(&|expr| { + expr.transform_up_old(&|expr| { // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let col = Expr::Column( diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index cfc052cfc14c..28779a261497 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1240,9 +1240,10 @@ pub fn project_with_column_index( let alias_expr = expr .into_iter() .enumerate() - .map(|(i, e)| match e { + .map(|(i, mut e)| match e { Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => { - e.unalias().alias(schema.field(i).name()) + e.unalias(); + e.alias(schema.field(i).name()) } Expr::Column(Column { relation: _, diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 112dbf74dba1..2a8c4ce5912d 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -19,7 +19,7 @@ use crate::LogicalPlan; use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; -use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::DataFusionError; use std::fmt; @@ -49,12 +49,12 @@ impl<'a, 'b> IndentVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; fn pre_visit( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { if self.indent > 0 { writeln!(self.f)?; } @@ -69,15 +69,15 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { } self.indent += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { self.indent -= 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -171,12 +171,12 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; fn pre_visit( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { let id = self.graphviz_builder.next_id(); // Create a new graph node for `plan` such as @@ -204,18 +204,18 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { } self.parent_ids.push(id); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { // always be non-empty as pre_visit always pushes // So it should always be Ok(true) let res = self.parent_ids.pop(); res.ok_or(DataFusionError::Internal("Fail to format".to_string())) - .map(|_| VisitRecursion::Continue) + .map(|_| TreeNodeRecursion::Continue) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c0c520c4e211..3629379d2f6b 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -32,8 +32,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, - split_conjunction, + grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, }; use crate::{ build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, @@ -43,8 +42,7 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor, - VisitRecursion, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeTransformer, VisitRecursionIterator, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, @@ -277,9 +275,9 @@ impl LogicalPlan { /// children pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.visit_expressions(&mut |e| { exprs.push(e.clone()); - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -290,13 +288,13 @@ impl LogicalPlan { /// logical plan nodes and all its descendant nodes. pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.visit_expressions(&mut |e| { find_out_reference_exprs(e).into_iter().for_each(|e| { if !exprs.contains(&e) { exprs.push(e) } }); - Ok(()) as Result<(), DataFusionError> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -311,37 +309,41 @@ impl LogicalPlan { exprs } - /// Calls `f` on all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children. - pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> + /// Apply `f` on expressions of the plan node. + /// `f` is not allowed to return [`TreeNodeRecursion::Prune`]. + pub fn visit_expressions(&self, f: &mut F) -> Result where - F: FnMut(&Expr) -> Result<(), E>, + F: FnMut(&Expr) -> Result, { + let f = &mut |e: &Expr| f(e)?.fail_on_prune(); + match self { LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().try_for_each(f) + expr.iter().for_each_till_continue(f) } LogicalPlan::Values(Values { values, .. }) => { - values.iter().flatten().try_for_each(f) + values.iter().flatten().for_each_till_continue(f) } LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { - Partitioning::Hash(expr, _) => expr.iter().try_for_each(f), - Partitioning::DistributeBy(expr) => expr.iter().try_for_each(f), - Partitioning::RoundRobinBatch(_) => Ok(()), + Partitioning::Hash(expr, _) => expr.iter().for_each_till_continue(f), + Partitioning::DistributeBy(expr) => expr.iter().for_each_till_continue(f), + Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().try_for_each(f) + window_expr.iter().for_each_till_continue(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr.iter().chain(aggr_expr.iter()).try_for_each(f), + }) => group_expr + .iter() + .chain(aggr_expr.iter()) + .for_each_till_continue(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). @@ -349,22 +351,21 @@ impl LogicalPlan { on.iter() // it not ideal to create an expr here to analyze them, but could cache it on the Join itself .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .try_for_each(|e| f(&e))?; - - if let Some(filter) = filter.as_ref() { - f(filter) - } else { - Ok(()) - } + .for_each_till_continue(&mut |e| f(&e))? + .and_then_on_continue(|| filter.iter().for_each_till_continue(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().try_for_each(f), + LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().for_each_till_continue(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().try_for_each(f) + extension + .node + .expressions() + .iter() + .for_each_till_continue(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().try_for_each(f) + filters.iter().for_each_till_continue(f) } LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) @@ -378,7 +379,7 @@ impl LogicalPlan { .iter() .chain(select_expr.iter()) .chain(sort_expr.clone().unwrap_or(vec![]).iter()) - .try_for_each(f), + .for_each_till_continue(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Subquery(_) @@ -394,7 +395,7 @@ impl LogicalPlan { | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Ok(()), + | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue), } } @@ -440,7 +441,7 @@ impl LogicalPlan { pub fn using_columns(&self) -> Result>, DataFusionError> { let mut using_columns: Vec> = vec![]; - self.apply(&mut |plan| { + self.visit_down(&mut |plan| { if let LogicalPlan::Join(Join { join_constraint: JoinConstraint::Using, on, @@ -456,7 +457,7 @@ impl LogicalPlan { })?; using_columns.push(columns); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(using_columns) @@ -612,7 +613,7 @@ impl LogicalPlan { } LogicalPlan::Filter { .. } => { assert_eq!(1, expr.len()); - let predicate = expr.pop().unwrap(); + let mut predicate = expr.pop().unwrap(); // filter predicates should not contain aliased expressions so we remove any aliases // before this logic was added we would have aliases within filters such as for @@ -628,29 +629,39 @@ impl LogicalPlan { struct RemoveAliases {} - impl TreeNodeRewriter for RemoveAliases { - type N = Expr; + impl TreeNodeTransformer for RemoveAliases { + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_transform( + &mut self, + expr: &mut Expr, + ) -> Result { match expr { Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { // subqueries could contain aliases so we don't recurse into those - Ok(RewriteRecursion::Stop) + Ok(TreeNodeRecursion::Prune) } - Expr::Alias(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), + Expr::Alias(_) => { + expr.unalias(); + Ok(TreeNodeRecursion::Prune) + } + _ => Ok(TreeNodeRecursion::Continue), } } - fn mutate(&mut self, expr: Expr) -> Result { - Ok(expr.unalias()) + fn post_transform( + &mut self, + expr: &mut Expr, + ) -> Result { + expr.unalias(); + Ok(TreeNodeRecursion::Continue) } } let mut remove_aliases = RemoveAliases {}; - let predicate = predicate.rewrite(&mut remove_aliases)?; + predicate.transform(&mut remove_aliases)?; Filter::try_new(predicate, Arc::new(inputs[0].clone())) .map(LogicalPlan::Filter) @@ -717,10 +728,10 @@ impl LogicalPlan { // The first part of expr is equi-exprs, // and the struct of each equi-expr is like `left-expr = right-expr`. assert_eq!(expr.len(), equi_expr_count); - let new_on:Vec<(Expr,Expr)> = expr.into_iter().map(|equi_expr| { + let new_on:Vec<(Expr,Expr)> = expr.into_iter().map(|mut equi_expr| { // SimplifyExpression rule may add alias to the equi_expr. - let unalias_expr = equi_expr.clone().unalias(); - if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr { + equi_expr.unalias(); + if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = equi_expr { Ok((*left, *right)) } else { internal_err!( @@ -1089,59 +1100,27 @@ impl LogicalPlan { | LogicalPlan::Extension(_) => None, } } -} -impl LogicalPlan { - /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> + /// Apply `f` on the root nodes of subquery plans of the plan node. + /// `f` is not allowed to return [`TreeNodeRecursion::Prune`]. + pub fn visit_subqueries(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> datafusion_common::Result, + F: FnMut(&Self) -> Result, { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.apply(op)?; - } - _ => {} + self.visit_expressions(&mut |e| { + e.visit_down(&mut |e| match e { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + // use a synthetic plan so the collector sees a + // LogicalPlan::Subquery (even though it is + // actually a Subquery alias) + let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); + f(&synthetic_plan)?.fail_on_prune() } - Ok::<(), DataFusionError>(()) + _ => Ok(TreeNodeRecursion::Continue), }) - })?; - Ok(()) - } - - /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> datafusion_common::Result<()> - where - V: TreeNodeVisitor, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the visitor sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.visit(v)?; - } - _ => {} - } - Ok::<(), DataFusionError>(()) - }) - })?; - Ok(()) + }) } /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, @@ -1177,9 +1156,9 @@ impl LogicalPlan { ) -> Result>, DataFusionError> { let mut param_types: HashMap> = HashMap::new(); - self.apply(&mut |plan| { - plan.inspect_expressions(|expr| { - expr.apply(&mut |expr| { + self.visit_down(&mut |plan| { + plan.visit_expressions(&mut |expr| { + expr.visit_down(&mut |expr| { if let Expr::Placeholder(Placeholder { id, data_type }) = expr { let prev = param_types.get(id); match (prev, data_type) { @@ -1194,11 +1173,9 @@ impl LogicalPlan { _ => {} } } - Ok(VisitRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) + }) + }) })?; Ok(param_types) @@ -1210,7 +1187,7 @@ impl LogicalPlan { expr: Expr, param_values: &ParamValues, ) -> Result { - expr.transform(&|expr| { + expr.transform_up_old(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { let value = param_values @@ -2725,9 +2702,9 @@ digraph { } impl TreeNodeVisitor for OkVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", @@ -2738,10 +2715,10 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", @@ -2752,7 +2729,7 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -2808,20 +2785,20 @@ digraph { } impl TreeNodeVisitor for StoppingVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } self.inner.pre_visit(plan)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } self.inner.post_visit(plan) @@ -2877,9 +2854,9 @@ digraph { } impl TreeNodeVisitor for ErrorVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { return not_impl_err!("Error in pre_visit"); } @@ -2887,7 +2864,7 @@ digraph { self.inner.pre_visit(plan) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { return not_impl_err!("Error in post_visit"); } @@ -3193,7 +3170,7 @@ digraph { // after transformation, because plan is not the same anymore, // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs let plan = plan - .transform(&|plan| match plan { + .transform_up_old(&|plan| match plan { LogicalPlan::TableScan(table) => { let filter = Filter::try_new( external_filter.clone(), diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 56388be58b8a..f6af24efbad9 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -25,7 +25,7 @@ use crate::expr::{ use crate::{Expr, GetFieldAccess}; use std::borrow::Cow; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIterator}; use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { @@ -139,6 +139,90 @@ impl TreeNode for Expr { } } + fn visit_children(&self, f: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + match self { + Expr::Alias(Alias{expr,..}) + | Expr::Not(expr) + | Expr::IsNotNull(expr) + | Expr::IsTrue(expr) + | Expr::IsFalse(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotTrue(expr) + | Expr::IsNotFalse(expr) + | Expr::IsNotUnknown(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) + | Expr::Cast(Cast { expr, .. }) + | Expr::TryCast(TryCast { expr, .. }) + | Expr::Sort(Sort { expr, .. }) + | Expr::InSubquery(InSubquery{ expr, .. }) => f(expr), + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + f(expr)?.and_then_on_continue(|| match field { + GetFieldAccess::ListIndex {key} => { + f(key) + }, + GetFieldAccess::ListRange { start, stop} => { + f(start)?.and_then_on_continue(|| f(stop)) + } + GetFieldAccess::NamedStructField { name: _name } => Ok(TreeNodeRecursion::Continue) + }) + } + Expr::GroupingSet(GroupingSet::Rollup(exprs)) + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().for_each_till_continue(f), + Expr::ScalarFunction (ScalarFunction{ args, .. } ) => args.iter().for_each_till_continue(f), + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + lists_of_exprs.iter().flatten().for_each_till_continue(f) + } + | Expr::Column(_) + // Treat OuterReferenceColumn as a leaf expression + | Expr::OuterReferenceColumn(_, _) + | Expr::ScalarVariable(_, _) + | Expr::Literal(_) + | Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::Wildcard {..} + | Expr::Placeholder (_) => Ok(TreeNodeRecursion::Continue), + Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { + f(left)? + .and_then_on_continue(|| f(right)) + } + Expr::Like(Like { expr, pattern, .. }) + | Expr::SimilarTo(Like { expr, pattern, .. }) => { + f(expr)? + .and_then_on_continue(|| f(pattern)) + } + Expr::Between(Between { expr, low, high, .. }) => { + f(expr)? + .and_then_on_continue(|| f(low))? + .and_then_on_continue(|| f(high)) + }, + Expr::Case( Case { expr, when_then_expr, else_expr }) => { + expr.as_deref().into_iter().for_each_till_continue(f)? + .and_then_on_continue(|| + when_then_expr.iter().for_each_till_continue(&mut |(w, t)| f(w)?.and_then_on_continue(|| f(t))))? + .and_then_on_continue(|| else_expr.as_deref().into_iter().for_each_till_continue(f)) + } + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) + => { + args.iter().for_each_till_continue(f)? + .and_then_on_continue(|| filter.as_deref().into_iter().for_each_till_continue(f))? + .and_then_on_continue(|| order_by.iter().flatten().for_each_till_continue(f)) + } + Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. }) => { + args.iter().for_each_till_continue(f)? + .and_then_on_continue(|| partition_by.iter().for_each_till_continue(f))? + .and_then_on_continue(|| order_by.iter().for_each_till_continue(f)) + } + Expr::InList(InList { expr, list, .. }) => { + f(expr)? + .and_then_on_continue(|| list.iter().for_each_till_continue(f)) + } + } + } + fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, @@ -365,6 +449,92 @@ impl TreeNode for Expr { } }) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + match self { + Expr::Alias(Alias { expr,.. }) + | Expr::Not(expr) + | Expr::IsNotNull(expr) + | Expr::IsTrue(expr) + | Expr::IsFalse(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotTrue(expr) + | Expr::IsNotFalse(expr) + | Expr::IsNotUnknown(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) + | Expr::Cast(Cast { expr, .. }) + | Expr::TryCast(TryCast { expr, .. }) + | Expr::Sort(Sort { expr, .. }) + | Expr::InSubquery(InSubquery{ expr, .. }) => { + let x = expr; + f(x) + } + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + f(expr)?.and_then_on_continue(|| match field { + GetFieldAccess::ListIndex {key} => { + f(key) + }, + GetFieldAccess::ListRange { start, stop} => { + f(start)?.and_then_on_continue(|| f(stop)) + } + GetFieldAccess::NamedStructField { name: _name } => Ok(TreeNodeRecursion::Continue) + }) + } + Expr::GroupingSet(GroupingSet::Rollup(exprs)) + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter_mut().for_each_till_continue(f), + | Expr::ScalarFunction(ScalarFunction{ args, .. }) => args.iter_mut().for_each_till_continue(f), + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + lists_of_exprs.iter_mut().flatten().for_each_till_continue(f) + } + | Expr::Column(_) + // Treat OuterReferenceColumn as a leaf expression + | Expr::OuterReferenceColumn(_, _) + | Expr::ScalarVariable(_, _) + | Expr::Literal(_) + | Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::Wildcard {..} + | Expr::Placeholder (_) => Ok(TreeNodeRecursion::Continue), + Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { + f(left)? + .and_then_on_continue(|| f(right)) + } + Expr::Like(Like { expr, pattern, .. }) + | Expr::SimilarTo(Like { expr, pattern, .. }) => { + f(expr)? + .and_then_on_continue(|| f(pattern)) + } + Expr::Between(Between { expr, low, high, .. }) => { + f(expr)? + .and_then_on_continue(|| f(low))? + .and_then_on_continue(|| f(high)) + }, + Expr::Case( Case { expr, when_then_expr, else_expr }) => { + expr.as_deref_mut().into_iter().for_each_till_continue(f)? + .and_then_on_continue(|| + when_then_expr.iter_mut().for_each_till_continue(&mut |(w, t)| f(w)?.and_then_on_continue(|| f(t))))? + .and_then_on_continue(|| else_expr.as_deref_mut().into_iter().for_each_till_continue(f)) + } + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => { + args.iter_mut().for_each_till_continue(f)? + .and_then_on_continue(|| filter.as_deref_mut().into_iter().for_each_till_continue(f))? + .and_then_on_continue(|| order_by.iter_mut().flatten().for_each_till_continue(f)) + } + Expr::WindowFunction(WindowFunction { args, partition_by, order_by, .. }) => { + args.iter_mut().for_each_till_continue(f)? + .and_then_on_continue(|| partition_by.iter_mut().for_each_till_continue(f))? + .and_then_on_continue(|| order_by.iter_mut().for_each_till_continue(f)) + } + Expr::InList(InList { expr, list, .. }) => { + f(expr)? + .and_then_on_continue(|| list.iter_mut().for_each_till_continue(f)) + } + } + } } fn transform_boxed(boxed_expr: Box, transform: &mut F) -> Result> diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 208a8b57d7b0..8dcd3db30968 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -18,8 +18,8 @@ //! Tree node implementation for logical plan use crate::LogicalPlan; -use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; -use datafusion_common::{tree_node::TreeNode, Result}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIterator}; +use datafusion_common::Result; use std::borrow::Cow; impl TreeNode for LogicalPlan { @@ -27,73 +27,11 @@ impl TreeNode for LogicalPlan { self.inputs().into_iter().map(Cow::Borrowed).collect() } - fn apply(&self, op: &mut F) -> Result + fn visit_inner_children(&self, f: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { - // Note, - // - // Compared to the default implementation, we need to invoke - // [`Self::apply_subqueries`] before visiting its children - match op(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.apply_subqueries(op)?; - - self.apply_children(&mut |node| node.apply(op)) - } - - /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke - /// [`LogicalPlan::visit`]. - /// - /// For example, for a logical plan like: - /// - /// ```text - /// Projection: id - /// Filter: state Eq Utf8(\"CO\")\ - /// CsvScan: employee.csv projection=Some([0, 3])"; - /// ``` - /// - /// The sequence of visit operations would be: - /// ```text - /// visitor.pre_visit(Projection) - /// visitor.pre_visit(Filter) - /// visitor.pre_visit(CsvScan) - /// visitor.post_visit(CsvScan) - /// visitor.post_visit(Filter) - /// visitor.post_visit(Projection) - /// ``` - fn visit>( - &self, - visitor: &mut V, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::visit_subqueries`] before visiting its children - - match visitor.pre_visit(self)? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; - - self.visit_subqueries(visitor)?; - - match self.apply_children(&mut |node| node.visit(visitor))? { - VisitRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - - visitor.post_visit(self) + self.visit_subqueries(f) } fn map_children(self, transform: F) -> Result @@ -118,4 +56,24 @@ impl TreeNode for LogicalPlan { Ok(self) } } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + let old_children = self.inputs(); + let mut new_children = + old_children.iter().map(|&c| c.clone()).collect::>(); + let tnr = new_children.iter_mut().for_each_till_continue(f)?; + + // if any changes made, make a new child + if old_children + .iter() + .zip(new_children.iter()) + .any(|(c1, c2)| c1 != &c2) + { + *self = self.with_new_exprs(self.expressions(), new_children.as_slice())?; + } + Ok(tnr) + } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index e3ecdf154e61..6e9dac033c10 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -31,7 +31,7 @@ use crate::{ }; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::utils::get_at_indices; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, @@ -262,8 +262,8 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - inspect_expr_pre(expr, |expr| { - match expr { + expr.visit_down(&mut |e| { + match e { Expr::Column(qc) => { accum.insert(qc.clone()); } @@ -304,8 +304,9 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} } - Ok(()) + Ok(TreeNodeRecursion::Continue) }) + .map(|_| ()) } /// Find excluded columns in the schema, if any @@ -656,44 +657,22 @@ where F: Fn(&Expr) -> bool, { let mut exprs = vec![]; - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { if test_fn(expr) { if !(exprs.contains(expr)) { exprs.push(expr.clone()) } // stop recursing down this expr once we find a match - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Prune); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); exprs } -/// Recursively inspect an [`Expr`] and all its children. -pub fn inspect_expr_pre(expr: &Expr, mut f: F) -> Result<(), E> -where - F: FnMut(&Expr) -> Result<(), E>, -{ - let mut err = Ok(()); - expr.apply(&mut |expr| { - if let Err(e) = f(expr) { - // save the error for later (it may not be a DataFusionError - err = Err(e); - Ok(VisitRecursion::Stop) - } else { - // keep going - Ok(VisitRecursion::Continue) - } - }) - // The closure always returns OK, so this will always too - .expect("no way to return error during recursion"); - - err -} - /// Returns a new logical plan based on the original one with inputs /// and expressions replaced. /// @@ -826,17 +805,14 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { .collect() } -pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { - let mut exprs = vec![]; - inspect_expr_pre(e, |expr| { - if let Expr::Column(c) = expr { - exprs.push(c.clone()) +pub(crate) fn find_columns_referenced_by_expr(expr: &Expr) -> Vec { + expr.collect(&mut |e| { + if let Expr::Column(c) = e { + vec![c.clone()] + } else { + vec![] } - Ok(()) as Result<()> }) - // As the closure always returns Ok, this "can't" error - .expect("Unexpected error"); - exprs } /// Convert any `Expr` to an `Expr::Column`. @@ -853,26 +829,16 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { /// Recursively walk an expression tree, collecting the column indexes /// referenced in the expression pub(crate) fn find_column_indexes_referenced_by_expr( - e: &Expr, + expr: &Expr, schema: &DFSchemaRef, ) -> Vec { - let mut indexes = vec![]; - inspect_expr_pre(e, |expr| { - match expr { - Expr::Column(qc) => { - if let Ok(idx) = schema.index_of_column(qc) { - indexes.push(idx); - } - } - Expr::Literal(_) => { - indexes.push(std::usize::MAX); - } - _ => {} + expr.collect(&mut |e| match e { + Expr::Column(qc) => schema.index_of_column(qc).into_iter().collect(), + Expr::Literal(_) => { + vec![std::usize::MAX] } - Ok(()) as Result<()> + _ => vec![], }) - .unwrap(); - indexes } /// can this data type be used in hash join equal conditions?? diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 953716713e41..2f0708b5cd7a 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -17,9 +17,13 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeTransformer, +}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; +use datafusion_expr::expr::{ + AggregateFunction, AggregateFunctionDefinition, Exists, InSubquery, WindowFunction, +}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; @@ -43,7 +47,7 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down(&analyze_internal) + plan.transform_down_old(&analyze_internal) } fn name(&self) -> &str { @@ -114,108 +118,69 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { struct CountWildcardRewriter {} -impl TreeNodeRewriter for CountWildcardRewriter { - type N = Expr; +impl TreeNodeTransformer for CountWildcardRewriter { + type Node = Expr; + + fn pre_transform(&mut self, _node: &mut Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } - fn mutate(&mut self, old_expr: Expr) -> Result { - let new_expr = match old_expr.clone() { - Expr::WindowFunction(expr::WindowFunction { + fn post_transform(&mut self, expr: &mut Expr) -> Result { + match expr { + Expr::WindowFunction(WindowFunction { fun: expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args, - partition_by, - order_by, - window_frame, - }) if args.len() == 1 => match args[0] { - Expr::Wildcard { qualifier: None } => { - Expr::WindowFunction(expr::WindowFunction { - fun: expr::WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args: vec![lit(COUNT_STAR_EXPANSION)], - partition_by, - order_by, - window_frame, - }) + .. + }) if args.len() == 1 => { + if let Expr::Wildcard { qualifier: None } = args[0] { + args[0] = lit(COUNT_STAR_EXPANSION) } - - _ => old_expr, - }, + } Expr::AggregateFunction(AggregateFunction { func_def: AggregateFunctionDefinition::BuiltIn( aggregate_function::AggregateFunction::Count, ), args, - distinct, - filter, - order_by, - }) if args.len() == 1 => match args[0] { - Expr::Wildcard { qualifier: None } => { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![lit(COUNT_STAR_EXPANSION)], - distinct, - filter, - order_by, - )) + .. + }) if args.len() == 1 => { + if let Expr::Wildcard { qualifier: None } = args[0] { + args[0] = lit(COUNT_STAR_EXPANSION) } - _ => old_expr, - }, - - ScalarSubquery(Subquery { - subquery, - outer_ref_columns, - }) => { + } + ScalarSubquery(Subquery { subquery, .. }) => { let new_plan = subquery .as_ref() .clone() - .transform_down(&analyze_internal)?; - ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }) + .transform_down_old(&analyze_internal)?; + *subquery = Arc::new(new_plan); } Expr::InSubquery(InSubquery { - expr, - subquery, - negated, + subquery: Subquery { subquery, .. }, + .. }) => { let new_plan = subquery - .subquery .as_ref() .clone() - .transform_down(&analyze_internal)?; - - Expr::InSubquery(InSubquery::new( - expr, - Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - )) + .transform_down_old(&analyze_internal)?; + *subquery = Arc::new(new_plan); } - Expr::Exists(expr::Exists { subquery, negated }) => { + Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + .. + }) => { let new_plan = subquery - .subquery .as_ref() .clone() - .transform_down(&analyze_internal)?; - - Expr::Exists(expr::Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }) + .transform_down_old(&analyze_internal)?; + *subquery = Arc::new(new_plan); } - _ => old_expr, + _ => {} }; - Ok(new_expr) + Ok(TreeNodeRecursion::Continue) } } diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 90af7aec8293..f2e00ed0763d 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -42,7 +42,7 @@ impl InlineTableScan { impl AnalyzerRule for InlineTableScan { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_up(&analyze_internal) + plan.transform_up_old(&analyze_internal) } fn name(&self) -> &str { @@ -74,7 +74,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Transformed::Yes(plan) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform(&rewrite_subquery)?; + let new_expr = filter.predicate.transform_up_old(&rewrite_subquery)?; Transformed::Yes(LogicalPlan::Filter(Filter::try_new( new_expr, filter.input, @@ -88,7 +88,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { match expr { Expr::Exists(Exists { subquery, negated }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up_old(&analyze_internal)?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::Yes(Expr::Exists(Exists { subquery, negated }))) } @@ -98,7 +98,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { negated, }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up_old(&analyze_internal)?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::Yes(Expr::InSubquery(InSubquery::new( expr, subquery, negated, @@ -106,7 +106,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { } Expr::ScalarSubquery(subquery) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up_old(&analyze_internal)?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::Yes(Expr::ScalarSubquery(subquery))) } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 14d5ddf47378..25b2b1246062 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -27,11 +27,10 @@ use crate::analyzer::subquery::check_subquery_expr; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; -use datafusion_expr::utils::inspect_expr_pre; use datafusion_expr::{Expr, LogicalPlan}; use log::debug; use std::sync::Arc; @@ -117,21 +116,21 @@ impl Analyzer { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { - plan.apply(&mut |plan: &LogicalPlan| { - for expr in plan.expressions().iter() { + plan.visit_down(&mut |plan: &LogicalPlan| { + plan.visit_expressions(&mut |e| { // recursively look for subqueries - inspect_expr_pre(expr, |expr| match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - check_subquery_expr(plan, &subquery.subquery, expr) + e.visit_down(&mut |e| { + match e { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + check_subquery_expr(plan, &subquery.subquery, e)? + } + _ => {} } - _ => Ok(()), - })?; - } - - Ok(VisitRecursion::Continue) - })?; - - Ok(()) + Ok(TreeNodeRecursion::Continue) + }) + }) + }) + .map(|_| ()) } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 7c5b70b19af0..d676be9a1087 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -17,7 +17,7 @@ use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; @@ -144,9 +144,9 @@ fn check_inner_plan( // We want to support as many operators as possible inside the correlated subquery match inner_plan { LogicalPlan::Aggregate(_) => { - inner_plan.apply_children(&mut |plan| { + inner_plan.visit_children(&mut |plan| { check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -169,9 +169,9 @@ fn check_inner_plan( } LogicalPlan::Window(window) => { check_mixed_out_refer_in_window(window)?; - inner_plan.apply_children(&mut |plan| { + inner_plan.visit_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -186,9 +186,9 @@ fn check_inner_plan( | LogicalPlan::Values(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) => { - inner_plan.apply_children(&mut |plan| { + inner_plan.visit_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -199,14 +199,14 @@ fn check_inner_plan( .. }) => match join_type { JoinType::Inner => { - inner_plan.apply_children(&mut |plan| { + inner_plan.visit_children(&mut |plan| { check_inner_plan( plan, is_scalar, is_aggregate, can_contain_outer_ref, )?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -219,9 +219,9 @@ fn check_inner_plan( check_inner_plan(right, is_scalar, is_aggregate, can_contain_outer_ref) } JoinType::Full => { - inner_plan.apply_children(&mut |plan| { + inner_plan.visit_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, false)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -281,7 +281,7 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { let mut exprs = vec![]; - inner_plan.apply(&mut |plan| { + inner_plan.visit_down(&mut |plan| { if let LogicalPlan::Filter(Filter { predicate, .. }) = plan { let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() @@ -290,9 +290,9 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { correlated .into_iter() .for_each(|expr| exprs.push(strip_outer_reference(expr.clone()))); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4d54dad99670..520df618c5ea 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -17,12 +17,13 @@ //! Optimizer rule for type validation and coercion +use std::mem; use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeTransformer}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -44,7 +45,6 @@ use datafusion_expr::type_coercion::other::{ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarFunctionDefinition, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -125,40 +125,24 @@ pub(crate) struct TypeCoercionRewriter { pub(crate) schema: DFSchemaRef, } -impl TreeNodeRewriter for TypeCoercionRewriter { - type N = Expr; +impl TreeNodeTransformer for TypeCoercionRewriter { + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) + fn pre_transform(&mut self, _expr: &mut Expr) -> Result { + Ok(TreeNodeRecursion::Continue) } - fn mutate(&mut self, expr: Expr) -> Result { + fn post_transform(&mut self, expr: &mut Expr) -> Result { match expr { - Expr::ScalarSubquery(Subquery { - subquery, - outer_ref_columns, + Expr::ScalarSubquery(Subquery { subquery, .. }) + | Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + .. }) => { - let new_plan = analyze_internal(&self.schema, &subquery)?; - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - })) - } - Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - Ok(Expr::Exists(Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - })) + let new_plan = analyze_internal(&self.schema, subquery)?; + *subquery = Arc::new(new_plan); } - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { + Expr::InSubquery(InSubquery { expr, subquery, .. }) => { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; let expr_type = expr.get_type(&self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); @@ -166,53 +150,31 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" ), )?; + **expr = mem::take(expr.as_mut()).cast_to(&common_type, &self.schema)?; let new_subquery = Subquery { subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, + outer_ref_columns: mem::take(&mut subquery.outer_ref_columns), }; - Ok(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, &self.schema)?), - cast_subquery(new_subquery, &common_type)?, - negated, - ))) - } - Expr::IsTrue(expr) => { - let expr = is_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotTrue(expr) => { - let expr = is_not_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsFalse(expr) => { - let expr = is_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) + *subquery = cast_subquery(new_subquery, &common_type)?; } - Expr::IsNotFalse(expr) => { - let expr = - is_not_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsUnknown(expr) => { - let expr = is_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotUnknown(expr) => { - let expr = - is_not_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) + Expr::IsTrue(expr) + | Expr::IsNotTrue(expr) + | Expr::IsFalse(expr) + | Expr::IsNotFalse(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotUnknown(expr) => { + **expr = get_casted_expr_for_bool_op(expr, &self.schema)? } Expr::Like(Like { - negated, expr, pattern, - escape_char, case_insensitive, + .. }) => { let left_type = expr.get_type(&self.schema)?; let right_type = pattern.get_type(&self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { - let op_name = if case_insensitive { + let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" @@ -221,35 +183,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" ) })?; - let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); - let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); - let expr = Expr::Like(Like::new( - negated, - expr, - pattern, - escape_char, - case_insensitive, - )); - Ok(expr) + **expr = mem::take(expr.as_mut()).cast_to(&coerced_type, &self.schema)?; + **pattern = + mem::take(pattern.as_mut()).cast_to(&coerced_type, &self.schema)?; } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left_type, right_type) = get_input_types( &left.get_type(&self.schema)?, - &op, + op, &right.get_type(&self.schema)?, )?; - - Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, &self.schema)?), - op, - Box::new(right.cast_to(&right_type, &self.schema)?), - ))) + **left = mem::take(left.as_mut()).cast_to(&left_type, &self.schema)?; + **right = mem::take(right.as_mut()).cast_to(&right_type, &self.schema)?; } Expr::Between(Between { - expr, - negated, - low, - high, + expr, low, high, .. }) => { let expr_type = expr.get_type(&self.schema)?; let low_type = low.get_type(&self.schema)?; @@ -273,19 +221,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) })?; - let expr = Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, &self.schema)?), - negated, - Box::new(low.cast_to(&coercion_type, &self.schema)?), - Box::new(high.cast_to(&coercion_type, &self.schema)?), - )); - Ok(expr) + **expr = + mem::take(expr.as_mut()).cast_to(&coercion_type, &self.schema)?; + **low = mem::take(low.as_mut()).cast_to(&coercion_type, &self.schema)?; + **high = + mem::take(high.as_mut()).cast_to(&coercion_type, &self.schema)?; } - Expr::InList(InList { - expr, - list, - negated, - }) => { + Expr::InList(InList { expr, list, .. }) => { let expr_data_type = expr.get_type(&self.schema)?; let list_data_types = list .iter() @@ -296,28 +238,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter { match result_type { None => plan_err!( "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}" - ), + )?, Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, &self.schema)?; - let cast_list_expr = list - .into_iter() - .map(|list_expr| { - list_expr.cast_to(&coerced_type, &self.schema) - }) - .collect::>>()?; - let expr = Expr::InList(InList ::new( - Box::new(cast_expr), - cast_list_expr, - negated, - )); - Ok(expr) + **expr = mem::take(expr.as_mut()).cast_to(&coerced_type, &self.schema)?; + list.iter_mut() + .try_for_each(|list_expr| { + mem::take(list_expr).cast_to(&coerced_type, &self.schema).map(|r| *list_expr = r) + })?; } } } - Expr::Case(case) => { - let case = coerce_case_expression(case, &self.schema)?; - Ok(Expr::Case(case)) + Expr::Case(_) => { + if let Expr::Case(case) = mem::take(expr) { + *expr = Expr::Case(coerce_case_expression(case, &self.schema)?); + } } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -326,12 +261,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun.signature(), )?; - let new_args = coerce_arguments_for_fun( - new_args.as_slice(), - &self.schema, - &fun, - )?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + let new_args = + coerce_arguments_for_fun(new_args.as_slice(), &self.schema, fun)?; + *args = new_args } ScalarFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -339,30 +271,23 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + *args = new_expr } ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") + internal_err!("Function `Expr` with name should be resolved.")? } }, Expr::AggregateFunction(expr::AggregateFunction { - func_def, - args, - distinct, - filter, - order_by, + func_def, args, .. }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let new_expr = coerce_agg_exprs_for_signature( - &fun, - &args, + fun, + args, &self.schema, &fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) + *args = new_expr } AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -370,48 +295,47 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fun, new_expr, false, filter, order_by, - )); - Ok(expr) + *args = new_expr } AggregateFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") + internal_err!("Function `Expr` with name should be resolved.")? } }, - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - }) => { - let window_frame = - coerce_window_frame(window_frame, &self.schema, &order_by)?; - - let args = match &fun { - expr::WindowFunctionDefinition::AggregateFunction(fun) => { - coerce_agg_exprs_for_signature( - fun, - &args, - &self.schema, - &fun.signature(), - )? - } - _ => args, - }; - - let expr = Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(_) => { + if let Expr::WindowFunction(WindowFunction { fun, args, partition_by, order_by, window_frame, - )); - Ok(expr) + .. + }) = mem::take(expr) + { + let window_frame = + coerce_window_frame(window_frame, &self.schema, &order_by)?; + let args = match &fun { + expr::WindowFunctionDefinition::AggregateFunction(fun) => { + coerce_agg_exprs_for_signature( + fun, + &args, + &self.schema, + &fun.signature(), + )? + } + _ => args, + }; + *expr = Expr::WindowFunction(WindowFunction::new( + fun, + args, + partition_by, + order_by, + window_frame, + )); + } } - expr => Ok(expr), + _ => {} } + Ok(TreeNodeRecursion::Continue) } } @@ -1229,7 +1153,7 @@ mod test { None, ), ))); - let expr = Expr::ScalarFunction(ScalarFunction::new( + let mut expr = Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::MakeArray, vec![val.clone()], )); @@ -1244,8 +1168,8 @@ mod test { )], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let result = expr.rewrite(&mut rewriter)?; + let mut transformer = TypeCoercionRewriter { schema }; + expr.transform(&mut transformer)?; let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( @@ -1266,7 +1190,7 @@ mod test { vec![expected_casted_expr], )); - assert_eq!(result, expected); + assert_eq!(expr, expected); Ok(()) } @@ -1277,33 +1201,33 @@ mod test { vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let expr = is_true(lit(12i32).gt(lit(13i64))); + let mut transformer = TypeCoercionRewriter { schema }; + let mut expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; - assert_eq!(expected, result); + expr.transform(&mut transformer)?; + assert_eq!(expected, expr); // eq let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let expr = is_true(lit(12i32).eq(lit(13i64))); + let mut transformer = TypeCoercionRewriter { schema }; + let mut expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; - assert_eq!(expected, result); + expr.transform(&mut transformer)?; + assert_eq!(expected, expr); // lt let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; - let expr = is_true(lit(12i32).lt(lit(13i64))); + let mut transfomer = TypeCoercionRewriter { schema }; + let mut expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; - assert_eq!(expected, result); + expr.transform(&mut transfomer)?; + assert_eq!(expected, expr); Ok(()) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1e089257c61a..d6cad22eb7e2 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -24,7 +24,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, + RewriteRecursion, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, @@ -612,18 +612,18 @@ impl ExprIdentifierVisitor<'_> { } impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { + fn pre_visit(&mut self, _expr: &Expr) -> Result { self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; // put placeholder self.id_array.push((0, "".to_string())); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, expr: &Expr) -> Result { + fn post_visit(&mut self, expr: &Expr) -> Result { self.series_number += 1; let (idx, sub_expr_desc) = self.pop_enter_mark(); @@ -632,7 +632,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } let mut desc = Self::desc_expr(expr); desc.push_str(&sub_expr_desc); @@ -646,7 +646,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b1000f042c98..3df604a62c8a 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -370,7 +370,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for e in agg_expr.iter() { - let result_expr = e.clone().transform_up(&|expr| { + let mut result_expr = e.clone().transform_up_old(&|expr| { let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { match func_def { @@ -396,7 +396,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( Ok(new_expr) })?; - let result_expr = result_expr.unalias(); + result_expr.unalias(); let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(info); @@ -415,7 +415,7 @@ fn proj_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for expr in proj_expr.iter() { - let result_expr = expr.clone().transform_up(&|expr| { + let result_expr = expr.clone().transform_up_old(&|expr| { if let Expr::Column(Column { name, .. }) = &expr { if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::Yes(result_expr.clone())) @@ -448,7 +448,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( input_expr_result_map_for_count_bug: &ExprResultMap, expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result> { - let result_expr = filter_expr.clone().transform_up(&|expr| { + let result_expr = filter_expr.clone().transform_up_old(&|expr| { if let Expr::Column(Column { name, .. }) = &expr { if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::Yes(result_expr.clone())) diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 07f495a7262d..97a56f85ef96 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, @@ -73,9 +73,9 @@ impl LogicalPlanSignature { /// Get total number of [`LogicalPlan`]s in the plan. fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; - plan.apply(&mut |_plan| { + plan.visit_down(&mut |_plan| { node_number += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // Closure always return Ok .unwrap(); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4eb925ac0629..d0d6d5e51c34 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DFSchemaRef, DataFusionError, JoinConstraint, Result, @@ -217,11 +217,11 @@ fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result Result { let mut is_evaluate = true; - predicate.apply(&mut |expr| match expr { + predicate.visit_down(&mut |expr| match expr { Expr::Column(_) | Expr::Literal(_) | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) => Ok(VisitRecursion::Skip), + | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Prune), Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) @@ -231,7 +231,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { .. }) => { is_evaluate = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } Expr::Alias(_) | Expr::BinaryExpr(_) @@ -253,7 +253,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::ScalarFunction(..) - | Expr::InList { .. } => Ok(VisitRecursion::Continue), + | Expr::InList { .. } => Ok(TreeNodeRecursion::Continue), Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) @@ -1016,7 +1016,7 @@ pub fn replace_cols_by_name( e: Expr, replace_map: &HashMap, ) -> Result { - e.transform_up(&|expr| { + e.transform_up_old(&|expr| { Ok(if let Expr::Column(c) = &expr { match replace_map.get(&c.flat_name()) { Some(new_c) => Transformed::Yes(new_c.clone()), @@ -1031,29 +1031,29 @@ pub fn replace_cols_by_name( /// check whether the expression is volatile predicates fn is_volatile_expression(e: &Expr) -> bool { let mut is_volatile = false; - e.apply(&mut |expr| { + e.visit_down(&mut |expr| { Ok(match expr { Expr::ScalarFunction(f) => match &f.func_def { ScalarFunctionDefinition::BuiltIn(fun) if fun.volatility() == Volatility::Volatile => { is_volatile = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } ScalarFunctionDefinition::UDF(fun) if fun.signature().volatility == Volatility::Volatile => { is_volatile = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } ScalarFunctionDefinition::Name(_) => { return internal_err!( "Function `Expr` with name should be resolved." ); } - _ => VisitRecursion::Continue, + _ => TreeNodeRecursion::Continue, }, - _ => VisitRecursion::Continue, + _ => TreeNodeRecursion::Continue, }) }) .unwrap(); @@ -1063,17 +1063,17 @@ fn is_volatile_expression(e: &Expr) -> bool { /// check whether the expression uses the columns in `check_map`. fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; - e.apply(&mut |expr| { + e.visit_down(&mut |expr| { Ok(if let Expr::Column(c) = &expr { match check_map.get(&c.flat_name()) { Some(_) => { is_contain = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } - None => VisitRecursion::Continue, + None => TreeNodeRecursion::Continue, } } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 34ed4a9475cb..378e01f4bfa9 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -87,7 +87,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = - rewrite_expr.clone().transform_up(&|expr| { + rewrite_expr.clone().transform_up_old(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) @@ -141,8 +141,9 @@ impl OptimizerRule for ScalarSubqueryToJoin { if let Some(rewrite_expr) = expr_to_rewrite_expr_map.get(expr) { - let new_expr = - rewrite_expr.clone().transform_up(&|expr| { + let new_expr = rewrite_expr + .clone() + .transform_up_old(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 7d09aec7e748..c091ef7581c4 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -158,10 +158,11 @@ impl ExprSimplifier { // rather than creating an DFSchemaRef coerces rather than doing // it manually. // https://github.com/apache/arrow-datafusion/issues/3793 - pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { + pub fn coerce(&self, mut expr: Expr, schema: DFSchemaRef) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite) + expr.transform(&mut expr_rewrite)?; + Ok(expr) } /// Input guarantees about the values of columns. diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 91603e82a54f..dfe4d1fa9ab8 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,17 +24,16 @@ use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeTransformer}; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{ - binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, -}; +use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator}; use std::cmp::Ordering; +use std::mem; use std::sync::Arc; /// [`UnwrapCastInComparison`] attempts to remove casts from @@ -126,21 +125,19 @@ struct UnwrapCastExprRewriter { schema: DFSchemaRef, } -impl TreeNodeRewriter for UnwrapCastExprRewriter { - type N = Expr; +impl TreeNodeTransformer for UnwrapCastExprRewriter { + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) + fn pre_transform(&mut self, _expr: &mut Expr) -> Result { + Ok(TreeNodeRecursion::Continue) } - fn mutate(&mut self, expr: Expr) -> Result { - match &expr { + fn post_transform(&mut self, expr: &mut Expr) -> Result { + match expr { // For case: // try_cast/cast(expr as data_type) op literal // literal op try_cast/cast(expr as data_type) Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let left = left.as_ref().clone(); - let right = right.as_ref().clone(); let left_type = left.get_type(&self.schema)?; let right_type = right.get_type(&self.schema)?; // Because the plan has been done the type coercion, the left and right must be equal @@ -148,7 +145,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { && is_support_data_type(&right_type) && is_comparison_op(op) { - match (&left, &right) { + match (left.as_mut(), right.as_mut()) { ( Expr::Literal(left_lit_value), Expr::TryCast(TryCast { expr, .. }) @@ -161,11 +158,8 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(left_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the right expr - return Ok(binary_expr( - lit(value), - *op, - expr.as_ref().clone(), - )); + **left = lit(value); + **right = mem::take(expr.as_mut()); } } ( @@ -180,49 +174,42 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(right_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the left expr - return Ok(binary_expr( - expr.as_ref().clone(), - *op, - lit(value), - )); + **left = mem::take(expr.as_mut()); + **right = lit(value); } } (_, _) => { // do nothing } - }; + } } - // return the new binary op - Ok(binary_expr(left, *op, right)) } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) Expr::InList(InList { expr: left_expr, list, - negated, + .. }) => { - if let Some( - Expr::TryCast(TryCast { - expr: internal_left_expr, - .. - }) - | Expr::Cast(Cast { - expr: internal_left_expr, - .. - }), - ) = Some(left_expr.as_ref()) + if let Expr::TryCast(TryCast { + expr: internal_left_expr, + .. + }) + | Expr::Cast(Cast { + expr: internal_left_expr, + .. + }) = left_expr.as_ref() { let internal_left = internal_left_expr.as_ref().clone(); let internal_left_type = internal_left.get_type(&self.schema); if internal_left_type.is_err() { // error data type - return Ok(expr); + return Ok(TreeNodeRecursion::Continue); } let internal_left_type = internal_left_type?; if !is_support_data_type(&internal_left_type) { // not supported data type - return Ok(expr); + return Ok(TreeNodeRecursion::Continue); } let right_exprs = list .iter() @@ -256,19 +243,16 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { } }) .collect::>>(); - match right_exprs { - Ok(right_exprs) => { - Ok(in_list(internal_left, right_exprs, *negated)) - } - Err(_) => Ok(expr), + if let Ok(right_exprs) = right_exprs { + **left_expr = internal_left; + *list = right_exprs; } - } else { - Ok(expr) } } // TODO: handle other expr type and dfs visit them - _ => Ok(expr), + _ => {} } + Ok(TreeNodeRecursion::Continue) } } @@ -730,11 +714,12 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); } - fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + fn optimize_test(mut expr: Expr, schema: &DFSchemaRef) -> Expr { let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), }; - expr.rewrite(&mut expr_rewriter).unwrap() + expr.transform(&mut expr_rewriter).unwrap(); + expr } fn expr_test_schema() -> DFSchemaRef { diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index f0bd1740d5d2..18e673c8fce5 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -260,7 +260,7 @@ impl EquivalenceGroup { /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.clone() - .transform(&|expr| { + .transform_down_old(&|expr| { for cls in self.iter() { if cls.contains(&expr) { return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 387dce2cdc8b..2016119c39dc 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -47,7 +47,7 @@ pub fn add_offset_to_expr( expr: Arc, offset: usize, ) -> Arc { - expr.transform_down(&|e| match e.as_any().downcast_ref::() { + expr.transform_down_old(&|e| match e.as_any().downcast_ref::() { Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( col.name(), offset + col.index(), diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 0f92b2c2f431..557358f6ecb4 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -58,7 +58,7 @@ impl ProjectionMapping { let target_expr = Arc::new(Column::new(name, expr_idx)) as _; expression .clone() - .transform_down(&|e| match e.as_any().downcast_ref::() { + .transform_down_old(&|e| match e.as_any().downcast_ref::() { Some(col) => { // Sometimes, an expression and its name in the input_schema // doesn't match. This can cause problems, so we make sure diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 31c1cf61193a..933bb17fb4bc 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -777,7 +777,7 @@ impl EquivalenceProperties { /// the given expression. pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { ExprOrdering::new(expr.clone()) - .transform_up(&|expr| Ok(update_ordering(expr, self))) + .transform_up_old(&|expr| Ok(update_ordering(expr, self))) // Guaranteed to always return `Ok`. .unwrap() } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 52fb85657f4e..8958c71c585e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -944,7 +944,7 @@ mod tests { let expr2 = expr .clone() - .transform(&|e| { + .transform_up_old(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { @@ -965,7 +965,7 @@ mod tests { let expr3 = expr .clone() - .transform_down(&|e| { + .transform_down_old(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index 0205f85dced4..aaeefd0fa73f 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -21,7 +21,8 @@ use std::{ops::Neg, sync::Arc}; use arrow_schema::SortOptions; use crate::PhysicalExpr; -use datafusion_common::tree_node::TreeNode; + +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, VisitRecursionIterator}; use datafusion_common::Result; /// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient @@ -191,4 +192,11 @@ impl TreeNode for ExprOrdering { } Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + self.children.iter_mut().for_each_till_continue(f) + } } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 64a62dc7820d..65949c640606 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -29,7 +29,7 @@ use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeTransformer, VisitRecursionIterator, }; use datafusion_common::Result; use datafusion_expr::Operator; @@ -169,9 +169,16 @@ impl TreeNode for ExprTreeNode { .collect::>>()?; Ok(self) } + + fn transform_children(&mut self, f: &mut F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + self.child_nodes.iter_mut().for_each_till_continue(f) + } } -/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a +/// This struct facilitates the [TreeNodeTransformer] mechanism to convert a /// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting /// identical expressions in one node. Caller specifies the node type in the /// DAEG via the `constructor` argument, which constructs nodes in the DAEG @@ -185,16 +192,21 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result< constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeTransformer for PhysicalExprDAEGBuilder<'a, T, F> { - type N = ExprTreeNode; + type Node = ExprTreeNode; + + fn pre_transform(&mut self, _node: &mut Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } + // This method mutates an expression node by transforming it to a physical expression // and adding it to the graph. The method returns the mutated expression node. - fn mutate( + fn post_transform( &mut self, - mut node: ExprTreeNode, - ) -> Result> { + node: &mut ExprTreeNode, + ) -> Result { // Get the expression associated with the input expression node. let expr = &node.expr; @@ -206,7 +218,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // add edges to its child nodes. Add the visited expression to the vector // of visited expressions and return the newly created node index. None => { - let node_idx = self.graph.add_node((self.constructor)(&node)?); + let node_idx = self.graph.add_node((self.constructor)(node)?); for expr_node in node.child_nodes.iter() { self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); } @@ -217,7 +229,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // Set the data field of the input expression node to the corresponding node index. node.data = Some(node_idx); // Return the mutated expression node. - Ok(node) + Ok(TreeNodeRecursion::Continue) } } @@ -238,7 +250,8 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.rewrite(&mut builder)?; + let mut root = init; + root.transform(&mut builder)?; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } @@ -246,13 +259,13 @@ where /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); - expr.apply(&mut |expr| { + expr.visit_down(&mut |expr| { if let Some(column) = expr.as_any().downcast_ref::() { if !columns.iter().any(|c| c.eq(column)) { columns.insert(column.clone()); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); @@ -266,7 +279,7 @@ pub fn reassign_predicate_columns( schema: &SchemaRef, ignore_not_found: bool, ) -> Result> { - pred.transform_down(&|expr| { + pred.transform_down_old(&|expr| { let expr_any = expr.as_any(); if let Some(column) = expr_any.downcast_ref::() { diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 9a4c98927683..6ba3cc75e572 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -284,7 +284,7 @@ pub fn convert_sort_expr_with_filter_schema( if all_columns_are_included { // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. - let converted_filter_expr = expr.transform_up(&|p| { + let converted_filter_expr = expr.transform_up_old(&|p| { convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { match transformed { Some(transformed) => Transformed::Yes(transformed), diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 616a2fc74932..76c270ffe463 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -33,7 +33,7 @@ use std::collections::HashMap; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { - expr.clone().transform_up(&|nested_expr| { + expr.clone().transform_up_old(&|nested_expr| { match nested_expr { Expr::Column(col) => { let field = plan.schema().field_from_column(&col)?; @@ -66,7 +66,7 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - expr.clone().transform_up(&|nested_expr| { + expr.clone().transform_up_old(&|nested_expr| { if base_exprs.contains(&nested_expr) { Ok(Transformed::Yes(expr_as_column_expr(&nested_expr, plan)?)) } else { @@ -170,16 +170,17 @@ pub(crate) fn resolve_aliases_to_exprs( expr: &Expr, aliases: &HashMap, ) -> Result { - expr.clone().transform_up(&|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { - if let Some(aliased_expr) = aliases.get(&c.name) { - Ok(Transformed::Yes(aliased_expr.clone())) - } else { - Ok(Transformed::No(Expr::Column(c))) + expr.clone() + .transform_up_old(&|nested_expr| match nested_expr { + Expr::Column(c) if c.relation.is_none() => { + if let Some(aliased_expr) = aliases.get(&c.name) { + Ok(Transformed::Yes(aliased_expr.clone())) + } else { + Ok(Transformed::No(Expr::Column(c))) + } } - } - _ => Ok(Transformed::No(nested_expr)), - }) + _ => Ok(Transformed::No(nested_expr)), + }) } /// given a slice of window expressions sharing the same sort key, find their common partition