Skip to content

Commit

Permalink
Auto merge of #77035 - mibac138:fn-fat-arrow-return, r=davidtwco
Browse files Browse the repository at this point in the history
Gracefully handle mistyping -> as => in function return type

Fixes #77019
  • Loading branch information
bors committed Dec 19, 2020
2 parents 50a9097 + e916641 commit d1741e5
Show file tree
Hide file tree
Showing 14 changed files with 255 additions and 25 deletions.
5 changes: 3 additions & 2 deletions compiler/rustc_parse/src/parser/expr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::pat::{GateOr, PARAM_EXPECTED};
use super::ty::{AllowPlus, RecoverQPath};
use super::ty::{AllowPlus, RecoverQPath, RecoverReturnSign};
use super::{BlockMode, Parser, PathStyle, Restrictions, TokenType};
use super::{SemiColonMode, SeqSep, TokenExpectType};
use crate::maybe_recover_from_interpolated_ty_qpath;
Expand Down Expand Up @@ -1647,7 +1647,8 @@ impl<'a> Parser<'a> {
self.expect_or()?;
args
};
let output = self.parse_ret_ty(AllowPlus::Yes, RecoverQPath::Yes)?;
let output =
self.parse_ret_ty(AllowPlus::Yes, RecoverQPath::Yes, RecoverReturnSign::Yes)?;

Ok(P(FnDecl { inputs, output }))
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_parse/src/parser/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ impl<'a> Parser<'a> {

// Parse type with mandatory colon and (possibly empty) bounds,
// or with mandatory equality sign and the second type.
let ty = self.parse_ty()?;
let ty = self.parse_ty_for_where_clause()?;
if self.eat(&token::Colon) {
let bounds = self.parse_generic_bounds(Some(self.prev_token.span))?;
Ok(ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate {
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_parse/src/parser/item.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::diagnostics::{dummy_arg, ConsumeClosingDelim, Error};
use super::ty::{AllowPlus, RecoverQPath};
use super::ty::{AllowPlus, RecoverQPath, RecoverReturnSign};
use super::{FollowedByType, Parser, PathStyle};

use crate::maybe_whole;
Expand Down Expand Up @@ -1549,7 +1549,7 @@ impl<'a> Parser<'a> {
let header = self.parse_fn_front_matter()?; // `const ... fn`
let ident = self.parse_ident()?; // `foo`
let mut generics = self.parse_generics()?; // `<'a, T, ...>`
let decl = self.parse_fn_decl(req_name, AllowPlus::Yes)?; // `(p: u8, ...)`
let decl = self.parse_fn_decl(req_name, AllowPlus::Yes, RecoverReturnSign::Yes)?; // `(p: u8, ...)`
generics.where_clause = self.parse_where_clause()?; // `where T: Ord`

let mut sig_hi = self.prev_token.span;
Expand Down Expand Up @@ -1680,10 +1680,11 @@ impl<'a> Parser<'a> {
&mut self,
req_name: ReqName,
ret_allow_plus: AllowPlus,
recover_return_sign: RecoverReturnSign,
) -> PResult<'a, P<FnDecl>> {
Ok(P(FnDecl {
inputs: self.parse_fn_params(req_name)?,
output: self.parse_ret_ty(ret_allow_plus, RecoverQPath::Yes)?,
output: self.parse_ret_ty(ret_allow_plus, RecoverQPath::Yes, recover_return_sign)?,
}))
}

Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_parse/src/parser/path.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::ty::{AllowPlus, RecoverQPath};
use super::ty::{AllowPlus, RecoverQPath, RecoverReturnSign};
use super::{Parser, TokenType};
use crate::maybe_whole;
use rustc_ast::ptr::P;
Expand Down Expand Up @@ -231,7 +231,8 @@ impl<'a> Parser<'a> {
// `(T, U) -> R`
let (inputs, _) = self.parse_paren_comma_seq(|p| p.parse_ty())?;
let span = ident.span.to(self.prev_token.span);
let output = self.parse_ret_ty(AllowPlus::No, RecoverQPath::No)?;
let output =
self.parse_ret_ty(AllowPlus::No, RecoverQPath::No, RecoverReturnSign::No)?;
ParenthesizedArgs { inputs, output, span }.into()
};

Expand Down
105 changes: 96 additions & 9 deletions compiler/rustc_parse/src/parser/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,37 @@ pub(super) enum RecoverQPath {
No,
}

/// Signals whether parsing a type should recover `->`.
///
/// More specifically, when parsing a function like:
/// ```rust
/// fn foo() => u8 { 0 }
/// fn bar(): u8 { 0 }
/// ```
/// The compiler will try to recover interpreting `foo() => u8` as `foo() -> u8` when calling
/// `parse_ty` with anything except `RecoverReturnSign::No`, and it will try to recover `bar(): u8`
/// as `bar() -> u8` when passing `RecoverReturnSign::Yes` to `parse_ty`
#[derive(Copy, Clone, PartialEq)]
pub(super) enum RecoverReturnSign {
Yes,
OnlyFatArrow,
No,
}

impl RecoverReturnSign {
/// [RecoverReturnSign::Yes] allows for recovering `fn foo() => u8` and `fn foo(): u8`,
/// [RecoverReturnSign::OnlyFatArrow] allows for recovering only `fn foo() => u8` (recovering
/// colons can cause problems when parsing where clauses), and
/// [RecoverReturnSign::No] doesn't allow for any recovery of the return type arrow
fn can_recover(self, token: &TokenKind) -> bool {
match self {
Self::Yes => matches!(token, token::FatArrow | token::Colon),
Self::OnlyFatArrow => matches!(token, token::FatArrow),
Self::No => false,
}
}
}

// Is `...` (`CVarArgs`) legal at this level of type parsing?
#[derive(PartialEq)]
enum AllowCVariadic {
Expand All @@ -62,14 +93,24 @@ fn can_continue_type_after_non_fn_ident(t: &Token) -> bool {
impl<'a> Parser<'a> {
/// Parses a type.
pub fn parse_ty(&mut self) -> PResult<'a, P<Ty>> {
self.parse_ty_common(AllowPlus::Yes, RecoverQPath::Yes, AllowCVariadic::No)
self.parse_ty_common(
AllowPlus::Yes,
AllowCVariadic::No,
RecoverQPath::Yes,
RecoverReturnSign::Yes,
)
}

/// Parse a type suitable for a function or function pointer parameter.
/// The difference from `parse_ty` is that this version allows `...`
/// (`CVarArgs`) at the top level of the type.
pub(super) fn parse_ty_for_param(&mut self) -> PResult<'a, P<Ty>> {
self.parse_ty_common(AllowPlus::Yes, RecoverQPath::Yes, AllowCVariadic::Yes)
self.parse_ty_common(
AllowPlus::Yes,
AllowCVariadic::Yes,
RecoverQPath::Yes,
RecoverReturnSign::Yes,
)
}

/// Parses a type in restricted contexts where `+` is not permitted.
Expand All @@ -79,18 +120,58 @@ impl<'a> Parser<'a> {
/// Example 2: `value1 as TYPE + value2`
/// `+` is prohibited to avoid interactions with expression grammar.
pub(super) fn parse_ty_no_plus(&mut self) -> PResult<'a, P<Ty>> {
self.parse_ty_common(AllowPlus::No, RecoverQPath::Yes, AllowCVariadic::No)
self.parse_ty_common(
AllowPlus::No,
AllowCVariadic::No,
RecoverQPath::Yes,
RecoverReturnSign::Yes,
)
}

/// Parse a type without recovering `:` as `->` to avoid breaking code such as `where fn() : for<'a>`
pub(super) fn parse_ty_for_where_clause(&mut self) -> PResult<'a, P<Ty>> {
self.parse_ty_common(
AllowPlus::Yes,
AllowCVariadic::Yes,
RecoverQPath::Yes,
RecoverReturnSign::OnlyFatArrow,
)
}

/// Parses an optional return type `[ -> TY ]` in a function declaration.
pub(super) fn parse_ret_ty(
&mut self,
allow_plus: AllowPlus,
recover_qpath: RecoverQPath,
recover_return_sign: RecoverReturnSign,
) -> PResult<'a, FnRetTy> {
Ok(if self.eat(&token::RArrow) {
// FIXME(Centril): Can we unconditionally `allow_plus`?
let ty = self.parse_ty_common(allow_plus, recover_qpath, AllowCVariadic::No)?;
let ty = self.parse_ty_common(
allow_plus,
AllowCVariadic::No,
recover_qpath,
recover_return_sign,
)?;
FnRetTy::Ty(ty)
} else if recover_return_sign.can_recover(&self.token.kind) {
// Don't `eat` to prevent `=>` from being added as an expected token which isn't
// actually expected and could only confuse users
self.bump();
self.struct_span_err(self.prev_token.span, "return types are denoted using `->`")
.span_suggestion_short(
self.prev_token.span,
"use `->` instead",
"->".to_string(),
Applicability::MachineApplicable,
)
.emit();
let ty = self.parse_ty_common(
allow_plus,
AllowCVariadic::No,
recover_qpath,
recover_return_sign,
)?;
FnRetTy::Ty(ty)
} else {
FnRetTy::Default(self.token.span.shrink_to_lo())
Expand All @@ -100,8 +181,9 @@ impl<'a> Parser<'a> {
fn parse_ty_common(
&mut self,
allow_plus: AllowPlus,
recover_qpath: RecoverQPath,
allow_c_variadic: AllowCVariadic,
recover_qpath: RecoverQPath,
recover_return_sign: RecoverReturnSign,
) -> PResult<'a, P<Ty>> {
let allow_qpath_recovery = recover_qpath == RecoverQPath::Yes;
maybe_recover_from_interpolated_ty_qpath!(self, allow_qpath_recovery);
Expand Down Expand Up @@ -129,14 +211,14 @@ impl<'a> Parser<'a> {
TyKind::Infer
} else if self.check_fn_front_matter() {
// Function pointer type
self.parse_ty_bare_fn(lo, Vec::new())?
self.parse_ty_bare_fn(lo, Vec::new(), recover_return_sign)?
} else if self.check_keyword(kw::For) {
// Function pointer type or bound list (trait object type) starting with a poly-trait.
// `for<'lt> [unsafe] [extern "ABI"] fn (&'lt S) -> T`
// `for<'lt> Trait1<'lt> + Trait2 + 'a`
let lifetime_defs = self.parse_late_bound_lifetime_defs()?;
if self.check_fn_front_matter() {
self.parse_ty_bare_fn(lo, lifetime_defs)?
self.parse_ty_bare_fn(lo, lifetime_defs, recover_return_sign)?
} else {
let path = self.parse_path(PathStyle::Type)?;
let parse_plus = allow_plus == AllowPlus::Yes && self.check_plus();
Expand Down Expand Up @@ -338,9 +420,14 @@ impl<'a> Parser<'a> {
/// Function Style ABI Parameter types
/// ```
/// We actually parse `FnHeader FnDecl`, but we error on `const` and `async` qualifiers.
fn parse_ty_bare_fn(&mut self, lo: Span, params: Vec<GenericParam>) -> PResult<'a, TyKind> {
fn parse_ty_bare_fn(
&mut self,
lo: Span,
params: Vec<GenericParam>,
recover_return_sign: RecoverReturnSign,
) -> PResult<'a, TyKind> {
let ast::FnHeader { ext, unsafety, constness, asyncness } = self.parse_fn_front_matter()?;
let decl = self.parse_fn_decl(|_| false, AllowPlus::No)?;
let decl = self.parse_fn_decl(|_| false, AllowPlus::No, recover_return_sign)?;
let whole_span = lo.to(self.prev_token.span);
if let ast::Const::Yes(span) = constness {
self.error_fn_ptr_bad_qualifier(whole_span, span, "const");
Expand Down
28 changes: 28 additions & 0 deletions src/test/ui/fn/fn-recover-return-sign.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// run-rustfix
#![allow(unused)]
fn a() -> usize { 0 }
//~^ ERROR return types are denoted using `->`

fn b()-> usize { 0 }
//~^ ERROR return types are denoted using `->`

fn bar(_: u32) {}

fn baz() -> *const dyn Fn(u32) { unimplemented!() }

fn foo() {
match () {
_ if baz() == &bar as &dyn Fn(u32) => (),
() => (),
}
}

fn main() {
let foo = |a: bool| -> bool { a };
//~^ ERROR return types are denoted using `->`
dbg!(foo(false));

let bar = |a: bool|-> bool { a };
//~^ ERROR return types are denoted using `->`
dbg!(bar(false));
}
28 changes: 28 additions & 0 deletions src/test/ui/fn/fn-recover-return-sign.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// run-rustfix
#![allow(unused)]
fn a() => usize { 0 }
//~^ ERROR return types are denoted using `->`

fn b(): usize { 0 }
//~^ ERROR return types are denoted using `->`

fn bar(_: u32) {}

fn baz() -> *const dyn Fn(u32) { unimplemented!() }

fn foo() {
match () {
_ if baz() == &bar as &dyn Fn(u32) => (),
() => (),
}
}

fn main() {
let foo = |a: bool| => bool { a };
//~^ ERROR return types are denoted using `->`
dbg!(foo(false));

let bar = |a: bool|: bool { a };
//~^ ERROR return types are denoted using `->`
dbg!(bar(false));
}
26 changes: 26 additions & 0 deletions src/test/ui/fn/fn-recover-return-sign.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
error: return types are denoted using `->`
--> $DIR/fn-recover-return-sign.rs:3:8
|
LL | fn a() => usize { 0 }
| ^^ help: use `->` instead

error: return types are denoted using `->`
--> $DIR/fn-recover-return-sign.rs:6:7
|
LL | fn b(): usize { 0 }
| ^ help: use `->` instead

error: return types are denoted using `->`
--> $DIR/fn-recover-return-sign.rs:21:25
|
LL | let foo = |a: bool| => bool { a };
| ^^ help: use `->` instead

error: return types are denoted using `->`
--> $DIR/fn-recover-return-sign.rs:25:24
|
LL | let bar = |a: bool|: bool { a };
| ^ help: use `->` instead

error: aborting due to 4 previous errors

8 changes: 8 additions & 0 deletions src/test/ui/fn/fn-recover-return-sign2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Separate test file because `Fn() => bool` isn't getting fixed and rustfix complained that
// even though a fix was applied the code was still incorrect

fn foo() => impl Fn() => bool {
//~^ ERROR return types are denoted using `->`
//~| ERROR expected one of `+`, `->`, `::`, `;`, `where`, or `{`, found `=>`
unimplemented!()
}
14 changes: 14 additions & 0 deletions src/test/ui/fn/fn-recover-return-sign2.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
error: return types are denoted using `->`
--> $DIR/fn-recover-return-sign2.rs:4:10
|
LL | fn foo() => impl Fn() => bool {
| ^^ help: use `->` instead

error: expected one of `+`, `->`, `::`, `;`, `where`, or `{`, found `=>`
--> $DIR/fn-recover-return-sign2.rs:4:23
|
LL | fn foo() => impl Fn() => bool {
| ^^ expected one of `+`, `->`, `::`, `;`, `where`, or `{`

error: aborting due to 2 previous errors

3 changes: 2 additions & 1 deletion src/test/ui/parser/fn-colon-return-type.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
fn foo(x: i32): i32 { //~ ERROR expected one of `->`, `;`, `where`, or `{`, found `:`
fn foo(x: i32): i32 {
//~^ ERROR return types are denoted using `->`
x
}

Expand Down
4 changes: 2 additions & 2 deletions src/test/ui/parser/fn-colon-return-type.stderr
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
error: expected one of `->`, `;`, `where`, or `{`, found `:`
error: return types are denoted using `->`
--> $DIR/fn-colon-return-type.rs:1:15
|
LL | fn foo(x: i32): i32 {
| ^ expected one of `->`, `;`, `where`, or `{`
| ^ help: use `->` instead

error: aborting due to previous error

13 changes: 11 additions & 2 deletions src/test/ui/parser/not-a-pred.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
fn f(a: isize, b: isize) : lt(a, b) { }
//~^ ERROR expected one of `->`, `;`, `where`, or `{`, found `:`
//~^ ERROR return types are denoted using `->`
//~| ERROR expected type, found function `lt` [E0573]
//~| ERROR expected type, found local variable `a` [E0573]
//~| ERROR expected type, found local variable `b` [E0573]

fn lt(a: isize, b: isize) { }

fn main() { let a: isize = 10; let b: isize = 23; check (lt(a, b)); f(a, b); }
fn main() {
let a: isize = 10;
let b: isize = 23;
check (lt(a, b));
//~^ ERROR cannot find function `check` in this scope [E0425]
f(a, b);
}
Loading

0 comments on commit d1741e5

Please sign in to comment.