diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs index d6ec7d64926c3..ccb229616e855 100644 --- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -126,59 +126,33 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { let mut field_remapping = UnordMap::default(); - // One parent capture may correspond to several child captures if we end up - // refining the set of captures via edition-2021 precise captures. We want to - // match up any number of child captures with one parent capture, so we keep - // peeking off this `Peekable` until the child doesn't match anymore. - let mut parent_captures = - tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable(); - // Make sure we use every field at least once, b/c why are we capturing something - // if it's not used in the inner coroutine. - let mut field_used_at_least_once = false; - - for (child_field_idx, child_capture) in tcx + let mut child_captures = tcx .closure_captures(coroutine_def_id) .iter() .copied() // By construction we capture all the args first. .skip(num_args) .enumerate() - { - loop { - let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else { - bug!("we ran out of parent captures!") - }; + .peekable(); - let PlaceBase::Upvar(parent_base) = parent_capture.place.base else { - bug!("expected capture to be an upvar"); - }; - let PlaceBase::Upvar(child_base) = child_capture.place.base else { - bug!("expected capture to be an upvar"); - }; + // One parent capture may correspond to several child captures if we end up + // refining the set of captures via edition-2021 precise captures. We want to + // match up any number of child captures with one parent capture, so we keep + // peeking off this `Peekable` until the child doesn't match anymore. + for (parent_field_idx, parent_capture) in + tcx.closure_captures(parent_def_id).iter().copied().enumerate() + { + // Make sure we use every field at least once, b/c why are we capturing something + // if it's not used in the inner coroutine. + let mut field_used_at_least_once = false; - assert!( - child_capture.place.projections.len() >= parent_capture.place.projections.len() - ); - // A parent matches a child they share the same prefix of projections. - // The child may have more, if it is capturing sub-fields out of - // something that is captured by-move in the parent closure. - if parent_base.var_path.hir_id != child_base.var_path.hir_id - || !std::iter::zip( - &child_capture.place.projections, - &parent_capture.place.projections, - ) - .all(|(child, parent)| child.kind == parent.kind) - { - // Make sure the field was used at least once. - assert!( - field_used_at_least_once, - "we captured {parent_capture:#?} but it was not used in the child coroutine?" - ); - field_used_at_least_once = false; - // Skip this field. - let _ = parent_captures.next().unwrap(); - continue; - } + // A parent matches a child if they share the same prefix of projections. + // The child may have more, if it is capturing sub-fields out of + // something that is captured by-move in the parent closure. + while child_captures.peek().map_or(false, |(_, child_capture)| { + child_prefix_matches_parent_projections(parent_capture, child_capture) + }) { + let (child_field_idx, child_capture) = child_captures.next().unwrap(); // Store this set of additional projections (fields and derefs). // We need to re-apply them later. @@ -221,15 +195,15 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { ); field_used_at_least_once = true; - break; } - } - // Pop the last parent capture - if field_used_at_least_once { - let _ = parent_captures.next().unwrap(); + // Make sure the field was used at least once. + assert!( + field_used_at_least_once, + "we captured {parent_capture:#?} but it was not used in the child coroutine?" + ); } - assert_eq!(parent_captures.next(), None, "leftover parent captures?"); + assert_eq!(child_captures.next(), None, "leftover child captures?"); if coroutine_kind == ty::ClosureKind::FnOnce { assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len()); @@ -251,6 +225,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { let mut by_move_body = body.clone(); MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body); dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(())); + // FIXME: use query feeding to generate the body right here and then only store the `DefId` of the new body. by_move_body.source = mir::MirSource::from_instance(InstanceDef::CoroutineKindShim { coroutine_def_id: coroutine_def_id.to_def_id(), }); @@ -258,6 +233,23 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { } } +fn child_prefix_matches_parent_projections( + parent_capture: &ty::CapturedPlace<'_>, + child_capture: &ty::CapturedPlace<'_>, +) -> bool { + let PlaceBase::Upvar(parent_base) = parent_capture.place.base else { + bug!("expected capture to be an upvar"); + }; + let PlaceBase::Upvar(child_base) = child_capture.place.base else { + bug!("expected capture to be an upvar"); + }; + + assert!(child_capture.place.projections.len() >= parent_capture.place.projections.len()); + parent_base.var_path.hir_id == child_base.var_path.hir_id + && std::iter::zip(&child_capture.place.projections, &parent_capture.place.projections) + .all(|(child, parent)| child.kind == parent.kind) +} + struct MakeByMoveBody<'tcx> { tcx: TyCtxt<'tcx>, field_remapping: UnordMap, bool, &'tcx [Projection<'tcx>])>,