Skip to content

Commit

Permalink
Move some cheaper checks earlier
Browse files Browse the repository at this point in the history
  • Loading branch information
DianQK committed Mar 3, 2024
1 parent 834683b commit 73ef0c0
Showing 1 changed file with 54 additions and 52 deletions.
106 changes: 54 additions & 52 deletions compiler/rustc_mir_transform/src/early_otherwise_branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("running EarlyOtherwiseBranch on {:?}", body.source);

let mut should_cleanup = false;
let mut should_apply_patch = false;
let mut patch = MirPatch::new(body);

// Also consider newly generated bbs in the same pass
for i in 0..body.basic_blocks.len() {
Expand All @@ -112,7 +113,7 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {

trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_data);

should_cleanup = true;
should_apply_patch = true;

let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
&bbs[parent].terminator().kind
Expand All @@ -129,8 +130,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
let statements_before = bbs[parent].statements.len();
let parent_end = Location { block: parent, statement_index: statements_before };

let mut patch = MirPatch::new(body);

let (second_discriminant_temp, second_operand) = if opt_data.hoist_discriminant {
// create temp to store second discriminant in, `_s` in example above
let second_discriminant_temp =
Expand Down Expand Up @@ -242,13 +241,12 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
);
}
}

patch.apply(body);
}

// Since this optimization adds new basic blocks and invalidates others,
// clean up the cfg to make it nicer for other passes
if should_cleanup {
if should_apply_patch {
patch.apply(body);
simplify_cfg(body);
}
}
Expand All @@ -275,19 +273,15 @@ fn evaluate_candidate<'tcx>(
return None;
};
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
let (_, child) = targets.iter().next()?;
let child_terminator = &bbs[child].terminator();
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
&child_terminator.kind
let mut targets_iter = targets.iter();
let (_, first_child) = targets_iter.next()?;
let first_child_terminator = &bbs[first_child].terminator();
let TerminatorKind::SwitchInt { targets: first_child_targets, discr: first_child_discr } =
&first_child_terminator.kind
else {
return None;
};
let child_ty = child_discr.ty(body.local_decls(), tcx);
if bbs[child].statements.len() > 1 {
return None;
}
let hoist_discriminant = bbs[child].statements.len() == 1;
let child_place = if hoist_discriminant {
let hoist_discriminant = if bbs[first_child].statements.len() == 1 {
if !bbs[targets.otherwise()].is_empty_unreachable() {
// Someone could write code like this:
// ```rust
Expand Down Expand Up @@ -320,7 +314,44 @@ fn evaluate_candidate<'tcx>(
// So we need the `otherwise` branch has no statements and an unreachable terminator.
return None;
}
let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind)
true
} else if bbs[first_child].statements.is_empty() {
false
} else {
return None;
};
let destination = if hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable() {
first_child_targets.otherwise()
} else {
if first_child_targets.otherwise() != targets.otherwise() {
return None;
}
targets.otherwise()
};
while let Some((_, child)) = targets_iter.next() {
let child_branch = &bbs[child];
// In order for the optimization to be correct, the branch must...
// ...have exactly one or empty statement
if (hoist_discriminant && child_branch.statements.len() != 1)
|| (!hoist_discriminant && !child_branch.statements.is_empty())
{
return None;
}
// ...terminate on a `SwitchInt` that invalidates that local
let TerminatorKind::SwitchInt { targets: child_targets, .. } =
&child_branch.terminator().kind
else {
return None;
};
if child_targets.otherwise() != destination {
return None;
}
// Make sure there are only two branches.
}
let child_ty = first_child_discr.ty(body.local_decls(), tcx);
let child_place = if hoist_discriminant {
let Some(StatementKind::Assign(boxed)) =
&bbs[first_child].statements.first().map(|x| &x.kind)
else {
return None;
};
Expand All @@ -329,26 +360,17 @@ fn evaluate_candidate<'tcx>(
};
*child_place
} else {
let TerminatorKind::SwitchInt { discr, .. } = &bbs[child].terminator().kind else {
let TerminatorKind::SwitchInt { discr, .. } = &bbs[first_child].terminator().kind else {
return None;
};
let Operand::Copy(child_place) = discr else {
return None;
};
*child_place
};
let destination = if hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable() {
child_targets.otherwise()
} else {
targets.otherwise()
};

let TerminatorKind::SwitchInt { targets: child_targets, .. } = &bbs[child].terminator().kind
else {
return None;
};
// Verify that the optimization is legal for each branch
let Some((may_same_target_value, _)) = child_targets.iter().next() else {
let Some((may_same_target_value, _)) = first_child_targets.iter().next() else {
return None;
};
let mut same_target_value = Some(may_same_target_value);
Expand All @@ -357,7 +379,6 @@ fn evaluate_candidate<'tcx>(
&bbs[child],
may_same_target_value,
child_place,
destination,
hoist_discriminant,
) {
same_target_value = None;
Expand All @@ -369,13 +390,7 @@ fn evaluate_candidate<'tcx>(
return None;
}
for (value, child) in targets.iter() {
if !verify_candidate_branch(
&bbs[child],
value,
child_place,
destination,
hoist_discriminant,
) {
if !verify_candidate_branch(&bbs[child], value, child_place, hoist_discriminant) {
return None;
}
}
Expand All @@ -384,7 +399,7 @@ fn evaluate_candidate<'tcx>(
destination,
child_place,
child_ty,
child_source: child_terminator.source_info,
child_source: first_child_terminator.source_info,
hoist_discriminant,
same_target_value,
})
Expand All @@ -394,20 +409,11 @@ fn verify_candidate_branch<'tcx>(
branch: &BasicBlockData<'tcx>,
value: u128,
place: Place<'tcx>,
destination: BasicBlock,
hoist_discriminant: bool,
) -> bool {
// In order for the optimization to be correct, the branch must...
// ...have exactly one statement
if (hoist_discriminant && branch.statements.len() != 1)
|| (!hoist_discriminant && !branch.statements.is_empty())
{
return false;
}
// ...terminate on a `SwitchInt` that invalidates that local
let TerminatorKind::SwitchInt { discr: switch_op, targets, .. } = &branch.terminator().kind
else {
return false;
unreachable!()
};
if hoist_discriminant {
// ...assign the discriminant of `place` in that statement
Expand All @@ -428,10 +434,6 @@ fn verify_candidate_branch<'tcx>(
return false;
}
}
// ...fall through to `destination` if the switch misses
if destination != targets.otherwise() {
return false;
}
// ...have a branch for value `value`
let mut iter = targets.iter();
let Some((target_value, _)) = iter.next() else {
Expand Down

0 comments on commit 73ef0c0

Please sign in to comment.