diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs index d0acf0b560c74..0b06af223653c 100644 --- a/compiler/rustc_ast/src/ast.rs +++ b/compiler/rustc_ast/src/ast.rs @@ -1329,7 +1329,7 @@ pub struct Closure { pub binder: ClosureBinder, pub capture_clause: CaptureBy, pub constness: Const, - pub coro_kind: Option, + pub coroutine_kind: Option, pub movability: Movability, pub fn_decl: P, pub body: P, @@ -1534,6 +1534,7 @@ pub enum ExprKind { pub enum GenBlockKind { Async, Gen, + AsyncGen, } impl fmt::Display for GenBlockKind { @@ -1547,6 +1548,7 @@ impl GenBlockKind { match self { GenBlockKind::Async => "async", GenBlockKind::Gen => "gen", + GenBlockKind::AsyncGen => "async gen", } } } @@ -2431,10 +2433,12 @@ pub enum Unsafe { /// Iterator`. #[derive(Copy, Clone, Encodable, Decodable, Debug)] pub enum CoroutineKind { - /// `async`, which evaluates to `impl Future` + /// `async`, which returns an `impl Future` Async { span: Span, closure_id: NodeId, return_impl_trait_id: NodeId }, - /// `gen`, which evaluates to `impl Iterator` + /// `gen`, which returns an `impl Iterator` Gen { span: Span, closure_id: NodeId, return_impl_trait_id: NodeId }, + /// `async gen`, which returns an `impl AsyncIterator` + AsyncGen { span: Span, closure_id: NodeId, return_impl_trait_id: NodeId }, } impl CoroutineKind { @@ -2451,7 +2455,10 @@ impl CoroutineKind { pub fn return_id(self) -> (NodeId, Span) { match self { CoroutineKind::Async { return_impl_trait_id, span, .. } - | CoroutineKind::Gen { return_impl_trait_id, span, .. } => (return_impl_trait_id, span), + | CoroutineKind::Gen { return_impl_trait_id, span, .. } + | CoroutineKind::AsyncGen { return_impl_trait_id, span, .. } => { + (return_impl_trait_id, span) + } } } } @@ -2856,7 +2863,7 @@ pub struct FnHeader { /// The `unsafe` keyword, if any pub unsafety: Unsafe, /// Whether this is `async`, `gen`, or nothing. - pub coro_kind: Option, + pub coroutine_kind: Option, /// The `const` keyword, if any pub constness: Const, /// The `extern` keyword and corresponding ABI string, if any @@ -2866,9 +2873,9 @@ pub struct FnHeader { impl FnHeader { /// Does this function header have any qualifiers or is it empty? pub fn has_qualifiers(&self) -> bool { - let Self { unsafety, coro_kind, constness, ext } = self; + let Self { unsafety, coroutine_kind, constness, ext } = self; matches!(unsafety, Unsafe::Yes(_)) - || coro_kind.is_some() + || coroutine_kind.is_some() || matches!(constness, Const::Yes(_)) || !matches!(ext, Extern::None) } @@ -2876,7 +2883,12 @@ impl FnHeader { impl Default for FnHeader { fn default() -> FnHeader { - FnHeader { unsafety: Unsafe::No, coro_kind: None, constness: Const::No, ext: Extern::None } + FnHeader { + unsafety: Unsafe::No, + coroutine_kind: None, + constness: Const::No, + ext: Extern::None, + } } } diff --git a/compiler/rustc_ast/src/mut_visit.rs b/compiler/rustc_ast/src/mut_visit.rs index 342f5530b40fa..41c4e024447e8 100644 --- a/compiler/rustc_ast/src/mut_visit.rs +++ b/compiler/rustc_ast/src/mut_visit.rs @@ -121,8 +121,8 @@ pub trait MutVisitor: Sized { noop_visit_fn_decl(d, self); } - fn visit_coro_kind(&mut self, a: &mut CoroutineKind) { - noop_visit_coro_kind(a, self); + fn visit_coroutine_kind(&mut self, a: &mut CoroutineKind) { + noop_visit_coroutine_kind(a, self); } fn visit_closure_binder(&mut self, b: &mut ClosureBinder) { @@ -871,10 +871,11 @@ pub fn noop_visit_closure_binder(binder: &mut ClosureBinder, vis: } } -pub fn noop_visit_coro_kind(coro_kind: &mut CoroutineKind, vis: &mut T) { - match coro_kind { +pub fn noop_visit_coroutine_kind(coroutine_kind: &mut CoroutineKind, vis: &mut T) { + match coroutine_kind { CoroutineKind::Async { span, closure_id, return_impl_trait_id } - | CoroutineKind::Gen { span, closure_id, return_impl_trait_id } => { + | CoroutineKind::Gen { span, closure_id, return_impl_trait_id } + | CoroutineKind::AsyncGen { span, closure_id, return_impl_trait_id } => { vis.visit_span(span); vis.visit_id(closure_id); vis.visit_id(return_impl_trait_id); @@ -1171,9 +1172,9 @@ fn visit_const_item( } pub fn noop_visit_fn_header(header: &mut FnHeader, vis: &mut T) { - let FnHeader { unsafety, coro_kind, constness, ext: _ } = header; + let FnHeader { unsafety, coroutine_kind, constness, ext: _ } = header; visit_constness(constness, vis); - coro_kind.as_mut().map(|coro_kind| vis.visit_coro_kind(coro_kind)); + coroutine_kind.as_mut().map(|coroutine_kind| vis.visit_coroutine_kind(coroutine_kind)); visit_unsafety(unsafety, vis); } @@ -1407,7 +1408,7 @@ pub fn noop_visit_expr( binder, capture_clause, constness, - coro_kind, + coroutine_kind, movability: _, fn_decl, body, @@ -1416,7 +1417,7 @@ pub fn noop_visit_expr( }) => { vis.visit_closure_binder(binder); visit_constness(constness, vis); - coro_kind.as_mut().map(|coro_kind| vis.visit_coro_kind(coro_kind)); + coroutine_kind.as_mut().map(|coroutine_kind| vis.visit_coroutine_kind(coroutine_kind)); vis.visit_capture_by(capture_clause); vis.visit_fn_decl(fn_decl); vis.visit_expr(body); diff --git a/compiler/rustc_ast/src/visit.rs b/compiler/rustc_ast/src/visit.rs index 6b290fdfcc91e..ce5214efaca70 100644 --- a/compiler/rustc_ast/src/visit.rs +++ b/compiler/rustc_ast/src/visit.rs @@ -861,7 +861,7 @@ pub fn walk_expr<'a, V: Visitor<'a>>(visitor: &mut V, expression: &'a Expr) { ExprKind::Closure(box Closure { binder, capture_clause, - coro_kind: _, + coroutine_kind: _, constness: _, movability: _, fn_decl, diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs index 835ff36acbc05..c287c65ff3627 100644 --- a/compiler/rustc_ast_lowering/src/expr.rs +++ b/compiler/rustc_ast_lowering/src/expr.rs @@ -14,6 +14,7 @@ use rustc_ast::*; use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_hir as hir; use rustc_hir::def::{DefKind, Res}; +use rustc_middle::span_bug; use rustc_session::errors::report_lit_error; use rustc_span::source_map::{respan, Spanned}; use rustc_span::symbol::{kw, sym, Ident, Symbol}; @@ -196,22 +197,19 @@ impl<'hir> LoweringContext<'_, 'hir> { binder, capture_clause, constness, - coro_kind, + coroutine_kind, movability, fn_decl, body, fn_decl_span, fn_arg_span, - }) => match coro_kind { - Some( - CoroutineKind::Async { closure_id, .. } - | CoroutineKind::Gen { closure_id, .. }, - ) => self.lower_expr_async_closure( + }) => match coroutine_kind { + Some(coroutine_kind) => self.lower_expr_coroutine_closure( binder, *capture_clause, e.id, hir_id, - *closure_id, + *coroutine_kind, fn_decl, body, *fn_decl_span, @@ -325,6 +323,15 @@ impl<'hir> LoweringContext<'_, 'hir> { hir::CoroutineSource::Block, |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)), ), + ExprKind::Gen(capture_clause, block, GenBlockKind::AsyncGen) => self + .make_async_gen_expr( + *capture_clause, + e.id, + None, + e.span, + hir::CoroutineSource::Block, + |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)), + ), ExprKind::Yield(opt_expr) => self.lower_expr_yield(e.span, opt_expr.as_deref()), ExprKind::Err => hir::ExprKind::Err( self.tcx.sess.span_delayed_bug(e.span, "lowered ExprKind::Err"), @@ -736,6 +743,87 @@ impl<'hir> LoweringContext<'_, 'hir> { })) } + /// Lower a `async gen` construct to a generator that implements `AsyncIterator`. + /// + /// This results in: + /// + /// ```text + /// static move? |_task_context| -> () { + /// + /// } + /// ``` + pub(super) fn make_async_gen_expr( + &mut self, + capture_clause: CaptureBy, + closure_node_id: NodeId, + _yield_ty: Option>, + span: Span, + async_coroutine_source: hir::CoroutineSource, + body: impl FnOnce(&mut Self) -> hir::Expr<'hir>, + ) -> hir::ExprKind<'hir> { + let output = hir::FnRetTy::DefaultReturn(self.lower_span(span)); + + // Resume argument type: `ResumeTy` + let unstable_span = self.mark_span_with_reason( + DesugaringKind::Async, + span, + Some(self.allow_gen_future.clone()), + ); + let resume_ty = hir::QPath::LangItem(hir::LangItem::ResumeTy, unstable_span); + let input_ty = hir::Ty { + hir_id: self.next_id(), + kind: hir::TyKind::Path(resume_ty), + span: unstable_span, + }; + + // The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`. + let fn_decl = self.arena.alloc(hir::FnDecl { + inputs: arena_vec![self; input_ty], + output, + c_variadic: false, + implicit_self: hir::ImplicitSelfKind::None, + lifetime_elision_allowed: false, + }); + + // Lower the argument pattern/ident. The ident is used again in the `.await` lowering. + let (pat, task_context_hid) = self.pat_ident_binding_mode( + span, + Ident::with_dummy_span(sym::_task_context), + hir::BindingAnnotation::MUT, + ); + let param = hir::Param { + hir_id: self.next_id(), + pat, + ty_span: self.lower_span(span), + span: self.lower_span(span), + }; + let params = arena_vec![self; param]; + + let body = self.lower_body(move |this| { + this.coroutine_kind = Some(hir::CoroutineKind::AsyncGen(async_coroutine_source)); + + let old_ctx = this.task_context; + this.task_context = Some(task_context_hid); + let res = body(this); + this.task_context = old_ctx; + (params, res) + }); + + // `static |_task_context| -> { body }`: + hir::ExprKind::Closure(self.arena.alloc(hir::Closure { + def_id: self.local_def_id(closure_node_id), + binder: hir::ClosureBinder::Default, + capture_clause, + bound_generic_params: &[], + fn_decl, + body, + fn_decl_span: self.lower_span(span), + fn_arg_span: None, + movability: Some(hir::Movability::Static), + constness: hir::Constness::NotConst, + })) + } + /// Forwards a possible `#[track_caller]` annotation from `outer_hir_id` to /// `inner_hir_id` in case the `async_fn_track_caller` feature is enabled. pub(super) fn maybe_forward_track_caller( @@ -785,15 +873,18 @@ impl<'hir> LoweringContext<'_, 'hir> { /// ``` fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> { let full_span = expr.span.to(await_kw_span); - match self.coroutine_kind { - Some(hir::CoroutineKind::Async(_)) => {} + + let is_async_gen = match self.coroutine_kind { + Some(hir::CoroutineKind::Async(_)) => false, + Some(hir::CoroutineKind::AsyncGen(_)) => true, Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) | None => { return hir::ExprKind::Err(self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks { await_kw_span, item_span: self.current_item, })); } - } + }; + let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None); let gen_future_span = self.mark_span_with_reason( DesugaringKind::Await, @@ -882,12 +973,19 @@ impl<'hir> LoweringContext<'_, 'hir> { self.stmt_expr(span, match_expr) }; - // task_context = yield (); + // Depending on `async` of `async gen`: + // async - task_context = yield (); + // async gen - task_context = yield ASYNC_GEN_PENDING; let yield_stmt = { - let unit = self.expr_unit(span); + let yielded = if is_async_gen { + self.arena.alloc(self.expr_lang_item_path(span, hir::LangItem::AsyncGenPending)) + } else { + self.expr_unit(span) + }; + let yield_expr = self.expr( span, - hir::ExprKind::Yield(unit, hir::YieldSource::Await { expr: Some(expr_hir_id) }), + hir::ExprKind::Yield(yielded, hir::YieldSource::Await { expr: Some(expr_hir_id) }), ); let yield_expr = self.arena.alloc(yield_expr); @@ -997,7 +1095,11 @@ impl<'hir> LoweringContext<'_, 'hir> { } Some(movability) } - Some(hir::CoroutineKind::Gen(_)) | Some(hir::CoroutineKind::Async(_)) => { + Some( + hir::CoroutineKind::Gen(_) + | hir::CoroutineKind::Async(_) + | hir::CoroutineKind::AsyncGen(_), + ) => { panic!("non-`async`/`gen` closure body turned `async`/`gen` during lowering"); } None => { @@ -1024,18 +1126,22 @@ impl<'hir> LoweringContext<'_, 'hir> { (binder, params) } - fn lower_expr_async_closure( + fn lower_expr_coroutine_closure( &mut self, binder: &ClosureBinder, capture_clause: CaptureBy, closure_id: NodeId, closure_hir_id: hir::HirId, - inner_closure_id: NodeId, + coroutine_kind: CoroutineKind, decl: &FnDecl, body: &Expr, fn_decl_span: Span, fn_arg_span: Span, ) -> hir::ExprKind<'hir> { + let CoroutineKind::Async { closure_id: inner_closure_id, .. } = coroutine_kind else { + span_bug!(fn_decl_span, "`async gen` and `gen` closures are not supported, yet"); + }; + if let &ClosureBinder::For { span, .. } = binder { self.tcx.sess.emit_err(NotSupportedForLifetimeBinderAsyncClosure { span }); } @@ -1504,8 +1610,9 @@ impl<'hir> LoweringContext<'_, 'hir> { } fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> { - match self.coroutine_kind { - Some(hir::CoroutineKind::Gen(_)) => {} + let is_async_gen = match self.coroutine_kind { + Some(hir::CoroutineKind::Gen(_)) => false, + Some(hir::CoroutineKind::AsyncGen(_)) => true, Some(hir::CoroutineKind::Async(_)) => { return hir::ExprKind::Err( self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span }), @@ -1521,14 +1628,24 @@ impl<'hir> LoweringContext<'_, 'hir> { ) .emit(); } - self.coroutine_kind = Some(hir::CoroutineKind::Coroutine) + self.coroutine_kind = Some(hir::CoroutineKind::Coroutine); + false } - } + }; - let expr = + let mut yielded = opt_expr.as_ref().map(|x| self.lower_expr(x)).unwrap_or_else(|| self.expr_unit(span)); - hir::ExprKind::Yield(expr, hir::YieldSource::Yield) + if is_async_gen { + // yield async_gen_ready($expr); + yielded = self.expr_call_lang_item_fn( + span, + hir::LangItem::AsyncGenReady, + std::slice::from_ref(yielded), + ); + } + + hir::ExprKind::Yield(yielded, hir::YieldSource::Yield) } /// Desugar `ExprForLoop` from: `[opt_ident]: for in ` into: diff --git a/compiler/rustc_ast_lowering/src/item.rs b/compiler/rustc_ast_lowering/src/item.rs index 80854c8a6c08b..9d1f2684c394d 100644 --- a/compiler/rustc_ast_lowering/src/item.rs +++ b/compiler/rustc_ast_lowering/src/item.rs @@ -206,19 +206,25 @@ impl<'hir> LoweringContext<'_, 'hir> { // `impl Future` here because lower_body // only cares about the input argument patterns in the function // declaration (decl), not the return types. - let coro_kind = header.coro_kind; + let coroutine_kind = header.coroutine_kind; let body_id = this.lower_maybe_coroutine_body( span, hir_id, decl, - coro_kind, + coroutine_kind, body.as_deref(), ); let itctx = ImplTraitContext::Universal; let (generics, decl) = this.lower_generics(generics, header.constness, id, &itctx, |this| { - this.lower_fn_decl(decl, id, *fn_sig_span, FnDeclKind::Fn, coro_kind) + this.lower_fn_decl( + decl, + id, + *fn_sig_span, + FnDeclKind::Fn, + coroutine_kind, + ) }); let sig = hir::FnSig { decl, @@ -734,7 +740,7 @@ impl<'hir> LoweringContext<'_, 'hir> { sig, i.id, FnDeclKind::Trait, - sig.header.coro_kind, + sig.header.coroutine_kind, ); (generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Required(names)), false) } @@ -743,7 +749,7 @@ impl<'hir> LoweringContext<'_, 'hir> { i.span, hir_id, &sig.decl, - sig.header.coro_kind, + sig.header.coroutine_kind, Some(body), ); let (generics, sig) = self.lower_method_sig( @@ -751,7 +757,7 @@ impl<'hir> LoweringContext<'_, 'hir> { sig, i.id, FnDeclKind::Trait, - sig.header.coro_kind, + sig.header.coroutine_kind, ); (generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Provided(body_id)), true) } @@ -844,7 +850,7 @@ impl<'hir> LoweringContext<'_, 'hir> { i.span, hir_id, &sig.decl, - sig.header.coro_kind, + sig.header.coroutine_kind, body.as_deref(), ); let (generics, sig) = self.lower_method_sig( @@ -852,7 +858,7 @@ impl<'hir> LoweringContext<'_, 'hir> { sig, i.id, if self.is_in_trait_impl { FnDeclKind::Impl } else { FnDeclKind::Inherent }, - sig.header.coro_kind, + sig.header.coroutine_kind, ); (generics, hir::ImplItemKind::Fn(sig, body_id)) @@ -1023,17 +1029,16 @@ impl<'hir> LoweringContext<'_, 'hir> { span: Span, fn_id: hir::HirId, decl: &FnDecl, - coro_kind: Option, + coroutine_kind: Option, body: Option<&Block>, ) -> hir::BodyId { - let (Some(coro_kind), Some(body)) = (coro_kind, body) else { + let (Some(coroutine_kind), Some(body)) = (coroutine_kind, body) else { return self.lower_fn_body_block(span, decl, body); }; - let closure_id = match coro_kind { - CoroutineKind::Async { closure_id, .. } | CoroutineKind::Gen { closure_id, .. } => { - closure_id - } - }; + // FIXME(gen_blocks): Introduce `closure_id` method and remove ALL destructuring. + let (CoroutineKind::Async { closure_id, .. } + | CoroutineKind::Gen { closure_id, .. } + | CoroutineKind::AsyncGen { closure_id, .. }) = coroutine_kind; self.lower_body(|this| { let mut parameters: Vec> = Vec::new(); @@ -1200,7 +1205,8 @@ impl<'hir> LoweringContext<'_, 'hir> { this.expr_block(body) }; - let coroutine_expr = match coro_kind { + // FIXME(gen_blocks): Consider unifying the `make_*_expr` functions. + let coroutine_expr = match coroutine_kind { CoroutineKind::Async { .. } => this.make_async_expr( CaptureBy::Value { move_kw: rustc_span::DUMMY_SP }, closure_id, @@ -1217,6 +1223,14 @@ impl<'hir> LoweringContext<'_, 'hir> { hir::CoroutineSource::Fn, mkbody, ), + CoroutineKind::AsyncGen { .. } => this.make_async_gen_expr( + CaptureBy::Value { move_kw: rustc_span::DUMMY_SP }, + closure_id, + None, + body.span, + hir::CoroutineSource::Fn, + mkbody, + ), }; let hir_id = this.lower_node_id(closure_id); @@ -1233,19 +1247,19 @@ impl<'hir> LoweringContext<'_, 'hir> { sig: &FnSig, id: NodeId, kind: FnDeclKind, - coro_kind: Option, + coroutine_kind: Option, ) -> (&'hir hir::Generics<'hir>, hir::FnSig<'hir>) { let header = self.lower_fn_header(sig.header); let itctx = ImplTraitContext::Universal; let (generics, decl) = self.lower_generics(generics, sig.header.constness, id, &itctx, |this| { - this.lower_fn_decl(&sig.decl, id, sig.span, kind, coro_kind) + this.lower_fn_decl(&sig.decl, id, sig.span, kind, coroutine_kind) }); (generics, hir::FnSig { header, decl, span: self.lower_span(sig.span) }) } fn lower_fn_header(&mut self, h: FnHeader) -> hir::FnHeader { - let asyncness = if let Some(CoroutineKind::Async { span, .. }) = h.coro_kind { + let asyncness = if let Some(CoroutineKind::Async { span, .. }) = h.coroutine_kind { hir::IsAsync::Async(span) } else { hir::IsAsync::NotAsync diff --git a/compiler/rustc_ast_lowering/src/lib.rs b/compiler/rustc_ast_lowering/src/lib.rs index 5dda8f5a6a328..753650f732410 100644 --- a/compiler/rustc_ast_lowering/src/lib.rs +++ b/compiler/rustc_ast_lowering/src/lib.rs @@ -132,6 +132,7 @@ struct LoweringContext<'a, 'hir> { allow_try_trait: Lrc<[Symbol]>, allow_gen_future: Lrc<[Symbol]>, + allow_async_iterator: Lrc<[Symbol]>, /// Mapping from generics `def_id`s to TAIT generics `def_id`s. /// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic @@ -176,6 +177,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { } else { [sym::gen_future].into() }, + // FIXME(gen_blocks): how does `closure_track_caller` + allow_async_iterator: [sym::gen_future, sym::async_iterator].into(), generics_def_id_map: Default::default(), host_param_id: None, } @@ -1900,13 +1903,18 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { fn_span: Span, ) -> hir::FnRetTy<'hir> { let span = self.lower_span(fn_span); - let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::Async, span, None); - let opaque_ty_node_id = match coro { - CoroutineKind::Async { return_impl_trait_id, .. } - | CoroutineKind::Gen { return_impl_trait_id, .. } => return_impl_trait_id, + let (opaque_ty_node_id, allowed_features) = match coro { + CoroutineKind::Async { return_impl_trait_id, .. } => (return_impl_trait_id, None), + CoroutineKind::Gen { return_impl_trait_id, .. } => (return_impl_trait_id, None), + CoroutineKind::AsyncGen { return_impl_trait_id, .. } => { + (return_impl_trait_id, Some(self.allow_async_iterator.clone())) + } }; + let opaque_ty_span = + self.mark_span_with_reason(DesugaringKind::Async, span, allowed_features); + let captured_lifetimes: Vec<_> = self .resolver .take_extra_lifetime_params(opaque_ty_node_id) @@ -1925,7 +1933,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { let bound = this.lower_coroutine_fn_output_type_to_bound( output, coro, - span, + opaque_ty_span, ImplTraitContext::ReturnPositionOpaqueTy { origin: hir::OpaqueTyOrigin::FnReturn(fn_def_id), fn_kind, @@ -1944,7 +1952,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { &mut self, output: &FnRetTy, coro: CoroutineKind, - span: Span, + opaque_ty_span: Span, nested_impl_trait_context: ImplTraitContext, ) -> hir::GenericBound<'hir> { // Compute the `T` in `Future` from the return type. @@ -1960,20 +1968,21 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { // "<$assoc_ty_name = T>" let (assoc_ty_name, trait_lang_item) = match coro { - CoroutineKind::Async { .. } => (hir::FN_OUTPUT_NAME, hir::LangItem::Future), - CoroutineKind::Gen { .. } => (hir::ITERATOR_ITEM_NAME, hir::LangItem::Iterator), + CoroutineKind::Async { .. } => (sym::Output, hir::LangItem::Future), + CoroutineKind::Gen { .. } => (sym::Item, hir::LangItem::Iterator), + CoroutineKind::AsyncGen { .. } => (sym::Item, hir::LangItem::AsyncIterator), }; let future_args = self.arena.alloc(hir::GenericArgs { args: &[], - bindings: arena_vec![self; self.assoc_ty_binding(assoc_ty_name, span, output_ty)], + bindings: arena_vec![self; self.assoc_ty_binding(assoc_ty_name, opaque_ty_span, output_ty)], parenthesized: hir::GenericArgsParentheses::No, span_ext: DUMMY_SP, }); hir::GenericBound::LangItemTrait( trait_lang_item, - self.lower_span(span), + opaque_ty_span, self.next_id(), future_args, ) diff --git a/compiler/rustc_ast_lowering/src/path.rs b/compiler/rustc_ast_lowering/src/path.rs index 7ab0805d08667..efd80af5ef4ae 100644 --- a/compiler/rustc_ast_lowering/src/path.rs +++ b/compiler/rustc_ast_lowering/src/path.rs @@ -389,7 +389,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { FnRetTy::Default(_) => self.arena.alloc(self.ty_tup(*span, &[])), }; let args = smallvec![GenericArg::Type(self.arena.alloc(self.ty_tup(*inputs_span, inputs)))]; - let binding = self.assoc_ty_binding(hir::FN_OUTPUT_NAME, output_ty.span, output_ty); + let binding = self.assoc_ty_binding(sym::Output, output_ty.span, output_ty); ( GenericArgsCtor { args, diff --git a/compiler/rustc_ast_passes/src/ast_validation.rs b/compiler/rustc_ast_passes/src/ast_validation.rs index 554ed36b814e6..0644c4cd6be4c 100644 --- a/compiler/rustc_ast_passes/src/ast_validation.rs +++ b/compiler/rustc_ast_passes/src/ast_validation.rs @@ -1271,14 +1271,15 @@ impl<'a> Visitor<'a> for AstValidator<'a> { // Functions cannot both be `const async` or `const gen` if let Some(&FnHeader { constness: Const::Yes(cspan), - coro_kind: - Some( - CoroutineKind::Async { span: aspan, .. } - | CoroutineKind::Gen { span: aspan, .. }, - ), + coroutine_kind: Some(coro_kind), .. }) = fk.header() { + let aspan = match coro_kind { + CoroutineKind::Async { span: aspan, .. } + | CoroutineKind::Gen { span: aspan, .. } + | CoroutineKind::AsyncGen { span: aspan, .. } => aspan, + }; // FIXME(gen_blocks): Report a different error for `const gen` self.err_handler().emit_err(errors::ConstAndAsync { spans: vec![cspan, aspan], diff --git a/compiler/rustc_ast_pretty/src/pprust/state.rs b/compiler/rustc_ast_pretty/src/pprust/state.rs index 1ad28ffbf2bb6..ff36e6c284526 100644 --- a/compiler/rustc_ast_pretty/src/pprust/state.rs +++ b/compiler/rustc_ast_pretty/src/pprust/state.rs @@ -1490,14 +1490,18 @@ impl<'a> State<'a> { } } - fn print_coro_kind(&mut self, coro_kind: ast::CoroutineKind) { - match coro_kind { + fn print_coroutine_kind(&mut self, coroutine_kind: ast::CoroutineKind) { + match coroutine_kind { ast::CoroutineKind::Gen { .. } => { self.word_nbsp("gen"); } ast::CoroutineKind::Async { .. } => { self.word_nbsp("async"); } + ast::CoroutineKind::AsyncGen { .. } => { + self.word_nbsp("async"); + self.word_nbsp("gen"); + } } } @@ -1690,7 +1694,7 @@ impl<'a> State<'a> { fn print_fn_header_info(&mut self, header: ast::FnHeader) { self.print_constness(header.constness); - header.coro_kind.map(|coro_kind| self.print_coro_kind(coro_kind)); + header.coroutine_kind.map(|coroutine_kind| self.print_coroutine_kind(coroutine_kind)); self.print_unsafety(header.unsafety); match header.ext { diff --git a/compiler/rustc_ast_pretty/src/pprust/state/expr.rs b/compiler/rustc_ast_pretty/src/pprust/state/expr.rs index 0e6c3628aacb6..5397278bbb1c6 100644 --- a/compiler/rustc_ast_pretty/src/pprust/state/expr.rs +++ b/compiler/rustc_ast_pretty/src/pprust/state/expr.rs @@ -413,7 +413,7 @@ impl<'a> State<'a> { binder, capture_clause, constness, - coro_kind, + coroutine_kind, movability, fn_decl, body, @@ -423,7 +423,7 @@ impl<'a> State<'a> { self.print_closure_binder(binder); self.print_constness(*constness); self.print_movability(*movability); - coro_kind.map(|coro_kind| self.print_coro_kind(coro_kind)); + coroutine_kind.map(|coroutine_kind| self.print_coroutine_kind(coroutine_kind)); self.print_capture_clause(*capture_clause); self.print_fn_params_and_ret(fn_decl, true); diff --git a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs index 7bcad92ff337f..7e62bb9793d50 100644 --- a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs +++ b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs @@ -2517,12 +2517,23 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> { CoroutineKind::Gen(kind) => match kind { CoroutineSource::Block => "gen block", CoroutineSource::Closure => "gen closure", - _ => bug!("gen block/closure expected, but gen function found."), + CoroutineSource::Fn => { + bug!("gen block/closure expected, but gen function found.") + } + }, + CoroutineKind::AsyncGen(kind) => match kind { + CoroutineSource::Block => "async gen block", + CoroutineSource::Closure => "async gen closure", + CoroutineSource::Fn => { + bug!("gen block/closure expected, but gen function found.") + } }, CoroutineKind::Async(async_kind) => match async_kind { CoroutineSource::Block => "async block", CoroutineSource::Closure => "async closure", - _ => bug!("async block/closure expected, but async function found."), + CoroutineSource::Fn => { + bug!("async block/closure expected, but async function found.") + } }, CoroutineKind::Coroutine => "coroutine", }, diff --git a/compiler/rustc_borrowck/src/diagnostics/region_name.rs b/compiler/rustc_borrowck/src/diagnostics/region_name.rs index 977a5d5d50d61..a17c3bc3a78c5 100644 --- a/compiler/rustc_borrowck/src/diagnostics/region_name.rs +++ b/compiler/rustc_borrowck/src/diagnostics/region_name.rs @@ -684,7 +684,7 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> { hir::FnRetTy::Return(hir_ty) => (fn_decl.output.span(), Some(hir_ty)), }; let mir_description = match hir.body(body).coroutine_kind { - Some(hir::CoroutineKind::Async(gen)) => match gen { + Some(hir::CoroutineKind::Async(src)) => match src { hir::CoroutineSource::Block => " of async block", hir::CoroutineSource::Closure => " of async closure", hir::CoroutineSource::Fn => { @@ -701,7 +701,7 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> { " of async function" } }, - Some(hir::CoroutineKind::Gen(gen)) => match gen { + Some(hir::CoroutineKind::Gen(src)) => match src { hir::CoroutineSource::Block => " of gen block", hir::CoroutineSource::Closure => " of gen closure", hir::CoroutineSource::Fn => { @@ -715,6 +715,21 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> { " of gen function" } }, + + Some(hir::CoroutineKind::AsyncGen(src)) => match src { + hir::CoroutineSource::Block => " of async gen block", + hir::CoroutineSource::Closure => " of async gen closure", + hir::CoroutineSource::Fn => { + let parent_item = + hir.get_by_def_id(hir.get_parent_item(mir_hir_id).def_id); + let output = &parent_item + .fn_decl() + .expect("coroutine lowered from async gen fn should be in fn") + .output; + span = output.span(); + " of async gen function" + } + }, Some(hir::CoroutineKind::Coroutine) => " of coroutine", None => " of closure", }; diff --git a/compiler/rustc_builtin_macros/src/test.rs b/compiler/rustc_builtin_macros/src/test.rs index 81433155ecfd0..794be25955d63 100644 --- a/compiler/rustc_builtin_macros/src/test.rs +++ b/compiler/rustc_builtin_macros/src/test.rs @@ -541,12 +541,30 @@ fn check_test_signature( return Err(sd.emit_err(errors::TestBadFn { span: i.span, cause: span, kind: "unsafe" })); } - if let Some(ast::CoroutineKind::Async { span, .. }) = f.sig.header.coro_kind { - return Err(sd.emit_err(errors::TestBadFn { span: i.span, cause: span, kind: "async" })); - } - - if let Some(ast::CoroutineKind::Gen { span, .. }) = f.sig.header.coro_kind { - return Err(sd.emit_err(errors::TestBadFn { span: i.span, cause: span, kind: "gen" })); + if let Some(coro_kind) = f.sig.header.coroutine_kind { + match coro_kind { + ast::CoroutineKind::Async { span, .. } => { + return Err(sd.emit_err(errors::TestBadFn { + span: i.span, + cause: span, + kind: "async", + })); + } + ast::CoroutineKind::Gen { span, .. } => { + return Err(sd.emit_err(errors::TestBadFn { + span: i.span, + cause: span, + kind: "gen", + })); + } + ast::CoroutineKind::AsyncGen { span, .. } => { + return Err(sd.emit_err(errors::TestBadFn { + span: i.span, + cause: span, + kind: "async gen", + })); + } + } } // If the termination trait is active, the compiler will check that the output diff --git a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs index 8630e5623e168..dda30046bfbad 100644 --- a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs +++ b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs @@ -566,6 +566,9 @@ fn coroutine_kind_label(coroutine_kind: Option) -> &'static str { Some(CoroutineKind::Async(CoroutineSource::Block)) => "async_block", Some(CoroutineKind::Async(CoroutineSource::Closure)) => "async_closure", Some(CoroutineKind::Async(CoroutineSource::Fn)) => "async_fn", + Some(CoroutineKind::AsyncGen(CoroutineSource::Block)) => "async_gen_block", + Some(CoroutineKind::AsyncGen(CoroutineSource::Closure)) => "async_gen_closure", + Some(CoroutineKind::AsyncGen(CoroutineSource::Fn)) => "async_gen_fn", Some(CoroutineKind::Coroutine) => "coroutine", None => "closure", } diff --git a/compiler/rustc_codegen_ssa/src/mir/locals.rs b/compiler/rustc_codegen_ssa/src/mir/locals.rs index 378c540132207..7db260c9f5bd8 100644 --- a/compiler/rustc_codegen_ssa/src/mir/locals.rs +++ b/compiler/rustc_codegen_ssa/src/mir/locals.rs @@ -43,7 +43,11 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let local = mir::Local::from_usize(local); let expected_ty = self.monomorphize(self.mir.local_decls[local].ty); if expected_ty != op.layout.ty { - warn!("Unexpected initial operand type. See the issues/114858"); + warn!( + "Unexpected initial operand type: expected {expected_ty:?}, found {:?}.\ + See .", + op.layout.ty + ); } } } diff --git a/compiler/rustc_expand/src/build.rs b/compiler/rustc_expand/src/build.rs index 853554b2dcd90..86f555fa08bcf 100644 --- a/compiler/rustc_expand/src/build.rs +++ b/compiler/rustc_expand/src/build.rs @@ -547,7 +547,7 @@ impl<'a> ExtCtxt<'a> { binder: ast::ClosureBinder::NotPresent, capture_clause: ast::CaptureBy::Ref, constness: ast::Const::No, - coro_kind: None, + coroutine_kind: None, movability: ast::Movability::Movable, fn_decl, body, diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs index 3414b2f2412a9..01508375b1ace 100644 --- a/compiler/rustc_hir/src/hir.rs +++ b/compiler/rustc_hir/src/hir.rs @@ -1356,12 +1356,16 @@ impl<'hir> Body<'hir> { /// The type of source expression that caused this coroutine to be created. #[derive(Clone, PartialEq, Eq, Debug, Copy, Hash, HashStable_Generic, Encodable, Decodable)] pub enum CoroutineKind { - /// An explicit `async` block or the body of an async function. + /// An explicit `async` block or the body of an `async` function. Async(CoroutineSource), /// An explicit `gen` block or the body of a `gen` function. Gen(CoroutineSource), + /// An explicit `async gen` block or the body of an `async gen` function, + /// which is able to both `yield` and `.await`. + AsyncGen(CoroutineSource), + /// A coroutine literal created via a `yield` inside a closure. Coroutine, } @@ -1386,6 +1390,14 @@ impl fmt::Display for CoroutineKind { } k.fmt(f) } + CoroutineKind::AsyncGen(k) => { + if f.alternate() { + f.write_str("`async gen` ")?; + } else { + f.write_str("async gen ")? + } + k.fmt(f) + } } } } @@ -2081,17 +2093,6 @@ impl fmt::Display for YieldSource { } } -impl From for YieldSource { - fn from(kind: CoroutineKind) -> Self { - match kind { - // Guess based on the kind of the current coroutine. - CoroutineKind::Coroutine => Self::Yield, - CoroutineKind::Async(_) => Self::Await { expr: None }, - CoroutineKind::Gen(_) => Self::Yield, - } - } -} - // N.B., if you change this, you'll probably want to change the corresponding // type structure in middle/ty.rs as well. #[derive(Debug, Clone, Copy, HashStable_Generic)] @@ -2271,11 +2272,6 @@ pub enum ImplItemKind<'hir> { Type(&'hir Ty<'hir>), } -/// The name of the associated type for `Fn` return types. -pub const FN_OUTPUT_NAME: Symbol = sym::Output; -/// The name of the associated type for `Iterator` item types. -pub const ITERATOR_ITEM_NAME: Symbol = sym::Item; - /// Bind a type to an associated type (i.e., `A = Foo`). /// /// Bindings like `A: Debug` are represented as a special type `A = diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index 60f1449c177cb..b0b53bb7478d5 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -212,6 +212,7 @@ language_item_table! { Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0); Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0); + AsyncIterator, sym::async_iterator, async_iterator_trait, Target::Trait, GenericRequirement::Exact(0); CoroutineState, sym::coroutine_state, coroutine_state, Target::Enum, GenericRequirement::None; Coroutine, sym::coroutine, coroutine_trait, Target::Trait, GenericRequirement::Minimum(1); Unpin, sym::unpin, unpin_trait, Target::Trait, GenericRequirement::None; @@ -294,6 +295,10 @@ language_item_table! { PollReady, sym::Ready, poll_ready_variant, Target::Variant, GenericRequirement::None; PollPending, sym::Pending, poll_pending_variant, Target::Variant, GenericRequirement::None; + AsyncGenReady, sym::AsyncGenReady, async_gen_ready, Target::Method(MethodKind::Inherent), GenericRequirement::Exact(1); + AsyncGenPending, sym::AsyncGenPending, async_gen_pending, Target::AssocConst, GenericRequirement::Exact(1); + AsyncGenFinished, sym::AsyncGenFinished, async_gen_finished, Target::AssocConst, GenericRequirement::Exact(1); + // FIXME(swatinem): the following lang items are used for async lowering and // should become obsolete eventually. ResumeTy, sym::ResumeTy, resume_ty, Target::Struct, GenericRequirement::None; diff --git a/compiler/rustc_hir_typeck/src/check.rs b/compiler/rustc_hir_typeck/src/check.rs index 0cd1ae2dbfe5c..19b566ff9fa64 100644 --- a/compiler/rustc_hir_typeck/src/check.rs +++ b/compiler/rustc_hir_typeck/src/check.rs @@ -67,6 +67,28 @@ pub(super) fn check_fn<'a, 'tcx>( fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType); yield_ty } + // HACK(-Ztrait-solver=next): In the *old* trait solver, we must eagerly + // guide inference on the yield type so that we can handle `AsyncIterator` + // in this block in projection correctly. In the new trait solver, it is + // not a problem. + hir::CoroutineKind::AsyncGen(..) => { + let yield_ty = fcx.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::TypeInference, + span, + }); + fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType); + + Ty::new_adt( + tcx, + tcx.adt_def(tcx.require_lang_item(hir::LangItem::Poll, Some(span))), + tcx.mk_args(&[Ty::new_adt( + tcx, + tcx.adt_def(tcx.require_lang_item(hir::LangItem::Option, Some(span))), + tcx.mk_args(&[yield_ty.into()]), + ) + .into()]), + ) + } hir::CoroutineKind::Async(..) => Ty::new_unit(tcx), }; diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs index 1f2bd92a15c0d..df840aaa57884 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs @@ -763,6 +763,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let args = self.fresh_args_for_item(span, def_id); let ty = item_ty.instantiate(self.tcx, args); + self.write_args(hir_id, args); self.write_resolution(hir_id, Ok((def_kind, def_id))); let code = match lang_item { diff --git a/compiler/rustc_lint/src/early.rs b/compiler/rustc_lint/src/early.rs index 7c4f81a4c3970..80c6feaa26936 100644 --- a/compiler/rustc_lint/src/early.rs +++ b/compiler/rustc_lint/src/early.rs @@ -162,11 +162,10 @@ impl<'a, T: EarlyLintPass> ast_visit::Visitor<'a> for EarlyContextAndPass<'a, T> // Explicitly check for lints associated with 'closure_id', since // it does not have a corresponding AST node if let ast_visit::FnKind::Fn(_, _, sig, _, _, _) = fk { - if let Some( - ast::CoroutineKind::Async { closure_id, .. } - | ast::CoroutineKind::Gen { closure_id, .. }, - ) = sig.header.coro_kind - { + if let Some(coro_kind) = sig.header.coroutine_kind { + let (ast::CoroutineKind::Async { closure_id, .. } + | ast::CoroutineKind::Gen { closure_id, .. } + | ast::CoroutineKind::AsyncGen { closure_id, .. }) = coro_kind; self.check_id(closure_id); } } @@ -227,13 +226,13 @@ impl<'a, T: EarlyLintPass> ast_visit::Visitor<'a> for EarlyContextAndPass<'a, T> // it does not have a corresponding AST node match e.kind { ast::ExprKind::Closure(box ast::Closure { - coro_kind: - Some( - ast::CoroutineKind::Async { closure_id, .. } - | ast::CoroutineKind::Gen { closure_id, .. }, - ), - .. - }) => self.check_id(closure_id), + coroutine_kind: Some(coro_kind), .. + }) => { + let (ast::CoroutineKind::Async { closure_id, .. } + | ast::CoroutineKind::Gen { closure_id, .. } + | ast::CoroutineKind::AsyncGen { closure_id, .. }) = coro_kind; + self.check_id(closure_id); + } _ => {} } lint_callback!(self, check_expr_post, e); diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index 9a6ac6ff57a4e..aa4cb36c5cef2 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -150,11 +150,17 @@ impl AssertKind { RemainderByZero(_) => "attempt to calculate the remainder with a divisor of zero", ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion", ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion", + ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => { + "`async gen fn` resumed after completion" + } ResumedAfterReturn(CoroutineKind::Gen(_)) => { "`gen fn` should just keep returning `None` after completion" } ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking", ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking", + ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => { + "`async gen fn` resumed after panicking" + } ResumedAfterPanic(CoroutineKind::Gen(_)) => { "`gen fn` should just keep returning `None` after panicking" } @@ -245,6 +251,7 @@ impl AssertKind { DivisionByZero(_) => middle_assert_divide_by_zero, RemainderByZero(_) => middle_assert_remainder_by_zero, ResumedAfterReturn(CoroutineKind::Async(_)) => middle_assert_async_resume_after_return, + ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => todo!(), ResumedAfterReturn(CoroutineKind::Gen(_)) => { bug!("gen blocks can be resumed after they return and will keep returning `None`") } @@ -252,6 +259,7 @@ impl AssertKind { middle_assert_coroutine_resume_after_return } ResumedAfterPanic(CoroutineKind::Async(_)) => middle_assert_async_resume_after_panic, + ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => todo!(), ResumedAfterPanic(CoroutineKind::Gen(_)) => middle_assert_gen_resume_after_panic, ResumedAfterPanic(CoroutineKind::Coroutine) => { middle_assert_coroutine_resume_after_panic diff --git a/compiler/rustc_middle/src/traits/select.rs b/compiler/rustc_middle/src/traits/select.rs index 96ed1a4d0be1f..e8e2907eb33a9 100644 --- a/compiler/rustc_middle/src/traits/select.rs +++ b/compiler/rustc_middle/src/traits/select.rs @@ -144,10 +144,14 @@ pub enum SelectionCandidate<'tcx> { /// generated for an async construct. FutureCandidate, - /// Implementation of an `Iterator` trait by one of the generator types - /// generated for a gen construct. + /// Implementation of an `Iterator` trait by one of the coroutine types + /// generated for a `gen` construct. IteratorCandidate, + /// Implementation of an `AsyncIterator` trait by one of the coroutine types + /// generated for a `async gen` construct. + AsyncIteratorCandidate, + /// Implementation of a `Fn`-family trait by one of the anonymous /// types generated for a fn pointer type (e.g., `fn(int) -> int`) FnPointerCandidate { diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index 3012434ad3fb8..6ebfe778e7f21 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -825,11 +825,16 @@ impl<'tcx> TyCtxt<'tcx> { matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Coroutine)) } - /// Returns `true` if the node pointed to by `def_id` is a coroutine for a gen construct. + /// Returns `true` if the node pointed to by `def_id` is a coroutine for a `gen` construct. pub fn coroutine_is_gen(self, def_id: DefId) -> bool { matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Gen(_))) } + /// Returns `true` if the node pointed to by `def_id` is a coroutine for a `async gen` construct. + pub fn coroutine_is_async_gen(self, def_id: DefId) -> bool { + matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::AsyncGen(_))) + } + pub fn stability(self) -> &'tcx stability::Index { self.stability_index(()) } diff --git a/compiler/rustc_middle/src/ty/util.rs b/compiler/rustc_middle/src/ty/util.rs index 52c3529d2b4aa..b7c3edee9e59e 100644 --- a/compiler/rustc_middle/src/ty/util.rs +++ b/compiler/rustc_middle/src/ty/util.rs @@ -732,6 +732,7 @@ impl<'tcx> TyCtxt<'tcx> { DefKind::Closure if let Some(coroutine_kind) = self.coroutine_kind(def_id) => { match coroutine_kind { rustc_hir::CoroutineKind::Async(..) => "async closure", + rustc_hir::CoroutineKind::AsyncGen(..) => "async gen closure", rustc_hir::CoroutineKind::Coroutine => "coroutine", rustc_hir::CoroutineKind::Gen(..) => "gen closure", } @@ -752,6 +753,7 @@ impl<'tcx> TyCtxt<'tcx> { DefKind::Closure if let Some(coroutine_kind) = self.coroutine_kind(def_id) => { match coroutine_kind { rustc_hir::CoroutineKind::Async(..) => "an", + rustc_hir::CoroutineKind::AsyncGen(..) => "an", rustc_hir::CoroutineKind::Coroutine => "a", rustc_hir::CoroutineKind::Gen(..) => "a", } diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index 79a1509531d66..2b591abb05d66 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -66,9 +66,9 @@ use rustc_index::{Idx, IndexVec}; use rustc_middle::mir::dump_mir; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::*; +use rustc_middle::ty::CoroutineArgs; use rustc_middle::ty::InstanceDef; -use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt}; -use rustc_middle::ty::{CoroutineArgs, GenericArgsRef}; +use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_mir_dataflow::impls::{ MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive, }; @@ -225,8 +225,6 @@ struct SuspensionPoint<'tcx> { struct TransformVisitor<'tcx> { tcx: TyCtxt<'tcx>, coroutine_kind: hir::CoroutineKind, - state_adt_ref: AdtDef<'tcx>, - state_args: GenericArgsRef<'tcx>, // The type of the discriminant in the coroutine struct discr_ty: Ty<'tcx>, @@ -245,21 +243,34 @@ struct TransformVisitor<'tcx> { always_live_locals: BitSet, // The original RETURN_PLACE local - new_ret_local: Local, + old_ret_local: Local, + + old_yield_ty: Ty<'tcx>, + + old_ret_ty: Ty<'tcx>, } impl<'tcx> TransformVisitor<'tcx> { fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock { - let block = BasicBlock::new(body.basic_blocks.len()); + assert!(matches!(self.coroutine_kind, CoroutineKind::Gen(_))); + let block = BasicBlock::new(body.basic_blocks.len()); let source_info = SourceInfo::outermost(body.span); + let option_def_id = self.tcx.require_lang_item(LangItem::Option, None); - let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true); - assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); let statements = vec![Statement { kind: StatementKind::Assign(Box::new(( Place::return_place(), - Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + Rvalue::Aggregate( + Box::new(AggregateKind::Adt( + option_def_id, + VariantIdx::from_usize(0), + self.tcx.mk_args(&[self.old_yield_ty.into()]), + None, + None, + )), + IndexVec::new(), + ), ))), source_info, }]; @@ -273,23 +284,6 @@ impl<'tcx> TransformVisitor<'tcx> { block } - fn coroutine_state_adt_and_variant_idx( - &self, - is_return: bool, - ) -> (AggregateKind<'tcx>, VariantIdx) { - let idx = VariantIdx::new(match (is_return, self.coroutine_kind) { - (true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete - (false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded - (true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready - (false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending - (true, hir::CoroutineKind::Gen(_)) => 0, // Option::None - (false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some - }); - - let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None); - (kind, idx) - } - // Make a `CoroutineState` or `Poll` variant assignment. // // `core::ops::CoroutineState` only has single element tuple variants, @@ -302,51 +296,119 @@ impl<'tcx> TransformVisitor<'tcx> { is_return: bool, statements: &mut Vec>, ) { - let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return); - - match self.coroutine_kind { - // `Poll::Pending` + let rvalue = match self.coroutine_kind { CoroutineKind::Async(_) => { - if !is_return { - assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); - - // FIXME(swatinem): assert that `val` is indeed unit? - statements.push(Statement { - kind: StatementKind::Assign(Box::new(( - Place::return_place(), - Rvalue::Aggregate(Box::new(kind), IndexVec::new()), - ))), - source_info, - }); - return; + let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None); + let args = self.tcx.mk_args(&[self.old_ret_ty.into()]); + if is_return { + // Poll::Ready(val) + Rvalue::Aggregate( + Box::new(AggregateKind::Adt( + poll_def_id, + VariantIdx::from_usize(0), + args, + None, + None, + )), + IndexVec::from_raw(vec![val]), + ) + } else { + // Poll::Pending + Rvalue::Aggregate( + Box::new(AggregateKind::Adt( + poll_def_id, + VariantIdx::from_usize(1), + args, + None, + None, + )), + IndexVec::new(), + ) } } - // `Option::None` CoroutineKind::Gen(_) => { + let option_def_id = self.tcx.require_lang_item(LangItem::Option, None); + let args = self.tcx.mk_args(&[self.old_yield_ty.into()]); if is_return { - assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); - - statements.push(Statement { - kind: StatementKind::Assign(Box::new(( - Place::return_place(), - Rvalue::Aggregate(Box::new(kind), IndexVec::new()), - ))), - source_info, - }); - return; + // None + Rvalue::Aggregate( + Box::new(AggregateKind::Adt( + option_def_id, + VariantIdx::from_usize(0), + args, + None, + None, + )), + IndexVec::new(), + ) + } else { + // Some(val) + Rvalue::Aggregate( + Box::new(AggregateKind::Adt( + option_def_id, + VariantIdx::from_usize(1), + args, + None, + None, + )), + IndexVec::from_raw(vec![val]), + ) } } - CoroutineKind::Coroutine => {} - } - - // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)` - assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1); + CoroutineKind::AsyncGen(_) => { + if is_return { + let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() }; + let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() }; + let yield_ty = args.type_at(0); + Rvalue::Use(Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + const_: Const::Unevaluated( + UnevaluatedConst::new( + self.tcx.require_lang_item(LangItem::AsyncGenFinished, None), + self.tcx.mk_args(&[yield_ty.into()]), + ), + self.old_yield_ty, + ), + user_ty: None, + }))) + } else { + Rvalue::Use(val) + } + } + CoroutineKind::Coroutine => { + let coroutine_state_def_id = + self.tcx.require_lang_item(LangItem::CoroutineState, None); + let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]); + if is_return { + // CoroutineState::Complete(val) + Rvalue::Aggregate( + Box::new(AggregateKind::Adt( + coroutine_state_def_id, + VariantIdx::from_usize(1), + args, + None, + None, + )), + IndexVec::from_raw(vec![val]), + ) + } else { + // CoroutineState::Yielded(val) + Rvalue::Aggregate( + Box::new(AggregateKind::Adt( + coroutine_state_def_id, + VariantIdx::from_usize(0), + args, + None, + None, + )), + IndexVec::from_raw(vec![val]), + ) + } + } + }; statements.push(Statement { - kind: StatementKind::Assign(Box::new(( - Place::return_place(), - Rvalue::Aggregate(Box::new(kind), [val].into()), - ))), + kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))), source_info, }); } @@ -420,7 +482,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { let ret_val = match data.terminator().kind { TerminatorKind::Return => { - Some((true, None, Operand::Move(Place::from(self.new_ret_local)), None)) + Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None)) } TerminatorKind::Yield { ref value, resume, resume_arg, drop } => { Some((false, Some((resume, resume_arg)), value.clone(), drop)) @@ -1331,7 +1393,8 @@ fn create_coroutine_resume_function<'tcx>( if can_return { let block = match coroutine_kind { - CoroutineKind::Async(_) | CoroutineKind::Coroutine => { + // FIXME(gen_blocks): Should `async gen` yield `None` when resumed once again? + CoroutineKind::Async(_) | CoroutineKind::AsyncGen(_) | CoroutineKind::Coroutine => { insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind)) } CoroutineKind::Gen(_) => transform.insert_none_ret_block(body), @@ -1493,10 +1556,11 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>( impl<'tcx> MirPass<'tcx> for StateTransform { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let Some(yield_ty) = body.yield_ty() else { + let Some(old_yield_ty) = body.yield_ty() else { // This only applies to coroutines return; }; + let old_ret_ty = body.return_ty(); assert!(body.coroutine_drop().is_none()); @@ -1519,38 +1583,42 @@ impl<'tcx> MirPass<'tcx> for StateTransform { }; let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_))); + let is_async_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::AsyncGen(_))); let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_))); - let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() { + let new_ret_ty = match body.coroutine_kind().unwrap() { CoroutineKind::Async(_) => { // Compute Poll let poll_did = tcx.require_lang_item(LangItem::Poll, None); let poll_adt_ref = tcx.adt_def(poll_did); - let poll_args = tcx.mk_args(&[body.return_ty().into()]); - (poll_adt_ref, poll_args) + let poll_args = tcx.mk_args(&[old_ret_ty.into()]); + Ty::new_adt(tcx, poll_adt_ref, poll_args) } CoroutineKind::Gen(_) => { // Compute Option let option_did = tcx.require_lang_item(LangItem::Option, None); let option_adt_ref = tcx.adt_def(option_did); - let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]); - (option_adt_ref, option_args) + let option_args = tcx.mk_args(&[old_yield_ty.into()]); + Ty::new_adt(tcx, option_adt_ref, option_args) + } + CoroutineKind::AsyncGen(_) => { + // The yield ty is already `Poll>` + old_yield_ty } CoroutineKind::Coroutine => { // Compute CoroutineState let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); let state_adt_ref = tcx.adt_def(state_did); - let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]); - (state_adt_ref, state_args) + let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]); + Ty::new_adt(tcx, state_adt_ref, state_args) } }; - let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args); - // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local + // We rename RETURN_PLACE which has type mir.return_ty to old_ret_local // RETURN_PLACE then is a fresh unused local with type ret_ty. - let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx); + let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx); // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies. - if is_async_kind { + if is_async_kind || is_async_gen_kind { transform_async_context(tcx, body); } @@ -1564,9 +1632,10 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } else { body.local_decls[resume_local].ty }; - let new_resume_local = replace_local(resume_local, resume_ty, body, tcx); + let old_resume_local = replace_local(resume_local, resume_ty, body, tcx); - // When first entering the coroutine, move the resume argument into its new local. + // When first entering the coroutine, move the resume argument into its old local + // (which is now a generator interior). let source_info = SourceInfo::outermost(body.span); let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements; stmts.insert( @@ -1574,7 +1643,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { Statement { source_info, kind: StatementKind::Assign(Box::new(( - new_resume_local.into(), + old_resume_local.into(), Rvalue::Use(Operand::Move(resume_local.into())), ))), }, @@ -1610,14 +1679,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform { let mut transform = TransformVisitor { tcx, coroutine_kind: body.coroutine_kind().unwrap(), - state_adt_ref, - state_args, remap, storage_liveness, always_live_locals, suspension_points: Vec::new(), - new_ret_local, + old_ret_local, discr_ty, + old_ret_ty, + old_yield_ty, }; transform.visit_body(body); diff --git a/compiler/rustc_parse/messages.ftl b/compiler/rustc_parse/messages.ftl index 083e26651c13a..363b8f4bfb9cc 100644 --- a/compiler/rustc_parse/messages.ftl +++ b/compiler/rustc_parse/messages.ftl @@ -23,8 +23,6 @@ parse_async_block_in_2015 = `async` blocks are only allowed in Rust 2018 or late parse_async_fn_in_2015 = `async fn` is not permitted in Rust 2015 .label = to use `async fn`, switch to Rust 2018 or later -parse_async_gen_fn = `async gen` functions are not supported - parse_async_move_block_in_2015 = `async move` blocks are only allowed in Rust 2018 or later parse_async_move_order_incorrect = the order of `move` and `async` is incorrect diff --git a/compiler/rustc_parse/src/errors.rs b/compiler/rustc_parse/src/errors.rs index 45f950db5c309..bc53ab83439d1 100644 --- a/compiler/rustc_parse/src/errors.rs +++ b/compiler/rustc_parse/src/errors.rs @@ -562,13 +562,6 @@ pub(crate) struct GenFn { pub span: Span, } -#[derive(Diagnostic)] -#[diag(parse_async_gen_fn)] -pub(crate) struct AsyncGenFn { - #[primary_span] - pub span: Span, -} - #[derive(Diagnostic)] #[diag(parse_comma_after_base_struct)] #[note] diff --git a/compiler/rustc_parse/src/parser/expr.rs b/compiler/rustc_parse/src/parser/expr.rs index 42257054f49d3..406a6def019ef 100644 --- a/compiler/rustc_parse/src/parser/expr.rs +++ b/compiler/rustc_parse/src/parser/expr.rs @@ -1442,20 +1442,21 @@ impl<'a> Parser<'a> { } else if this.token.uninterpolated_span().at_least_rust_2018() { // `Span:.at_least_rust_2018()` is somewhat expensive; don't get it repeatedly. if this.check_keyword(kw::Async) { - if this.is_gen_block(kw::Async) { - // Check for `async {` and `async move {`. + // FIXME(gen_blocks): Parse `gen async` and suggest swap + if this.is_gen_block(kw::Async, 0) { + // Check for `async {` and `async move {`, + // or `async gen {` and `async gen move {`. this.parse_gen_block() } else { this.parse_expr_closure() } - } else if this.eat_keyword(kw::Await) { + } else if this.token.uninterpolated_span().at_least_rust_2024() + && (this.is_gen_block(kw::Gen, 0) + || (this.check_keyword(kw::Async) && this.is_gen_block(kw::Gen, 1))) + { + this.parse_gen_block() + } else if this.eat_keyword_noexpect(kw::Await) { this.recover_incorrect_await_syntax(lo, this.prev_token.span) - } else if this.token.uninterpolated_span().at_least_rust_2024() { - if this.is_gen_block(kw::Gen) { - this.parse_gen_block() - } else { - this.parse_expr_lit() - } } else { this.parse_expr_lit() } @@ -2234,8 +2235,8 @@ impl<'a> Parser<'a> { let movability = if self.eat_keyword(kw::Static) { Movability::Static } else { Movability::Movable }; - let asyncness = if self.token.uninterpolated_span().at_least_rust_2018() { - self.parse_asyncness(Case::Sensitive) + let coroutine_kind = if self.token.uninterpolated_span().at_least_rust_2018() { + self.parse_coroutine_kind(Case::Sensitive) } else { None }; @@ -2261,9 +2262,17 @@ impl<'a> Parser<'a> { } }; - if let Some(CoroutineKind::Async { span, .. }) = asyncness { - // Feature-gate `async ||` closures. - self.sess.gated_spans.gate(sym::async_closure, span); + match coroutine_kind { + Some(CoroutineKind::Async { span, .. }) => { + // Feature-gate `async ||` closures. + self.sess.gated_spans.gate(sym::async_closure, span); + } + Some(CoroutineKind::Gen { span, .. }) | Some(CoroutineKind::AsyncGen { span, .. }) => { + // Feature-gate `gen ||` and `async gen ||` closures. + // FIXME(gen_blocks): This perhaps should be a different gate. + self.sess.gated_spans.gate(sym::gen_blocks, span); + } + None => {} } if self.token.kind == TokenKind::Semi @@ -2284,7 +2293,7 @@ impl<'a> Parser<'a> { binder, capture_clause, constness, - coro_kind: asyncness, + coroutine_kind, movability, fn_decl, body, @@ -3207,7 +3216,7 @@ impl<'a> Parser<'a> { fn parse_gen_block(&mut self) -> PResult<'a, P> { let lo = self.token.span; let kind = if self.eat_keyword(kw::Async) { - GenBlockKind::Async + if self.eat_keyword(kw::Gen) { GenBlockKind::AsyncGen } else { GenBlockKind::Async } } else { assert!(self.eat_keyword(kw::Gen)); self.sess.gated_spans.gate(sym::gen_blocks, lo.to(self.token.span)); @@ -3219,22 +3228,26 @@ impl<'a> Parser<'a> { Ok(self.mk_expr_with_attrs(lo.to(self.prev_token.span), kind, attrs)) } - fn is_gen_block(&self, kw: Symbol) -> bool { - self.token.is_keyword(kw) + fn is_gen_block(&self, kw: Symbol, lookahead: usize) -> bool { + self.is_keyword_ahead(lookahead, &[kw]) && (( // `async move {` - self.is_keyword_ahead(1, &[kw::Move]) - && self.look_ahead(2, |t| { + self.is_keyword_ahead(lookahead + 1, &[kw::Move]) + && self.look_ahead(lookahead + 2, |t| { *t == token::OpenDelim(Delimiter::Brace) || t.is_whole_block() }) ) || ( // `async {` - self.look_ahead(1, |t| { + self.look_ahead(lookahead + 1, |t| { *t == token::OpenDelim(Delimiter::Brace) || t.is_whole_block() }) )) } + pub(super) fn is_async_gen_block(&self) -> bool { + self.token.is_keyword(kw::Async) && self.is_gen_block(kw::Gen, 1) + } + fn is_certainly_not_a_block(&self) -> bool { self.look_ahead(1, |t| t.is_ident()) && ( diff --git a/compiler/rustc_parse/src/parser/item.rs b/compiler/rustc_parse/src/parser/item.rs index 086e8d5cf9b7f..d22cc04d18206 100644 --- a/compiler/rustc_parse/src/parser/item.rs +++ b/compiler/rustc_parse/src/parser/item.rs @@ -2359,8 +2359,10 @@ impl<'a> Parser<'a> { || case == Case::Insensitive && t.is_non_raw_ident_where(|i| quals.iter().any(|qual| qual.as_str() == i.name.as_str().to_lowercase())) ) - // Rule out unsafe extern block. - && !self.is_unsafe_foreign_mod()) + // Rule out `unsafe extern {`. + && !self.is_unsafe_foreign_mod() + // Rule out `async gen {` and `async gen move {` + && !self.is_async_gen_block()) }) // `extern ABI fn` || self.check_keyword_case(kw::Extern, case) @@ -2392,10 +2394,7 @@ impl<'a> Parser<'a> { let constness = self.parse_constness(case); let async_start_sp = self.token.span; - let asyncness = self.parse_asyncness(case); - - let _gen_start_sp = self.token.span; - let genness = self.parse_genness(case); + let coroutine_kind = self.parse_coroutine_kind(case); let unsafe_start_sp = self.token.span; let unsafety = self.parse_unsafety(case); @@ -2403,7 +2402,7 @@ impl<'a> Parser<'a> { let ext_start_sp = self.token.span; let ext = self.parse_extern(case); - if let Some(CoroutineKind::Async { span, .. }) = asyncness { + if let Some(CoroutineKind::Async { span, .. }) = coroutine_kind { if span.is_rust_2015() { self.sess.emit_err(errors::AsyncFnIn2015 { span, @@ -2412,16 +2411,11 @@ impl<'a> Parser<'a> { } } - if let Some(CoroutineKind::Gen { span, .. }) = genness { - self.sess.gated_spans.gate(sym::gen_blocks, span); - } - - if let ( - Some(CoroutineKind::Async { span: async_span, .. }), - Some(CoroutineKind::Gen { span: gen_span, .. }), - ) = (asyncness, genness) - { - self.sess.emit_err(errors::AsyncGenFn { span: async_span.to(gen_span) }); + match coroutine_kind { + Some(CoroutineKind::Gen { span, .. }) | Some(CoroutineKind::AsyncGen { span, .. }) => { + self.sess.gated_spans.gate(sym::gen_blocks, span); + } + Some(CoroutineKind::Async { .. }) | None => {} } if !self.eat_keyword_case(kw::Fn, case) { @@ -2440,7 +2434,7 @@ impl<'a> Parser<'a> { // We may be able to recover let mut recover_constness = constness; - let mut recover_asyncness = asyncness; + let mut recover_coroutine_kind = coroutine_kind; let mut recover_unsafety = unsafety; // This will allow the machine fix to directly place the keyword in the correct place or to indicate // that the keyword is already present and the second instance should be removed. @@ -2453,15 +2447,24 @@ impl<'a> Parser<'a> { } } } else if self.check_keyword(kw::Async) { - match asyncness { + match coroutine_kind { Some(CoroutineKind::Async { span, .. }) => { Some(WrongKw::Duplicated(span)) } + Some(CoroutineKind::AsyncGen { span, .. }) => { + Some(WrongKw::Duplicated(span)) + } Some(CoroutineKind::Gen { .. }) => { - panic!("not sure how to recover here") + recover_coroutine_kind = Some(CoroutineKind::AsyncGen { + span: self.token.span, + closure_id: DUMMY_NODE_ID, + return_impl_trait_id: DUMMY_NODE_ID, + }); + // FIXME(gen_blocks): This span is wrong, didn't want to think about it. + Some(WrongKw::Misplaced(unsafe_start_sp)) } None => { - recover_asyncness = Some(CoroutineKind::Async { + recover_coroutine_kind = Some(CoroutineKind::Async { span: self.token.span, closure_id: DUMMY_NODE_ID, return_impl_trait_id: DUMMY_NODE_ID, @@ -2559,7 +2562,7 @@ impl<'a> Parser<'a> { return Ok(FnHeader { constness: recover_constness, unsafety: recover_unsafety, - coro_kind: recover_asyncness, + coroutine_kind: recover_coroutine_kind, ext, }); } @@ -2569,13 +2572,7 @@ impl<'a> Parser<'a> { } } - let coro_kind = match asyncness { - Some(CoroutineKind::Async { .. }) => asyncness, - Some(CoroutineKind::Gen { .. }) => unreachable!("asycness cannot be Gen"), - None => genness, - }; - - Ok(FnHeader { constness, unsafety, coro_kind, ext }) + Ok(FnHeader { constness, unsafety, coroutine_kind, ext }) } /// Parses the parameter list and result type of a function declaration. diff --git a/compiler/rustc_parse/src/parser/mod.rs b/compiler/rustc_parse/src/parser/mod.rs index 2816386cbad9f..7a306823ed498 100644 --- a/compiler/rustc_parse/src/parser/mod.rs +++ b/compiler/rustc_parse/src/parser/mod.rs @@ -1125,23 +1125,30 @@ impl<'a> Parser<'a> { } /// Parses asyncness: `async` or nothing. - fn parse_asyncness(&mut self, case: Case) -> Option { + fn parse_coroutine_kind(&mut self, case: Case) -> Option { + let span = self.token.uninterpolated_span(); if self.eat_keyword_case(kw::Async, case) { - let span = self.prev_token.uninterpolated_span(); - Some(CoroutineKind::Async { - span, - closure_id: DUMMY_NODE_ID, - return_impl_trait_id: DUMMY_NODE_ID, - }) - } else { - None - } - } - - /// Parses genness: `gen` or nothing. - fn parse_genness(&mut self, case: Case) -> Option { - if self.token.span.at_least_rust_2024() && self.eat_keyword_case(kw::Gen, case) { - let span = self.prev_token.uninterpolated_span(); + // FIXME(gen_blocks): Do we want to unconditionally parse `gen` and then + // error if edition <= 2024, like we do with async and edition <= 2018? + if self.token.uninterpolated_span().at_least_rust_2024() + && self.eat_keyword_case(kw::Gen, case) + { + let gen_span = self.prev_token.uninterpolated_span(); + Some(CoroutineKind::AsyncGen { + span: span.to(gen_span), + closure_id: DUMMY_NODE_ID, + return_impl_trait_id: DUMMY_NODE_ID, + }) + } else { + Some(CoroutineKind::Async { + span, + closure_id: DUMMY_NODE_ID, + return_impl_trait_id: DUMMY_NODE_ID, + }) + } + } else if self.token.uninterpolated_span().at_least_rust_2024() + && self.eat_keyword_case(kw::Gen, case) + { Some(CoroutineKind::Gen { span, closure_id: DUMMY_NODE_ID, diff --git a/compiler/rustc_parse/src/parser/ty.rs b/compiler/rustc_parse/src/parser/ty.rs index f349140e8c347..da8cc05ff66e8 100644 --- a/compiler/rustc_parse/src/parser/ty.rs +++ b/compiler/rustc_parse/src/parser/ty.rs @@ -598,7 +598,7 @@ impl<'a> Parser<'a> { tokens: None, }; let span_start = self.token.span; - let ast::FnHeader { ext, unsafety, constness, coro_kind } = + let ast::FnHeader { ext, unsafety, constness, coroutine_kind } = self.parse_fn_front_matter(&inherited_vis, Case::Sensitive)?; if self.may_recover() && self.token.kind == TokenKind::Lt { self.recover_fn_ptr_with_generics(lo, &mut params, param_insertion_point)?; @@ -611,7 +611,7 @@ impl<'a> Parser<'a> { // cover it. self.sess.emit_err(FnPointerCannotBeConst { span: whole_span, qualifier: span }); } - if let Some(ast::CoroutineKind::Async { span, .. }) = coro_kind { + if let Some(ast::CoroutineKind::Async { span, .. }) = coroutine_kind { self.sess.emit_err(FnPointerCannotBeAsync { span: whole_span, qualifier: span }); } // FIXME(gen_blocks): emit a similar error for `gen fn()` diff --git a/compiler/rustc_resolve/src/def_collector.rs b/compiler/rustc_resolve/src/def_collector.rs index ab5d3b368eb8a..186dd28b142e6 100644 --- a/compiler/rustc_resolve/src/def_collector.rs +++ b/compiler/rustc_resolve/src/def_collector.rs @@ -156,29 +156,33 @@ impl<'a, 'b, 'tcx> visit::Visitor<'a> for DefCollector<'a, 'b, 'tcx> { fn visit_fn(&mut self, fn_kind: FnKind<'a>, span: Span, _: NodeId) { if let FnKind::Fn(_, _, sig, _, generics, body) = fn_kind { - if let Some( - CoroutineKind::Async { closure_id, .. } | CoroutineKind::Gen { closure_id, .. }, - ) = sig.header.coro_kind - { - self.visit_generics(generics); - - // For async functions, we need to create their inner defs inside of a - // closure to match their desugared representation. Besides that, - // we must mirror everything that `visit::walk_fn` below does. - self.visit_fn_header(&sig.header); - for param in &sig.decl.inputs { - self.visit_param(param); - } - self.visit_fn_ret_ty(&sig.decl.output); - // If this async fn has no body (i.e. it's an async fn signature in a trait) - // then the closure_def will never be used, and we should avoid generating a - // def-id for it. - if let Some(body) = body { - let closure_def = - self.create_def(closure_id, kw::Empty, DefKind::Closure, span); - self.with_parent(closure_def, |this| this.visit_block(body)); + match sig.header.coroutine_kind { + Some( + CoroutineKind::Async { closure_id, .. } + | CoroutineKind::Gen { closure_id, .. } + | CoroutineKind::AsyncGen { closure_id, .. }, + ) => { + self.visit_generics(generics); + + // For async functions, we need to create their inner defs inside of a + // closure to match their desugared representation. Besides that, + // we must mirror everything that `visit::walk_fn` below does. + self.visit_fn_header(&sig.header); + for param in &sig.decl.inputs { + self.visit_param(param); + } + self.visit_fn_ret_ty(&sig.decl.output); + // If this async fn has no body (i.e. it's an async fn signature in a trait) + // then the closure_def will never be used, and we should avoid generating a + // def-id for it. + if let Some(body) = body { + let closure_def = + self.create_def(closure_id, kw::Empty, DefKind::Closure, span); + self.with_parent(closure_def, |this| this.visit_block(body)); + } + return; } - return; + None => {} } } @@ -284,10 +288,11 @@ impl<'a, 'b, 'tcx> visit::Visitor<'a> for DefCollector<'a, 'b, 'tcx> { // Async closures desugar to closures inside of closures, so // we must create two defs. let closure_def = self.create_def(expr.id, kw::Empty, DefKind::Closure, expr.span); - match closure.coro_kind { + match closure.coroutine_kind { Some( CoroutineKind::Async { closure_id, .. } - | CoroutineKind::Gen { closure_id, .. }, + | CoroutineKind::Gen { closure_id, .. } + | CoroutineKind::AsyncGen { closure_id, .. }, ) => self.create_def(closure_id, kw::Empty, DefKind::Closure, expr.span), None => closure_def, } diff --git a/compiler/rustc_resolve/src/late.rs b/compiler/rustc_resolve/src/late.rs index 07c8c036c9ebb..9c96e9a9bd728 100644 --- a/compiler/rustc_resolve/src/late.rs +++ b/compiler/rustc_resolve/src/late.rs @@ -916,8 +916,10 @@ impl<'a: 'ast, 'ast, 'tcx> Visitor<'ast> for LateResolutionVisitor<'a, '_, 'ast, &sig.decl.output, ); - if let Some((coro_node_id, _)) = - sig.header.coro_kind.map(|coro_kind| coro_kind.return_id()) + if let Some((coro_node_id, _)) = sig + .header + .coroutine_kind + .map(|coroutine_kind| coroutine_kind.return_id()) { this.record_lifetime_params_for_impl_trait(coro_node_id); } @@ -942,8 +944,10 @@ impl<'a: 'ast, 'ast, 'tcx> Visitor<'ast> for LateResolutionVisitor<'a, '_, 'ast, this.visit_generics(generics); let declaration = &sig.decl; - let coro_node_id = - sig.header.coro_kind.map(|coro_kind| coro_kind.return_id()); + let coro_node_id = sig + .header + .coroutine_kind + .map(|coroutine_kind| coroutine_kind.return_id()); this.with_lifetime_rib( LifetimeRibKind::AnonymousCreateParameter { @@ -4294,7 +4298,7 @@ impl<'a: 'ast, 'b, 'ast, 'tcx> LateResolutionVisitor<'a, 'b, 'ast, 'tcx> { // // Similarly, `gen |x| ...` gets desugared to `|x| gen {...}`, so we handle that too. ExprKind::Closure(box ast::Closure { - coro_kind: Some(_), + coroutine_kind: Some(_), ref fn_decl, ref body, .. diff --git a/compiler/rustc_smir/src/rustc_smir/convert/mod.rs b/compiler/rustc_smir/src/rustc_smir/convert/mod.rs index 9c0b2b29bca71..ce575f269caa1 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/mod.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/mod.rs @@ -56,6 +56,7 @@ impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineKind { stable_mir::mir::CoroutineKind::Gen(source.stable(tables)) } CoroutineKind::Coroutine => stable_mir::mir::CoroutineKind::Coroutine, + CoroutineKind::AsyncGen(_) => todo!(), } } } diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 5c1e703837a85..7b9b7b8529356 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -139,6 +139,9 @@ symbols! { AssertParamIsClone, AssertParamIsCopy, AssertParamIsEq, + AsyncGenFinished, + AsyncGenPending, + AsyncGenReady, AtomicBool, AtomicI128, AtomicI16, @@ -423,6 +426,7 @@ symbols! { async_closure, async_fn_in_trait, async_fn_track_caller, + async_iterator, atomic, atomic_mod, atomics, @@ -1200,6 +1204,7 @@ symbols! { pointer, pointer_like, poll, + poll_next, post_dash_lto: "post-lto", powerpc_target_feature, powf32, diff --git a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs index 201fade5ad795..62d62bdfd114d 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs @@ -207,6 +207,11 @@ pub(super) trait GoalKind<'tcx>: goal: Goal<'tcx, Self>, ) -> QueryResult<'tcx>; + fn consider_builtin_async_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx>; + /// A coroutine (that doesn't come from an `async` or `gen` desugaring) is known to /// implement `Coroutine`, given the resume, yield, /// and return types of the coroutine computed during type-checking. @@ -565,6 +570,8 @@ impl<'tcx> EvalCtxt<'_, 'tcx> { G::consider_builtin_future_candidate(self, goal) } else if lang_items.iterator_trait() == Some(trait_def_id) { G::consider_builtin_iterator_candidate(self, goal) + } else if lang_items.async_iterator_trait() == Some(trait_def_id) { + G::consider_builtin_async_iterator_candidate(self, goal) } else if lang_items.coroutine_trait() == Some(trait_def_id) { G::consider_builtin_coroutine_candidate(self, goal) } else if lang_items.discriminant_kind_trait() == Some(trait_def_id) { diff --git a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs index 867a520915f45..2fe51b400ec21 100644 --- a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs @@ -510,6 +510,40 @@ impl<'tcx> assembly::GoalKind<'tcx> for NormalizesTo<'tcx> { ) } + fn consider_builtin_async_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx> { + let self_ty = goal.predicate.self_ty(); + let ty::Coroutine(def_id, args, _) = *self_ty.kind() else { + return Err(NoSolution); + }; + + // Coroutines are not AsyncIterators unless they come from `gen` desugaring + let tcx = ecx.tcx(); + if !tcx.coroutine_is_async_gen(def_id) { + return Err(NoSolution); + } + + ecx.probe_misc_candidate("builtin AsyncIterator kind").enter(|ecx| { + // Take `AsyncIterator` and turn it into the corresponding + // coroutine yield ty `Poll>`. + let expected_ty = Ty::new_adt( + tcx, + tcx.adt_def(tcx.require_lang_item(LangItem::Poll, None)), + tcx.mk_args(&[Ty::new_adt( + tcx, + tcx.adt_def(tcx.require_lang_item(LangItem::Option, None)), + tcx.mk_args(&[goal.predicate.term.into()]), + ) + .into()]), + ); + let yield_ty = args.as_coroutine().yield_ty(); + ecx.eq(goal.param_env, expected_ty, yield_ty)?; + ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) + }) + } + fn consider_builtin_coroutine_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs index 95712da3c5e82..5807f7c6153cf 100644 --- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs +++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs @@ -370,6 +370,30 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> { ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) } + fn consider_builtin_async_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx> { + if goal.predicate.polarity != ty::ImplPolarity::Positive { + return Err(NoSolution); + } + + let ty::Coroutine(def_id, _, _) = *goal.predicate.self_ty().kind() else { + return Err(NoSolution); + }; + + // Coroutines are not iterators unless they come from `gen` desugaring + let tcx = ecx.tcx(); + if !tcx.coroutine_is_async_gen(def_id) { + return Err(NoSolution); + } + + // Gen coroutines unconditionally implement `Iterator` + // Technically, we need to check that the iterator output type is Sized, + // but that's already proven by the coroutines being WF. + ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) + } + fn consider_builtin_coroutine_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs index 6b231a30ea78c..7bf37cf79806a 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs @@ -2587,6 +2587,23 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { CoroutineKind::Async(CoroutineSource::Closure) => { format!("future created by async closure is not {trait_name}") } + CoroutineKind::AsyncGen(CoroutineSource::Fn) => self + .tcx + .parent(coroutine_did) + .as_local() + .map(|parent_did| self.tcx.local_def_id_to_hir_id(parent_did)) + .and_then(|parent_hir_id| hir.opt_name(parent_hir_id)) + .map(|name| { + format!("async iterator returned by `{name}` is not {trait_name}") + })?, + CoroutineKind::AsyncGen(CoroutineSource::Block) => { + format!("async iterator created by async gen block is not {trait_name}") + } + CoroutineKind::AsyncGen(CoroutineSource::Closure) => { + format!( + "async iterator created by async gen closure is not {trait_name}" + ) + } CoroutineKind::Gen(CoroutineSource::Fn) => self .tcx .parent(coroutine_did) @@ -3127,7 +3144,9 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { let what = match self.tcx.coroutine_kind(coroutine_def_id) { None | Some(hir::CoroutineKind::Coroutine) - | Some(hir::CoroutineKind::Gen(_)) => "yield", + | Some(hir::CoroutineKind::Gen(_)) + // FIXME(gen_blocks): This could be yield or await... + | Some(hir::CoroutineKind::AsyncGen(_)) => "yield", Some(hir::CoroutineKind::Async(..)) => "await", }; err.note(format!( diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs index 1f94fbaf9f8ab..8fa0dceda8742 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs @@ -1921,6 +1921,9 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> { hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "an async block", hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "an async function", hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "an async closure", + hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Block) => "an async gen block", + hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Fn) => "an async gen function", + hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Closure) => "an async gen closure", hir::CoroutineKind::Gen(hir::CoroutineSource::Block) => "a gen block", hir::CoroutineKind::Gen(hir::CoroutineSource::Fn) => "a gen function", hir::CoroutineKind::Gen(hir::CoroutineSource::Closure) => "a gen closure", diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 5b0829b57325b..a08e35b566f44 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -1823,11 +1823,18 @@ fn assemble_candidates_from_impls<'cx, 'tcx>( let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty()); let lang_items = selcx.tcx().lang_items(); - if [lang_items.coroutine_trait(), lang_items.future_trait(), lang_items.iterator_trait()].contains(&Some(trait_ref.def_id)) - || selcx.tcx().fn_trait_kind_from_def_id(trait_ref.def_id).is_some() + if [ + lang_items.coroutine_trait(), + lang_items.future_trait(), + lang_items.iterator_trait(), + lang_items.async_iterator_trait(), + lang_items.fn_trait(), + lang_items.fn_mut_trait(), + lang_items.fn_once_trait(), + ].contains(&Some(trait_ref.def_id)) { true - } else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) { + }else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) { match self_ty.kind() { ty::Bool | ty::Char @@ -2042,6 +2049,8 @@ fn confirm_select_candidate<'cx, 'tcx>( confirm_future_candidate(selcx, obligation, data) } else if lang_items.iterator_trait() == Some(trait_def_id) { confirm_iterator_candidate(selcx, obligation, data) + } else if lang_items.async_iterator_trait() == Some(trait_def_id) { + confirm_async_iterator_candidate(selcx, obligation, data) } else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() { if obligation.predicate.self_ty().is_closure() { confirm_closure_candidate(selcx, obligation, data) @@ -2203,6 +2212,57 @@ fn confirm_iterator_candidate<'cx, 'tcx>( .with_addl_obligations(obligations) } +fn confirm_async_iterator_candidate<'cx, 'tcx>( + selcx: &mut SelectionContext<'cx, 'tcx>, + obligation: &ProjectionTyObligation<'tcx>, + nested: Vec>, +) -> Progress<'tcx> { + let ty::Coroutine(_, args, _) = + selcx.infcx.shallow_resolve(obligation.predicate.self_ty()).kind() + else { + unreachable!() + }; + let gen_sig = args.as_coroutine().sig(); + let Normalized { value: gen_sig, obligations } = normalize_with_depth( + selcx, + obligation.param_env, + obligation.cause.clone(), + obligation.recursion_depth + 1, + gen_sig, + ); + + debug!(?obligation, ?gen_sig, ?obligations, "confirm_async_iterator_candidate"); + + let tcx = selcx.tcx(); + let iter_def_id = tcx.require_lang_item(LangItem::AsyncIterator, None); + + let (trait_ref, yield_ty) = super::util::async_iterator_trait_ref_and_outputs( + tcx, + iter_def_id, + obligation.predicate.self_ty(), + gen_sig, + ); + + debug_assert_eq!(tcx.associated_item(obligation.predicate.def_id).name, sym::Item); + + let ty::Adt(_poll_adt, args) = *yield_ty.kind() else { + bug!(); + }; + let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { + bug!(); + }; + let item_ty = args.type_at(0); + + let predicate = ty::ProjectionPredicate { + projection_ty: ty::AliasTy::new(tcx, obligation.predicate.def_id, trait_ref.args), + term: item_ty.into(), + }; + + confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false) + .with_addl_obligations(nested) + .with_addl_obligations(obligations) +} + fn confirm_builtin_candidate<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTyObligation<'tcx>, diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 367de517af2be..c7d0ab7164411 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -112,6 +112,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { self.assemble_future_candidates(obligation, &mut candidates); } else if lang_items.iterator_trait() == Some(def_id) { self.assemble_iterator_candidates(obligation, &mut candidates); + } else if lang_items.async_iterator_trait() == Some(def_id) { + self.assemble_async_iterator_candidates(obligation, &mut candidates); } self.assemble_closure_candidates(obligation, &mut candidates); @@ -258,6 +260,34 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { } } + fn assemble_async_iterator_candidates( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + candidates: &mut SelectionCandidateSet<'tcx>, + ) { + let self_ty = obligation.self_ty().skip_binder(); + if let ty::Coroutine(did, args, _) = *self_ty.kind() { + // gen constructs get lowered to a special kind of coroutine that + // should directly `impl AsyncIterator`. + if self.tcx().coroutine_is_async_gen(did) { + debug!(?self_ty, ?obligation, "assemble_iterator_candidates",); + + // Can only confirm this candidate if we have constrained + // the `Yield` type to at least `Poll>`.. + let ty::Adt(_poll_def, args) = *args.as_coroutine().yield_ty().kind() else { + candidates.ambiguous = true; + return; + }; + let ty::Adt(_option_def, _) = *args.type_at(0).kind() else { + candidates.ambiguous = true; + return; + }; + + candidates.vec.push(AsyncIteratorCandidate); + } + } + } + /// Checks for the artificial impl that the compiler will create for an obligation like `X : /// FnMut<..>` where `X` is a closure type. /// diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs index 8567f4f0e70e3..4a342a7f6b146 100644 --- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs +++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs @@ -98,6 +98,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator) } + AsyncIteratorCandidate => { + let vtable_iterator = self.confirm_async_iterator_candidate(obligation)?; + ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator) + } + FnPointerCandidate { is_const } => { let data = self.confirm_fn_pointer_candidate(obligation, is_const)?; ImplSource::Builtin(BuiltinImplSource::Misc, data) @@ -813,6 +818,35 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { Ok(nested) } + fn confirm_async_iterator_candidate( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + ) -> Result>, SelectionError<'tcx>> { + // Okay to skip binder because the args on coroutine types never + // touch bound regions, they just capture the in-scope + // type/region parameters. + let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); + let ty::Coroutine(coroutine_def_id, args, _) = *self_ty.kind() else { + bug!("closure candidate for non-closure {:?}", obligation); + }; + + debug!(?obligation, ?coroutine_def_id, ?args, "confirm_async_iterator_candidate"); + + let gen_sig = args.as_coroutine().sig(); + + let (trait_ref, _) = super::util::async_iterator_trait_ref_and_outputs( + self.tcx(), + obligation.predicate.def_id(), + obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(), + gen_sig, + ); + + let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?; + debug!(?trait_ref, ?nested, "iterator candidate obligations"); + + Ok(nested) + } + #[instrument(skip(self), level = "debug")] fn confirm_closure_candidate( &mut self, diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs index 6ba379467dac7..7f31a2529f5c2 100644 --- a/compiler/rustc_trait_selection/src/traits/select/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs @@ -1875,6 +1875,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate @@ -1904,6 +1905,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate @@ -1939,6 +1941,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate @@ -1954,6 +1957,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate @@ -2061,6 +2065,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate @@ -2072,6 +2077,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate diff --git a/compiler/rustc_trait_selection/src/traits/util.rs b/compiler/rustc_trait_selection/src/traits/util.rs index 5574badf23803..98da3bc2fe9fa 100644 --- a/compiler/rustc_trait_selection/src/traits/util.rs +++ b/compiler/rustc_trait_selection/src/traits/util.rs @@ -308,6 +308,17 @@ pub fn iterator_trait_ref_and_outputs<'tcx>( (trait_ref, sig.yield_ty) } +pub fn async_iterator_trait_ref_and_outputs<'tcx>( + tcx: TyCtxt<'tcx>, + async_iterator_def_id: DefId, + self_ty: Ty<'tcx>, + sig: ty::GenSig<'tcx>, +) -> (ty::TraitRef<'tcx>, Ty<'tcx>) { + assert!(!self_ty.has_escaping_bound_vars()); + let trait_ref = ty::TraitRef::new(tcx, async_iterator_def_id, [self_ty]); + (trait_ref, sig.yield_ty) +} + pub fn impl_item_is_final(tcx: TyCtxt<'_>, assoc_item: &ty::AssocItem) -> bool { assoc_item.defaultness(tcx).is_final() && tcx.defaultness(assoc_item.container_id(tcx)).is_final() diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs index a58e98ce99c4c..a5f11ca23e124 100644 --- a/compiler/rustc_ty_utils/src/abi.rs +++ b/compiler/rustc_ty_utils/src/abi.rs @@ -119,9 +119,9 @@ fn fn_sig_for_fn_abi<'tcx>( // unlike for all other coroutine kinds. env_ty } - hir::CoroutineKind::Async(_) | hir::CoroutineKind::Coroutine => { - Ty::new_adt(tcx, pin_adt_ref, pin_args) - } + hir::CoroutineKind::Async(_) + | hir::CoroutineKind::AsyncGen(_) + | hir::CoroutineKind::Coroutine => Ty::new_adt(tcx, pin_adt_ref, pin_args), }; // The `FnSig` and the `ret_ty` here is for a coroutines main @@ -168,6 +168,30 @@ fn fn_sig_for_fn_abi<'tcx>( (None, ret_ty) } + hir::CoroutineKind::AsyncGen(_) => { + // The signature should be + // `AsyncIterator::poll_next(_, &mut Context<'_>) -> Poll>` + assert_eq!(sig.return_ty, tcx.types.unit); + + // Yield type is already `Poll>` + let ret_ty = sig.yield_ty; + + // We have to replace the `ResumeTy` that is used for type and borrow checking + // with `&mut Context<'_>` which is used in codegen. + #[cfg(debug_assertions)] + { + if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() { + let expected_adt = + tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None)); + assert_eq!(*resume_ty_adt, expected_adt); + } else { + panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty); + }; + } + let context_mut_ref = Ty::new_task_context(tcx); + + (Some(context_mut_ref), ret_ty) + } hir::CoroutineKind::Coroutine => { // The signature should be `Coroutine::resume(_, Resume) -> CoroutineState` let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs index a0f01d9eca979..f1c9bb23e5d6e 100644 --- a/compiler/rustc_ty_utils/src/instance.rs +++ b/compiler/rustc_ty_utils/src/instance.rs @@ -271,6 +271,21 @@ fn resolve_associated_item<'tcx>( debug_assert!(tcx.defaultness(trait_item_id).has_value()); Some(Instance::new(trait_item_id, rcvr_args)) } + } else if Some(trait_ref.def_id) == lang_items.async_iterator_trait() { + let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else { + bug!() + }; + + if cfg!(debug_assertions) && tcx.item_name(trait_item_id) != sym::poll_next { + span_bug!( + tcx.def_span(coroutine_def_id), + "no definition for `{trait_ref}::{}` for built-in coroutine type", + tcx.item_name(trait_item_id) + ) + } + + // `AsyncIterator::poll_next` is generated by the compiler. + Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args }) } else if Some(trait_ref.def_id) == lang_items.coroutine_trait() { let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else { bug!() diff --git a/library/core/src/async_iter/async_iter.rs b/library/core/src/async_iter/async_iter.rs index 12a47f9fc7626..8a45bd36f7a29 100644 --- a/library/core/src/async_iter/async_iter.rs +++ b/library/core/src/async_iter/async_iter.rs @@ -13,6 +13,7 @@ use crate::task::{Context, Poll}; #[unstable(feature = "async_iterator", issue = "79024")] #[must_use = "async iterators do nothing unless polled"] #[doc(alias = "Stream")] +#[cfg_attr(not(bootstrap), lang = "async_iterator")] pub trait AsyncIterator { /// The type of items yielded by the async iterator. type Item; @@ -109,3 +110,27 @@ where (**self).size_hint() } } + +#[unstable(feature = "async_gen_internals", issue = "none")] +impl Poll> { + /// A helper function for internal desugaring -- produces `Ready(Some(t))`, + /// which corresponds to the async iterator yielding a value. + #[unstable(feature = "async_gen_internals", issue = "none")] + #[cfg_attr(not(bootstrap), lang = "AsyncGenReady")] + pub fn async_gen_ready(t: T) -> Self { + Poll::Ready(Some(t)) + } + + /// A helper constant for internal desugaring -- produces `Pending`, + /// which corresponds to the async iterator pending on an `.await`. + #[unstable(feature = "async_gen_internals", issue = "none")] + #[cfg_attr(not(bootstrap), lang = "AsyncGenPending")] + // FIXME(gen_blocks): This probably could be deduplicated. + pub const PENDING: Self = Poll::Pending; + + /// A helper constant for internal desugaring -- produces `Ready(None)`, + /// which corresponds to the async iterator finishing its iteration. + #[unstable(feature = "async_gen_internals", issue = "none")] + #[cfg_attr(not(bootstrap), lang = "AsyncGenFinished")] + pub const FINISHED: Self = Poll::Ready(None); +} diff --git a/src/tools/clippy/clippy_lints/src/doc/needless_doctest_main.rs b/src/tools/clippy/clippy_lints/src/doc/needless_doctest_main.rs index 640d4a069ec78..e019523e60987 100644 --- a/src/tools/clippy/clippy_lints/src/doc/needless_doctest_main.rs +++ b/src/tools/clippy/clippy_lints/src/doc/needless_doctest_main.rs @@ -69,7 +69,7 @@ pub fn check( if !ignore { get_test_spans(&item, &mut test_attr_spans); } - let is_async = matches!(sig.header.coro_kind, Some(CoroutineKind::Async { .. })); + let is_async = matches!(sig.header.coroutine_kind, Some(CoroutineKind::Async { .. })); let returns_nothing = match &sig.decl.output { FnRetTy::Default(..) => true, FnRetTy::Ty(ty) if ty.kind.is_unit() => true, diff --git a/src/tools/clippy/clippy_utils/src/ast_utils.rs b/src/tools/clippy/clippy_utils/src/ast_utils.rs index 5972278f32fe4..47237de4fefda 100644 --- a/src/tools/clippy/clippy_utils/src/ast_utils.rs +++ b/src/tools/clippy/clippy_utils/src/ast_utils.rs @@ -188,7 +188,7 @@ pub fn eq_expr(l: &Expr, r: &Expr) -> bool { Closure(box ast::Closure { binder: lb, capture_clause: lc, - coro_kind: la, + coroutine_kind: la, movability: lm, fn_decl: lf, body: le, @@ -197,7 +197,7 @@ pub fn eq_expr(l: &Expr, r: &Expr) -> bool { Closure(box ast::Closure { binder: rb, capture_clause: rc, - coro_kind: ra, + coroutine_kind: ra, movability: rm, fn_decl: rf, body: re, @@ -563,10 +563,11 @@ pub fn eq_fn_sig(l: &FnSig, r: &FnSig) -> bool { eq_fn_decl(&l.decl, &r.decl) && eq_fn_header(&l.header, &r.header) } -fn eq_opt_coro_kind(l: Option, r: Option) -> bool { +fn eq_opt_coroutine_kind(l: Option, r: Option) -> bool { match (l, r) { (Some(CoroutineKind::Async { .. }), Some(CoroutineKind::Async { .. })) - | (Some(CoroutineKind::Gen { .. }), Some(CoroutineKind::Gen { .. })) => true, + | (Some(CoroutineKind::Gen { .. }), Some(CoroutineKind::Gen { .. })) + | (Some(CoroutineKind::AsyncGen { .. }), Some(CoroutineKind::AsyncGen { .. })) => true, (None, None) => true, _ => false, } @@ -574,7 +575,7 @@ fn eq_opt_coro_kind(l: Option, r: Option) -> bool pub fn eq_fn_header(l: &FnHeader, r: &FnHeader) -> bool { matches!(l.unsafety, Unsafe::No) == matches!(r.unsafety, Unsafe::No) - && eq_opt_coro_kind(l.coro_kind, r.coro_kind) + && eq_opt_coroutine_kind(l.coroutine_kind, r.coroutine_kind) && matches!(l.constness, Const::No) == matches!(r.constness, Const::No) && eq_ext(&l.ext, &r.ext) } diff --git a/src/tools/rustfmt/src/closures.rs b/src/tools/rustfmt/src/closures.rs index c1ce87eadcb99..f698f494ae538 100644 --- a/src/tools/rustfmt/src/closures.rs +++ b/src/tools/rustfmt/src/closures.rs @@ -29,7 +29,7 @@ pub(crate) fn rewrite_closure( binder: &ast::ClosureBinder, constness: ast::Const, capture: ast::CaptureBy, - coro_kind: &Option, + coroutine_kind: &Option, movability: ast::Movability, fn_decl: &ast::FnDecl, body: &ast::Expr, @@ -40,7 +40,16 @@ pub(crate) fn rewrite_closure( debug!("rewrite_closure {:?}", body); let (prefix, extra_offset) = rewrite_closure_fn_decl( - binder, constness, capture, coro_kind, movability, fn_decl, body, span, context, shape, + binder, + constness, + capture, + coroutine_kind, + movability, + fn_decl, + body, + span, + context, + shape, )?; // 1 = space between `|...|` and body. let body_shape = shape.offset_left(extra_offset)?; @@ -233,7 +242,7 @@ fn rewrite_closure_fn_decl( binder: &ast::ClosureBinder, constness: ast::Const, capture: ast::CaptureBy, - coro_kind: &Option, + coroutine_kind: &Option, movability: ast::Movability, fn_decl: &ast::FnDecl, body: &ast::Expr, @@ -263,9 +272,10 @@ fn rewrite_closure_fn_decl( } else { "" }; - let coro = match coro_kind { + let coro = match coroutine_kind { Some(ast::CoroutineKind::Async { .. }) => "async ", Some(ast::CoroutineKind::Gen { .. }) => "gen ", + Some(ast::CoroutineKind::AsyncGen { .. }) => "async gen ", None => "", }; let mover = if matches!(capture, ast::CaptureBy::Value { .. }) { @@ -343,7 +353,7 @@ pub(crate) fn rewrite_last_closure( ref binder, constness, capture_clause, - ref coro_kind, + ref coroutine_kind, movability, ref fn_decl, ref body, @@ -364,7 +374,7 @@ pub(crate) fn rewrite_last_closure( binder, constness, capture_clause, - coro_kind, + coroutine_kind, movability, fn_decl, body, diff --git a/src/tools/rustfmt/src/expr.rs b/src/tools/rustfmt/src/expr.rs index 4515c27be374a..a68bd6694ba62 100644 --- a/src/tools/rustfmt/src/expr.rs +++ b/src/tools/rustfmt/src/expr.rs @@ -212,7 +212,7 @@ pub(crate) fn format_expr( &cl.binder, cl.constness, cl.capture_clause, - &cl.coro_kind, + &cl.coroutine_kind, cl.movability, &cl.fn_decl, &cl.body, diff --git a/src/tools/rustfmt/src/items.rs b/src/tools/rustfmt/src/items.rs index 4dff65f8cd0a6..a4256730f19da 100644 --- a/src/tools/rustfmt/src/items.rs +++ b/src/tools/rustfmt/src/items.rs @@ -287,7 +287,7 @@ pub(crate) struct FnSig<'a> { decl: &'a ast::FnDecl, generics: &'a ast::Generics, ext: ast::Extern, - coro_kind: Cow<'a, Option>, + coroutine_kind: Cow<'a, Option>, constness: ast::Const, defaultness: ast::Defaultness, unsafety: ast::Unsafe, @@ -302,7 +302,7 @@ impl<'a> FnSig<'a> { ) -> FnSig<'a> { FnSig { unsafety: method_sig.header.unsafety, - coro_kind: Cow::Borrowed(&method_sig.header.coro_kind), + coroutine_kind: Cow::Borrowed(&method_sig.header.coroutine_kind), constness: method_sig.header.constness, defaultness: ast::Defaultness::Final, ext: method_sig.header.ext, @@ -328,7 +328,7 @@ impl<'a> FnSig<'a> { generics, ext: fn_sig.header.ext, constness: fn_sig.header.constness, - coro_kind: Cow::Borrowed(&fn_sig.header.coro_kind), + coroutine_kind: Cow::Borrowed(&fn_sig.header.coroutine_kind), defaultness, unsafety: fn_sig.header.unsafety, visibility: vis, @@ -343,8 +343,8 @@ impl<'a> FnSig<'a> { result.push_str(&*format_visibility(context, self.visibility)); result.push_str(format_defaultness(self.defaultness)); result.push_str(format_constness(self.constness)); - self.coro_kind - .map(|coro_kind| result.push_str(format_coro(&coro_kind))); + self.coroutine_kind + .map(|coroutine_kind| result.push_str(format_coro(&coroutine_kind))); result.push_str(format_unsafety(self.unsafety)); result.push_str(&format_extern( self.ext, diff --git a/src/tools/rustfmt/src/utils.rs b/src/tools/rustfmt/src/utils.rs index 5805e417c04ad..7d7bbf1152905 100644 --- a/src/tools/rustfmt/src/utils.rs +++ b/src/tools/rustfmt/src/utils.rs @@ -75,10 +75,11 @@ pub(crate) fn format_visibility( } #[inline] -pub(crate) fn format_coro(coro_kind: &ast::CoroutineKind) -> &'static str { - match coro_kind { +pub(crate) fn format_coro(coroutine_kind: &ast::CoroutineKind) -> &'static str { + match coroutine_kind { ast::CoroutineKind::Async { .. } => "async ", ast::CoroutineKind::Gen { .. } => "gen ", + ast::CoroutineKind::AsyncGen { .. } => "async gen ", } } diff --git a/tests/ui-fulldeps/pprust-expr-roundtrip.rs b/tests/ui-fulldeps/pprust-expr-roundtrip.rs index 9e581620ec1b4..fe5333643edf8 100644 --- a/tests/ui-fulldeps/pprust-expr-roundtrip.rs +++ b/tests/ui-fulldeps/pprust-expr-roundtrip.rs @@ -132,7 +132,7 @@ fn iter_exprs(depth: usize, f: &mut dyn FnMut(P)) { binder: ClosureBinder::NotPresent, capture_clause: CaptureBy::Value { move_kw: DUMMY_SP }, constness: Const::No, - coro_kind: None, + coroutine_kind: None, movability: Movability::Movable, fn_decl: decl.clone(), body: e, diff --git a/tests/ui/coroutine/async_gen_fn.e2024.stderr b/tests/ui/coroutine/async_gen_fn.e2024.stderr new file mode 100644 index 0000000000000..d24cdbbc30d2f --- /dev/null +++ b/tests/ui/coroutine/async_gen_fn.e2024.stderr @@ -0,0 +1,12 @@ +error[E0658]: gen blocks are experimental + --> $DIR/async_gen_fn.rs:4:1 + | +LL | async gen fn foo() {} + | ^^^^^^^^^ + | + = note: see issue #117078 for more information + = help: add `#![feature(gen_blocks)]` to the crate attributes to enable + +error: aborting due to 1 previous error + +For more information about this error, try `rustc --explain E0658`. diff --git a/tests/ui/coroutine/async_gen_fn.none.stderr b/tests/ui/coroutine/async_gen_fn.none.stderr new file mode 100644 index 0000000000000..7950251a75daa --- /dev/null +++ b/tests/ui/coroutine/async_gen_fn.none.stderr @@ -0,0 +1,18 @@ +error[E0670]: `async fn` is not permitted in Rust 2015 + --> $DIR/async_gen_fn.rs:4:1 + | +LL | async gen fn foo() {} + | ^^^^^ to use `async fn`, switch to Rust 2018 or later + | + = help: pass `--edition 2021` to `rustc` + = note: for more on editions, read https://doc.rust-lang.org/edition-guide + +error: expected one of `extern`, `fn`, or `unsafe`, found `gen` + --> $DIR/async_gen_fn.rs:4:7 + | +LL | async gen fn foo() {} + | ^^^ expected one of `extern`, `fn`, or `unsafe` + +error: aborting due to 2 previous errors + +For more information about this error, try `rustc --explain E0670`. diff --git a/tests/ui/coroutine/async_gen_fn.rs b/tests/ui/coroutine/async_gen_fn.rs index f8860e07f6cc8..20564106f992a 100644 --- a/tests/ui/coroutine/async_gen_fn.rs +++ b/tests/ui/coroutine/async_gen_fn.rs @@ -1,11 +1,9 @@ -// edition: 2024 -// compile-flags: -Zunstable-options -#![feature(gen_blocks)] - -// async generators are not yet supported, so this test makes sure they make some kind of reasonable -// error. +// revisions: e2024 none +//[e2024] compile-flags: --edition 2024 -Zunstable-options async gen fn foo() {} -//~^ `async gen` functions are not supported +//[none]~^ ERROR: `async fn` is not permitted in Rust 2015 +//[none]~| ERROR: expected one of `extern`, `fn`, or `unsafe`, found `gen` +//[e2024]~^^^ ERROR: gen blocks are experimental fn main() {} diff --git a/tests/ui/coroutine/async_gen_fn.stderr b/tests/ui/coroutine/async_gen_fn.stderr deleted file mode 100644 index 6857ebe6c7901..0000000000000 --- a/tests/ui/coroutine/async_gen_fn.stderr +++ /dev/null @@ -1,8 +0,0 @@ -error: `async gen` functions are not supported - --> $DIR/async_gen_fn.rs:8:1 - | -LL | async gen fn foo() {} - | ^^^^^^^^^ - -error: aborting due to 1 previous error - diff --git a/tests/ui/coroutine/async_gen_fn_iter.rs b/tests/ui/coroutine/async_gen_fn_iter.rs new file mode 100644 index 0000000000000..6f8f3feb87e92 --- /dev/null +++ b/tests/ui/coroutine/async_gen_fn_iter.rs @@ -0,0 +1,96 @@ +// edition: 2024 +// compile-flags: -Zunstable-options +// run-pass + +#![feature(gen_blocks, async_iterator)] + +// make sure that a ridiculously simple async gen fn works as an iterator. + +async fn pause() { + // this doesn't actually do anything, lol +} + +async fn one() -> i32 { + 1 +} + +async fn two() -> i32 { + 2 +} + +async gen fn foo() -> i32 { + yield one().await; + pause().await; + yield two().await; + pause().await; + yield 3; + pause().await; +} + +async fn async_main() { + let mut iter = std::pin::pin!(foo()); + assert_eq!(iter.next().await, Some(1)); + assert_eq!(iter.as_mut().next().await, Some(2)); + assert_eq!(iter.as_mut().next().await, Some(3)); + assert_eq!(iter.as_mut().next().await, None); +} + +// ------------------------------------------------------------------------- // +// Implementation Details Below... + +use std::pin::Pin; +use std::task::*; +use std::async_iter::AsyncIterator; +use std::future::Future; + +trait AsyncIterExt { + fn next(&mut self) -> Next<'_, Self>; +} + +impl AsyncIterExt for T { + fn next(&mut self) -> Next<'_, Self> { + Next { s: self } + } +} + +struct Next<'s, S: ?Sized> { + s: &'s mut S, +} + +impl<'s, S: AsyncIterator> Future for Next<'s, S> where S: Unpin { + type Output = Option; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut *self.s).poll_next(cx) + } +} + +pub fn noop_waker() -> Waker { + let raw = RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE); + + // SAFETY: the contracts for RawWaker and RawWakerVTable are upheld + unsafe { Waker::from_raw(raw) } +} + +const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop); + +unsafe fn noop_clone(_p: *const ()) -> RawWaker { + RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE) +} + +unsafe fn noop(_p: *const ()) {} + +fn main() { + let mut fut = async_main(); + + // Poll loop, just to test the future... + let waker = noop_waker(); + let ctx = &mut Context::from_waker(&waker); + + loop { + match unsafe { Pin::new_unchecked(&mut fut).poll(ctx) } { + Poll::Pending => {} + Poll::Ready(()) => break, + } + } +}