Skip to content

Commit

Permalink
- refactor transform_down() and transform_up() to work on mutable…
Browse files Browse the repository at this point in the history
… `TreeNode`s and use them in a few examples

- add `transform_down_with_payload()`, `transform_up_with_payload()`, `transform_with_payload()` and use it in `EnforceSorting` as an example
  • Loading branch information
peter-toth committed Dec 19, 2023
1 parent c0990de commit 5c61470
Show file tree
Hide file tree
Showing 30 changed files with 368 additions and 391 deletions.
6 changes: 3 additions & 3 deletions datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {

impl MyAnalyzerRule {
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
plan.transform_up(&|plan| {
plan.transform_up_old(&|plan| {
Ok(match plan {
LogicalPlan::Filter(filter) => {
let predicate = Self::analyze_expr(filter.predicate.clone())?;
Expand All @@ -106,7 +106,7 @@ impl MyAnalyzerRule {
}

fn analyze_expr(expr: Expr) -> Result<Expr> {
expr.transform_up(&|expr| {
expr.transform_up_old(&|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Literal(ScalarValue::Int64(i)) => {
Expand Down Expand Up @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule {

/// use rewrite_expr to modify the expression tree.
fn my_rewrite(expr: Expr) -> Result<Expr> {
expr.transform_up(&|expr| {
expr.transform_up_old(&|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Between(Between {
Expand Down
207 changes: 146 additions & 61 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ pub trait TreeNode: Sized {
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
{
// Apply `f` on self.
f(self)
f(self)?
// If it returns continue (not prune or stop or stop all) then continue
// traversal on inner children and children.
.and_then_on_continue(|| {
// Run the recursive `apply` on each inner children, but as they are
// unrelated root nodes of inner trees if any returns stop then continue
// with the next one.
self.apply_inner_children(&mut |c| c.visit_down(f).continue_on_stop())
self.apply_inner_children(&mut |c| c.visit_down(f)?.continue_on_stop())?
// Run the recursive `apply` on each children.
.and_then_on_continue(|| {
self.apply_children(&mut |c| c.visit_down(f))
})
})
})?
// Applying `f` on self might have returned prune, but we need to propagate
// continue.
.continue_on_prune()
Expand Down Expand Up @@ -107,21 +107,21 @@ pub trait TreeNode: Sized {
) -> Result<TreeNodeRecursion> {
// Apply `pre_visit` on self.
visitor
.pre_visit(self)
.pre_visit(self)?
// If it returns continue (not prune or stop or stop all) then continue
// traversal on inner children and children.
.and_then_on_continue(|| {
// Run the recursive `visit` on each inner children, but as they are
// unrelated subquery plans if any returns stop then continue with the
// next one.
self.apply_inner_children(&mut |c| c.visit(visitor).continue_on_stop())
self.apply_inner_children(&mut |c| c.visit(visitor)?.continue_on_stop())?
// Run the recursive `visit` on each children.
.and_then_on_continue(|| {
self.apply_children(&mut |c| c.visit(visitor))
})
})?
// Apply `post_visit` on self.
.and_then_on_continue(|| visitor.post_visit(self))
})
})?
// Applying `pre_visit` or `post_visit` on self might have returned prune,
// but we need to propagate continue.
.continue_on_prune()
Expand All @@ -133,31 +133,144 @@ pub trait TreeNode: Sized {
) -> Result<TreeNodeRecursion> {
// Apply `pre_transform` on self.
transformer
.pre_transform(self)
.pre_transform(self)?
// If it returns continue (not prune or stop or stop all) then continue
// traversal on inner children and children.
.and_then_on_continue(||
// Run the recursive `transform` on each children.
self
.transform_children(&mut |c| c.transform(transformer))
.transform_children(&mut |c| c.transform(transformer))?
// Apply `post_transform` on new self.
.and_then_on_continue(|| {
transformer.post_transform(self)
}))
.and_then_on_continue(|| transformer.post_transform(self)))?
// Applying `pre_transform` or `post_transform` on self might have returned
// prune, but we need to propagate continue.
.continue_on_prune()
}

fn transform_with_payload<FD, PD, FU, PU>(
&mut self,
f_down: &mut FD,
payload_down: Option<PD>,
f_up: &mut FU,
) -> Result<(TreeNodeRecursion, Option<PU>)>
where
FD: FnMut(&mut Self, Option<PD>) -> Result<(TreeNodeRecursion, Vec<PD>)>,
FU: FnMut(&mut Self, Vec<PU>) -> Result<(TreeNodeRecursion, PU)>,
{
// Apply `f_down` on self.
let (tnr, new_payload_down) = f_down(self, payload_down)?;
let mut new_payload_down_iter = new_payload_down.into_iter();
// If it returns continue (not prune or stop or stop all) then continue traversal
// on inner children and children.
let mut new_payload_up = None;
tnr.and_then_on_continue(|| {
// Run the recursive `transform` on each children.
let mut payload_up = vec![];
let tnr = self.transform_children(&mut |c| {
let (tnr, p) =
c.transform_with_payload(f_down, new_payload_down_iter.next(), f_up)?;
p.into_iter().for_each(|p| payload_up.push(p));
Ok(tnr)
})?;
// Apply `f_up` on self.
tnr.and_then_on_continue(|| {
let (tnr, np) = f_up(self, payload_up)?;
new_payload_up = Some(np);
Ok(tnr)
})
})?
// Applying `f_down` or `f_up` on self might have returned prune, but we need to propagate
// continue.
.continue_on_prune()
.map(|tnr| (tnr, new_payload_up))
}

fn transform_down<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
{
// Apply `f` on self.
f(self)?
// If it returns continue (not prune or stop or stop all) then continue
// traversal on inner children and children.
.and_then_on_continue(||
// Run the recursive `transform` on each children.
self.transform_children(&mut |c| c.transform_down(f)))?
// Applying `f` on self might have returned prune, but we need to propagate
// continue.
.continue_on_prune()
}

fn transform_down_with_payload<F, P>(
&mut self,
f: &mut F,
payload: P,
) -> Result<TreeNodeRecursion>
where
F: FnMut(&mut Self, P) -> Result<(TreeNodeRecursion, Vec<P>)>,
{
// Apply `f` on self.
let (tnr, new_payload) = f(self, payload)?;
let mut new_payload_iter = new_payload.into_iter();
// If it returns continue (not prune or stop or stop all) then continue
// traversal on inner children and children.
tnr.and_then_on_continue(||
// Run the recursive `transform` on each children.
self.transform_children(&mut |c| c.transform_down_with_payload(f, new_payload_iter.next().unwrap())))?
// Applying `f` on self might have returned prune, but we need to propagate
// continue.
.continue_on_prune()
}

fn transform_up<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
{
// Run the recursive `transform` on each children.
self.transform_children(&mut |c| c.transform_up(f))?
// Apply `f` on self.
.and_then_on_continue(|| f(self))?
// Applying `f` on self might have returned prune, but we need to propagate
// continue.
.continue_on_prune()
}

fn transform_up_with_payload<F, P>(
&mut self,
f: &mut F,
) -> Result<(TreeNodeRecursion, Option<P>)>
where
F: FnMut(&mut Self, Vec<P>) -> Result<(TreeNodeRecursion, P)>,
{
// Run the recursive `transform` on each children.
let mut payload = vec![];
let tnr = self.transform_children(&mut |c| {
let (tnr, p) = c.transform_up_with_payload(f)?;
p.into_iter().for_each(|p| payload.push(p));
Ok(tnr)
})?;
let mut new_payload = None;
// Apply `f` on self.
tnr.and_then_on_continue(|| {
let (tnr, np) = f(self, payload)?;
new_payload = Some(np);
Ok(tnr)
})?
// Applying `f` on self might have returned prune, but we need to propagate
// continue.
.continue_on_prune()
.map(|tnr| (tnr, new_payload))
}

/// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its
/// children(Preorder Traversal).
/// When the `op` does not apply to a given node, it is left unchanged.
fn transform_down<F>(self, op: &F) -> Result<Self>
fn transform_down_old<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Transformed<Self>>,
{
let after_op = op(self)?.into();
after_op.map_children(|node| node.transform_down(op))
after_op.map_children(|node| node.transform_down_old(op))
}

/// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its
Expand All @@ -174,11 +287,11 @@ pub trait TreeNode: Sized {
/// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its
/// children and then itself(Postorder Traversal).
/// When the `op` does not apply to a given node, it is left unchanged.
fn transform_up<F>(self, op: &F) -> Result<Self>
fn transform_up_old<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Transformed<Self>>,
{
let after_op_children = self.map_children(|node| node.transform_up(op))?;
let after_op_children = self.map_children(|node| node.transform_up_old(op))?;

let new_node = op(after_op_children)?.into();
Ok(new_node)
Expand Down Expand Up @@ -402,63 +515,35 @@ pub enum TreeNodeRecursion {
}

impl TreeNodeRecursion {
fn continue_on_prune(self) -> TreeNodeRecursion {
match self {
TreeNodeRecursion::Prune => TreeNodeRecursion::Continue,
o => o,
}
}

fn fail_on_prune(self) -> TreeNodeRecursion {
match self {
TreeNodeRecursion::Prune => panic!("Recursion can't prune."),
o => o,
}
}

fn continue_on_stop(self) -> TreeNodeRecursion {
match self {
TreeNodeRecursion::Stop => TreeNodeRecursion::Continue,
o => o,
}
}
}

/// This helper trait provide functions to control recursion on
/// [`Result<TreeNodeRecursion>`].
pub trait TreeNodeRecursionResult: Sized {
fn and_then_on_continue<F>(self, f: F) -> Result<TreeNodeRecursion>
where
F: FnOnce() -> Result<TreeNodeRecursion>;

fn continue_on_prune(self) -> Result<TreeNodeRecursion>;

fn fail_on_prune(self) -> Result<TreeNodeRecursion>;

fn continue_on_stop(self) -> Result<TreeNodeRecursion>;
}

impl TreeNodeRecursionResult for Result<TreeNodeRecursion> {
fn and_then_on_continue<F>(self, f: F) -> Result<TreeNodeRecursion>
pub fn and_then_on_continue<F>(self, f: F) -> Result<TreeNodeRecursion>
where
F: FnOnce() -> Result<TreeNodeRecursion>,
{
match self? {
match self {
TreeNodeRecursion::Continue => f(),
o => Ok(o),
}
}

fn continue_on_prune(self) -> Result<TreeNodeRecursion> {
self.map(|tnr| tnr.continue_on_prune())
pub fn continue_on_prune(self) -> Result<TreeNodeRecursion> {
Ok(match self {
TreeNodeRecursion::Prune => TreeNodeRecursion::Continue,
o => o,
})
}

fn fail_on_prune(self) -> Result<TreeNodeRecursion> {
self.map(|tnr| tnr.fail_on_prune())
pub fn fail_on_prune(self) -> Result<TreeNodeRecursion> {
Ok(match self {
TreeNodeRecursion::Prune => panic!("Recursion can't prune."),
o => o,
})
}

fn continue_on_stop(self) -> Result<TreeNodeRecursion> {
self.map(|tnr| tnr.continue_on_stop())
pub fn continue_on_stop(self) -> Result<TreeNodeRecursion> {
Ok(match self {
TreeNodeRecursion::Stop => TreeNodeRecursion::Continue,
o => o,
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/coalesce_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl PhysicalOptimizerRule for CoalesceBatches {
}

let target_batch_size = config.execution.batch_size;
plan.transform_up(&|plan| {
plan.transform_up_old(&|plan| {
let plan_any = plan.as_any();
// The goal here is to detect operators that could produce small batches and only
// wrap those ones with a CoalesceBatchesExec operator. An alternate approach here
Expand Down
Loading

0 comments on commit 5c61470

Please sign in to comment.