Skip to content

Commit

Permalink
- refactor EnforceDistribution using transform_down_with_payload()
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Dec 19, 2023
1 parent 5c61470 commit 8fa80e7
Showing 1 changed file with 88 additions and 163 deletions.
251 changes: 88 additions & 163 deletions datafusion/core/src/physical_optimizer/enforce_distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,15 @@ impl EnforceDistribution {
impl PhysicalOptimizerRule for EnforceDistribution {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
mut plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let top_down_join_key_reordering = config.optimizer.top_down_join_key_reordering;

let adjusted = if top_down_join_key_reordering {
// Run a top-down process to adjust input key ordering recursively
let plan_requirements = PlanWithKeyRequirements::new(plan);
let adjusted =
plan_requirements.transform_down_old(&adjust_input_keys_ordering)?;
adjusted.plan
plan.transform_down_with_payload(&mut adjust_input_keys_ordering, None)?;
plan
} else {
// Run a bottom-up process
plan.transform_up_old(&|plan| {
Expand Down Expand Up @@ -269,12 +267,15 @@ impl PhysicalOptimizerRule for EnforceDistribution {
/// 4) If the current plan is Projection, transform the requirements to the columns before the Projection and push down requirements
/// 5) For other types of operators, by default, pushdown the parent requirements to children.
///
type RequiredKeyOrdering = Option<Vec<Arc<dyn PhysicalExpr>>>;

fn adjust_input_keys_ordering(
requirements: PlanWithKeyRequirements,
) -> Result<Transformed<PlanWithKeyRequirements>> {
let parent_required = requirements.required_key_ordering.clone();
let plan_any = requirements.plan.as_any();
let transformed = if let Some(HashJoinExec {
plan: &mut Arc<dyn ExecutionPlan>,
required_key_ordering: RequiredKeyOrdering,
) -> Result<(TreeNodeRecursion, Vec<RequiredKeyOrdering>)> {
let parent_required = required_key_ordering.unwrap_or_default().clone();
let plan_any = plan.as_any();
if let Some(HashJoinExec {
left,
right,
on,
Expand All @@ -299,13 +300,15 @@ fn adjust_input_keys_ordering(
*null_equals_null,
)?) as Arc<dyn ExecutionPlan>)
};
Some(reorder_partitioned_join_keys(
requirements.plan.clone(),
let (new_plan, request_key_ordering) = reorder_partitioned_join_keys(
plan.clone(),
&parent_required,
on,
vec![],
&join_constructor,
)?)
)?;
*plan = new_plan;
Ok((TreeNodeRecursion::Continue, request_key_ordering))
}
PartitionMode::CollectLeft => {
let new_right_request = match join_type {
Expand All @@ -323,30 +326,28 @@ fn adjust_input_keys_ordering(
};

// Push down requirements to the right side
Some(PlanWithKeyRequirements {
plan: requirements.plan.clone(),
required_key_ordering: vec![],
request_key_ordering: vec![None, new_right_request],
})
Ok((TreeNodeRecursion::Continue, vec![None, new_right_request]))
}
PartitionMode::Auto => {
// Can not satisfy, clear the current requirements and generate new empty requirements
Some(PlanWithKeyRequirements::new(requirements.plan.clone()))
Ok((
TreeNodeRecursion::Continue,
vec![None; plan.children().len()],
))
}
}
} else if let Some(CrossJoinExec { left, .. }) =
plan_any.downcast_ref::<CrossJoinExec>()
{
let left_columns_len = left.schema().fields().len();
// Push down requirements to the right side
Some(PlanWithKeyRequirements {
plan: requirements.plan.clone(),
required_key_ordering: vec![],
request_key_ordering: vec![
Ok((
TreeNodeRecursion::Continue,
vec![
None,
shift_right_required(&parent_required, left_columns_len),
],
})
))
} else if let Some(SortMergeJoinExec {
left,
right,
Expand All @@ -368,26 +369,38 @@ fn adjust_input_keys_ordering(
*null_equals_null,
)?) as Arc<dyn ExecutionPlan>)
};
Some(reorder_partitioned_join_keys(
requirements.plan.clone(),
let (new_plan, request_key_ordering) = reorder_partitioned_join_keys(
plan.clone(),
&parent_required,
on,
sort_options.clone(),
&join_constructor,
)?)
)?;
*plan = new_plan;
Ok((TreeNodeRecursion::Continue, request_key_ordering))
} else if let Some(aggregate_exec) = plan_any.downcast_ref::<AggregateExec>() {
if !parent_required.is_empty() {
match aggregate_exec.mode() {
AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys(
requirements.plan.clone(),
&parent_required,
aggregate_exec,
)?),
_ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())),
AggregateMode::FinalPartitioned => {
let (new_plan, request_key_ordering) = reorder_aggregate_keys(
plan.clone(),
&parent_required,
aggregate_exec,
)?;
*plan = new_plan;
Ok((TreeNodeRecursion::Continue, request_key_ordering))
}
_ => Ok((
TreeNodeRecursion::Continue,
vec![None; plan.children().len()],
)),
}
} else {
// Keep everything unchanged
None
Ok((
TreeNodeRecursion::Continue,
vec![None; plan.children().len()],
))
}
} else if let Some(proj) = plan_any.downcast_ref::<ProjectionExec>() {
let expr = proj.expr();
Expand All @@ -396,34 +409,33 @@ fn adjust_input_keys_ordering(
// Construct a mapping from new name to the the orginal Column
let new_required = map_columns_before_projection(&parent_required, expr);
if new_required.len() == parent_required.len() {
Some(PlanWithKeyRequirements {
plan: requirements.plan.clone(),
required_key_ordering: vec![],
request_key_ordering: vec![Some(new_required.clone())],
})
Ok((
TreeNodeRecursion::Continue,
vec![Some(new_required.clone())],
))
} else {
// Can not satisfy, clear the current requirements and generate new empty requirements
Some(PlanWithKeyRequirements::new(requirements.plan.clone()))
Ok((
TreeNodeRecursion::Continue,
vec![None; plan.children().len()],
))
}
} else if plan_any.downcast_ref::<RepartitionExec>().is_some()
|| plan_any.downcast_ref::<CoalescePartitionsExec>().is_some()
|| plan_any.downcast_ref::<WindowAggExec>().is_some()
{
Some(PlanWithKeyRequirements::new(requirements.plan.clone()))
Ok((
TreeNodeRecursion::Continue,
vec![None; plan.children().len()],
))
} else {
// By default, push down the parent requirements to children
let children_len = requirements.plan.children().len();
Some(PlanWithKeyRequirements {
plan: requirements.plan.clone(),
required_key_ordering: vec![],
request_key_ordering: vec![Some(parent_required.clone()); children_len],
})
};
Ok(if let Some(transformed) = transformed {
Transformed::Yes(transformed)
} else {
Transformed::No(requirements)
})
let children_len = plan.children().len();
Ok((
TreeNodeRecursion::Continue,
vec![Some(parent_required.clone()); children_len],
))
}
}

fn reorder_partitioned_join_keys<F>(
Expand All @@ -432,7 +444,7 @@ fn reorder_partitioned_join_keys<F>(
on: &[(Column, Column)],
sort_options: Vec<SortOptions>,
join_constructor: &F,
) -> Result<PlanWithKeyRequirements>
) -> Result<(Arc<dyn ExecutionPlan>, Vec<RequiredKeyOrdering>)>
where
F: Fn((Vec<(Column, Column)>, Vec<SortOptions>)) -> Result<Arc<dyn ExecutionPlan>>,
{
Expand All @@ -455,35 +467,29 @@ where
new_sort_options.push(sort_options[new_positions[idx]])
}

Ok(PlanWithKeyRequirements {
plan: join_constructor((new_join_on, new_sort_options))?,
required_key_ordering: vec![],
request_key_ordering: vec![Some(left_keys), Some(right_keys)],
})
Ok((
join_constructor((new_join_on, new_sort_options))?,
vec![Some(left_keys), Some(right_keys)],
))
} else {
Ok(PlanWithKeyRequirements {
plan: join_plan,
required_key_ordering: vec![],
request_key_ordering: vec![Some(left_keys), Some(right_keys)],
})
Ok((join_plan, vec![Some(left_keys), Some(right_keys)]))
}
} else {
Ok(PlanWithKeyRequirements {
plan: join_plan,
required_key_ordering: vec![],
request_key_ordering: vec![
Ok((
join_plan,
vec![
Some(join_key_pairs.left_keys),
Some(join_key_pairs.right_keys),
],
})
))
}
}

fn reorder_aggregate_keys(
agg_plan: Arc<dyn ExecutionPlan>,
parent_required: &[Arc<dyn PhysicalExpr>],
agg_exec: &AggregateExec,
) -> Result<PlanWithKeyRequirements> {
) -> Result<(Arc<dyn ExecutionPlan>, Vec<RequiredKeyOrdering>)> {
let output_columns = agg_exec
.group_by()
.expr()
Expand All @@ -501,11 +507,15 @@ fn reorder_aggregate_keys(
|| !agg_exec.group_by().null_expr().is_empty()
|| physical_exprs_equal(&output_exprs, parent_required)
{
Ok(PlanWithKeyRequirements::new(agg_plan))
let request_key_ordering = vec![None; agg_plan.children().len()];
Ok((agg_plan, request_key_ordering))
} else {
let new_positions = expected_expr_positions(&output_exprs, parent_required);
match new_positions {
None => Ok(PlanWithKeyRequirements::new(agg_plan)),
None => {
let request_key_ordering = vec![None; agg_plan.children().len()];
Ok((agg_plan, request_key_ordering))
}
Some(positions) => {
let new_partial_agg = if let Some(agg_exec) =
agg_exec.input().as_any().downcast_ref::<AggregateExec>()
Expand Down Expand Up @@ -577,11 +587,13 @@ fn reorder_aggregate_keys(
.push((Arc::new(Column::new(name, idx)) as _, name.clone()))
}
// TODO merge adjacent Projections if there are
Ok(PlanWithKeyRequirements::new(Arc::new(
ProjectionExec::try_new(proj_exprs, new_final_agg)?,
)))
let new_plan =
Arc::new(ProjectionExec::try_new(proj_exprs, new_final_agg)?);
let request_key_ordering = vec![None; new_plan.children().len()];
Ok((new_plan, request_key_ordering))
} else {
Ok(PlanWithKeyRequirements::new(agg_plan))
let request_key_ordering = vec![None; agg_plan.children().len()];
Ok((agg_plan, request_key_ordering))
}
}
}
Expand Down Expand Up @@ -1539,93 +1551,6 @@ struct JoinKeyPairs {
right_keys: Vec<Arc<dyn PhysicalExpr>>,
}

#[derive(Debug, Clone)]
struct PlanWithKeyRequirements {
plan: Arc<dyn ExecutionPlan>,
/// Parent required key ordering
required_key_ordering: Vec<Arc<dyn PhysicalExpr>>,
/// The request key ordering to children
request_key_ordering: Vec<Option<Vec<Arc<dyn PhysicalExpr>>>>,
}

impl PlanWithKeyRequirements {
fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
let children_len = plan.children().len();
PlanWithKeyRequirements {
plan,
required_key_ordering: vec![],
request_key_ordering: vec![None; children_len],
}
}

fn children(&self) -> Vec<PlanWithKeyRequirements> {
let plan_children = self.plan.children();
assert_eq!(plan_children.len(), self.request_key_ordering.len());
plan_children
.into_iter()
.zip(self.request_key_ordering.clone())
.map(|(child, required)| {
let from_parent = required.unwrap_or_default();
let length = child.children().len();
PlanWithKeyRequirements {
plan: child,
required_key_ordering: from_parent,
request_key_ordering: vec![None; length],
}
})
.collect()
}
}

impl TreeNode for PlanWithKeyRequirements {
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
{
self.children().iter().for_each_till_continue(f)
}

fn map_children<F>(self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>,
{
let children = self.children();
if !children.is_empty() {
let new_children: Result<Vec<_>> =
children.into_iter().map(transform).collect();

let children_plans = new_children?
.into_iter()
.map(|child| child.plan)
.collect::<Vec<_>>();
let new_plan = with_new_children_if_necessary(self.plan, children_plans)?;
Ok(PlanWithKeyRequirements {
plan: new_plan.into(),
required_key_ordering: self.required_key_ordering,
request_key_ordering: self.request_key_ordering,
})
} else {
Ok(self)
}
}

fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
{
let mut children = self.children();
if !children.is_empty() {
let tnr = children.iter_mut().for_each_till_continue(f)?;
let children_plans = children.into_iter().map(|c| c.plan).collect();
self.plan =
with_new_children_if_necessary(self.plan.clone(), children_plans)?.into();
Ok(tnr)
} else {
Ok(TreeNodeRecursion::Continue)
}
}
}

/// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on
#[cfg(feature = "parquet")]
#[cfg(test)]
Expand Down

0 comments on commit 8fa80e7

Please sign in to comment.