Skip to content

Commit

Permalink
Auto merge of rust-lang#120168 - dingxiangfei2009:coroutine-upvar, r=…
Browse files Browse the repository at this point in the history
…<try>

Relocate coroutine upvars into Unresumed state

Related to rust-lang#62958

This PR is an attempt to address the async/coroutine size issue by allowing independent def-use/liveness analysis on individual upvars in coroutines. It has appeared to address partially the size doubling issue introduced by the use of upvars.

However, there are caveats detailed in the following list that I would like to address before turning this draft in.
- The treatment here towards the `ty::Coroutine` in MIR passes is unfortunately "messier" than my liking, which is something I definitely want to change. I propose to promote upvars into `Body<'tcx>` along with `local_decls`, so that we can safely handle them safely. I would happily open a new separate PR to improve the upvar management.
- It is not a generic solution, yet. For instance, we are still doubling the size in the example of rust-lang#62958. If we insert a pass before MIR type analysis to remove unnecessary drops, which we can, that particular size doubling will be solved. However, if a `Future` upvar is alive across more than one yield points, that upvar is still ineligible. It makes sense because we would like to minimize moving of variant fields. How to handle these upvars is not the focus of this PR for now.

Out of expectation of possible change in the high level plan, I am keeping this as a draft in hope of invoking conversations. 🙇

cc `@pnkfelix` for the context.
  • Loading branch information
bors committed Apr 10, 2024
2 parents b14d8b2 + a68481f commit 6b8fde6
Show file tree
Hide file tree
Showing 60 changed files with 1,078 additions and 431 deletions.
3 changes: 1 addition & 2 deletions compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1523,8 +1523,7 @@ pub struct LayoutS<FieldIdx: Idx, VariantIdx: Idx> {

/// Encodes information about multi-variant layouts.
/// Even with `Multiple` variants, a layout still has its own fields! Those are then
/// shared between all variants. One of them will be the discriminant,
/// but e.g. coroutines can have more.
/// shared between all variants. One of them will be the discriminant.
///
/// To access all fields of this layout, both `fields` and the fields of the active variant
/// must be taken into account.
Expand Down
31 changes: 13 additions & 18 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -816,22 +816,18 @@ impl<'a, 'b, 'tcx> TypeVerifier<'a, 'b, 'tcx> {
}),
};
}
ty::CoroutineClosure(_, args) => {
return match args.as_coroutine_closure().upvar_tys().get(field.index()) {
ty::CoroutineClosure(_def_id, args) => {
let upvar_tys = args.as_coroutine_closure().upvar_tys();
return match upvar_tys.get(field.index()) {
Some(&ty) => Ok(ty),
None => Err(FieldAccessError::OutOfRange {
field_count: args.as_coroutine_closure().upvar_tys().len(),
}),
None => Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() }),
};
}
ty::Coroutine(_, args) => {
// Only prefix fields (upvars and current state) are
// accessible without a variant index.
return match args.as_coroutine().prefix_tys().get(field.index()) {
Some(ty) => Ok(*ty),
None => Err(FieldAccessError::OutOfRange {
field_count: args.as_coroutine().prefix_tys().len(),
}),
ty::Coroutine(_def_id, args) => {
let upvar_tys = args.as_coroutine().upvar_tys();
return match upvar_tys.get(field.index()) {
Some(&ty) => Ok(ty),
None => Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() }),
};
}
ty::Tuple(tys) => {
Expand Down Expand Up @@ -1905,11 +1901,10 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
// It doesn't make sense to look at a field beyond the prefix;
// these require a variant index, and are not initialized in
// aggregate rvalues.
match args.as_coroutine().prefix_tys().get(field_index.as_usize()) {
let upvar_tys = args.as_coroutine().upvar_tys();
match upvar_tys.get(field_index.as_usize()) {
Some(ty) => Ok(*ty),
None => Err(FieldAccessError::OutOfRange {
field_count: args.as_coroutine().prefix_tys().len(),
}),
None => Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() }),
}
}
AggregateKind::CoroutineClosure(_, args) => {
Expand Down Expand Up @@ -2534,7 +2529,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {

self.prove_aggregate_predicates(aggregate_kind, location);

if *aggregate_kind == AggregateKind::Tuple {
if matches!(aggregate_kind, AggregateKind::Tuple) {
// tuple rvalue field type is always the type of the op. Nothing to check here.
return;
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,9 @@ fn codegen_stmt<'tcx>(
let variant_dest = lval.downcast_variant(fx, variant_index);
(variant_index, variant_dest, active_field_index)
}
mir::AggregateKind::Coroutine(_, _) => {
(FIRST_VARIANT, lval.downcast_variant(fx, FIRST_VARIANT), None)
}
_ => (FIRST_VARIANT, lval, None),
};
if active_field_index.is_some() {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ fn build_upvar_field_di_nodes<'ll, 'tcx>(
closure_or_coroutine_di_node: &'ll DIType,
) -> SmallVec<&'ll DIType> {
let (&def_id, up_var_tys) = match closure_or_coroutine_ty.kind() {
ty::Coroutine(def_id, args) => (def_id, args.as_coroutine().prefix_tys()),
ty::Coroutine(def_id, args) => (def_id, args.as_coroutine().upvar_tys()),
ty::Closure(def_id, args) => (def_id, args.as_closure().upvar_tys()),
ty::CoroutineClosure(def_id, args) => (def_id, args.as_coroutine_closure().upvar_tys()),
_ => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -686,12 +686,12 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
let coroutine_layout =
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();

let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
let variant_count = (variant_range.start.as_u32()..variant_range.end.as_u32()).len();

let tag_base_type = tag_base_type(cx, coroutine_type_and_layout);

let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
let variant_names_type_di_node = build_variant_names_type_di_node(
cx,
coroutine_type_di_node,
Expand Down
29 changes: 4 additions & 25 deletions compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
coroutine_type_and_layout: TyAndLayout<'tcx>,
coroutine_type_di_node: &'ll DIType,
coroutine_layout: &CoroutineLayout<'tcx>,
common_upvar_names: &IndexSlice<FieldIdx, Symbol>,
_common_upvar_names: &IndexSlice<FieldIdx, Symbol>,
) -> &'ll DIType {
let variant_name = CoroutineArgs::variant_name(variant_index);
let unique_type_id = UniqueTypeId::for_enum_variant_struct_type(
Expand All @@ -337,7 +337,7 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(

let variant_layout = coroutine_type_and_layout.for_variant(cx, variant_index);

let coroutine_args = match coroutine_type_and_layout.ty.kind() {
let _coroutine_args = match coroutine_type_and_layout.ty.kind() {
ty::Coroutine(_, args) => args.as_coroutine(),
_ => unreachable!(),
};
Expand All @@ -355,7 +355,7 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
),
|cx, variant_struct_type_di_node| {
// Fields that just belong to this variant/state
let state_specific_fields: SmallVec<_> = (0..variant_layout.fields.count())
(0..variant_layout.fields.count())
.map(|field_index| {
let coroutine_saved_local = coroutine_layout.variant_fields[variant_index]
[FieldIdx::from_usize(field_index)];
Expand All @@ -377,28 +377,7 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
type_di_node(cx, field_type),
)
})
.collect();

// Fields that are common to all states
let common_fields: SmallVec<_> = coroutine_args
.prefix_tys()
.iter()
.zip(common_upvar_names)
.enumerate()
.map(|(index, (upvar_ty, upvar_name))| {
build_field_di_node(
cx,
variant_struct_type_di_node,
upvar_name.as_str(),
cx.size_and_align_of(upvar_ty),
coroutine_type_and_layout.fields.offset(index),
DIFlags::FlagZero,
type_di_node(cx, upvar_ty),
)
})
.collect();

state_specific_fields.into_iter().chain(common_fields).collect()
.collect()
},
|cx| build_generic_type_param_di_nodes(cx, coroutine_type_and_layout.ty),
)
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let variant_dest = dest.project_downcast(bx, variant_index);
(variant_index, variant_dest, active_field_index)
}
mir::AggregateKind::Coroutine(_, _) => {
(FIRST_VARIANT, dest.project_downcast(bx, FIRST_VARIANT), None)
}
_ => (FIRST_VARIANT, dest, None),
};
if active_field_index.is_some() {
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_const_eval/src/interpret/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
let variant_dest = self.project_downcast(dest, variant_index)?;
(variant_index, variant_dest, active_field_index)
}
mir::AggregateKind::Coroutine(_def_id, _args) => {
(FIRST_VARIANT, self.project_downcast(dest, FIRST_VARIANT)?, None)
}
_ => (FIRST_VARIANT, dest.clone(), None),
};
if active_field_index.is_some() {
Expand Down
12 changes: 5 additions & 7 deletions compiler/rustc_const_eval/src/transform/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,14 +743,12 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
};

ty::EarlyBinder::bind(f_ty.ty).instantiate(self.tcx, args)
} else if let Some(&ty) = args.as_coroutine().upvar_tys().get(f.as_usize())
{
ty
} else {
let Some(&f_ty) = args.as_coroutine().prefix_tys().get(f.index())
else {
fail_out_of_bounds(self, location);
return;
};

f_ty
fail_out_of_bounds(self, location);
return;
};

check_equal(self, location, f_ty);
Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_middle/src/mir/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ pub struct CoroutineLayout<'tcx> {
/// await).
pub variant_source_info: IndexVec<VariantIdx, SourceInfo>,

/// The starting index of upvars.
pub upvar_start: CoroutineSavedLocal,

/// Which saved locals are storage-live at the same time. Locals that do not
/// have conflicts with each other are allowed to overlap in the computed
/// layout.
Expand Down Expand Up @@ -101,6 +104,7 @@ impl Debug for CoroutineLayout<'_> {
}

fmt.debug_struct("CoroutineLayout")
.field("upvar_start", &self.upvar_start)
.field("field_tys", &MapPrinter::new(self.field_tys.iter_enumerated()))
.field(
"variant_fields",
Expand All @@ -110,7 +114,12 @@ impl Debug for CoroutineLayout<'_> {
.map(|(k, v)| (GenVariantPrinter(k), OneLinePrinter(v))),
),
)
.field("field_names", &MapPrinter::new(self.field_names.iter_enumerated()))
.field("storage_conflicts", &self.storage_conflicts)
.field(
"variant_source_info",
&MapPrinter::new(self.variant_source_info.iter_enumerated()),
)
.finish()
}
}
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_middle/src/mir/tcx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ impl<'tcx> PlaceTy<'tcx> {
T: ::std::fmt::Debug + Copy,
{
if self.variant_index.is_some() && !matches!(elem, ProjectionElem::Field(..)) {
bug!("cannot use non field projection on downcasted place")
bug!(
"cannot use non field projection on downcasted place from {:?} (variant {:?}), got {elem:?}",
self.ty,
self.variant_index
)
}
let answer = match *elem {
ProjectionElem::Deref => {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ where
if i == tag_field {
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
}
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
bug!("coroutine has no prefix field");
}
},

Expand Down
9 changes: 1 addition & 8 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
witness: witness.expect_ty(),
tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
},
_ => bug!("coroutine args missing synthetics"),
_ => bug!("coroutine args missing synthetics, got {:?}", self.args),
}
}

Expand Down Expand Up @@ -762,13 +762,6 @@ impl<'tcx> CoroutineArgs<'tcx> {
})
})
}

/// This is the types of the fields of a coroutine which are not stored in a
/// variant.
#[inline]
pub fn prefix_tys(self) -> &'tcx List<Ty<'tcx>> {
self.upvar_tys()
}
}

#[derive(Debug, Copy, Clone, HashStable)]
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_dataflow/src/framework/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ where
}

None if dump_enabled(tcx, A::NAME, def_id) => {
create_dump_file(tcx, ".dot", false, A::NAME, &pass_name.unwrap_or("-----"), body)?
create_dump_file(tcx, ".dot", true, A::NAME, &pass_name.unwrap_or("-----"), body)?
}

_ => return (Ok(()), results),
Expand Down
Loading

0 comments on commit 6b8fde6

Please sign in to comment.