Skip to content

Commit

Permalink
Perform return type heuristics to determine when to return and not
Browse files Browse the repository at this point in the history
  • Loading branch information
udoprog committed Feb 18, 2022
1 parent 439841d commit dcc8a7a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 22 deletions.
38 changes: 27 additions & 11 deletions tokio-macros/src/entry/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@ use crate::into_tokens::{bracketed, from_fn, parens, string, IntoTokens, S};
use crate::parsing::Buf;
use crate::token_stream::TokenStream;

#[derive(Debug, Clone, Copy)]
pub(crate) enum ReturnHeuristics {
/// Unknown how to treat the return type.
Unknown,
/// Generated function explicitly returns the special `()` unit type.
Unit,
/// Generated function explicitly returns the special `!` never type.
Never,
}

#[derive(Default)]
pub(crate) struct TailState {
pub(crate) start: Option<Span>,
pub(crate) end: Option<Span>,
/// Indicates if last expression is a return.
pub(crate) return_: bool,
pub(crate) has_return: bool,
}

#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -126,7 +136,10 @@ pub(crate) struct ItemOutput {
async_keyword: Option<usize>,
signature: Option<ops::Range<usize>>,
block: Option<usize>,
/// What's known about the tail statement.
tail_state: TailState,
/// Best effort heuristics to determine the return value of the function being procssed.
return_heuristics: ReturnHeuristics,
}

impl ItemOutput {
Expand All @@ -136,13 +149,15 @@ impl ItemOutput {
signature: Option<ops::Range<usize>>,
block: Option<usize>,
tail_state: TailState,
return_heuristics: ReturnHeuristics,
) -> Self {
Self {
tokens,
async_keyword,
signature,
block,
tail_state,
return_heuristics,
}
}

Expand Down Expand Up @@ -285,18 +300,19 @@ impl ItemOutput {
parens(string("Failed building the Runtime")),
);

let statement = (
with_span((build, '.', "block_on"), start),
parens(("async", block.clone())),
);

let should_return =
self.tail_state.has_return || matches!(self.return_heuristics, ReturnHeuristics::Unit);

from_fn(move |s| {
if self.tail_state.return_ {
s.write((
with_span(("return", build, '.', "block_on"), start),
parens(("async", block.clone())),
';',
));
if should_return {
s.write(((with_span("return", start), statement), ';'));
} else {
s.write((
with_span((build, '.', "block_on"), start),
parens(("async", block.clone())),
));
s.write(statement);
}
})
}
Expand Down
51 changes: 47 additions & 4 deletions tokio-macros/src/entry/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ use crate::entry::output::{
Config, EntryKind, ItemOutput, RuntimeFlavor, SupportsThreading, TailState,
};
use crate::error::Error;
use crate::parsing::{BaseParser, Buf};
use crate::parsing::{BaseParser, Buf, ROCKET};
use crate::parsing::{Punct, COMMA, EQ};

use super::output::ReturnHeuristics;

/// A parser for the arguments provided to an entry macro.
pub(crate) struct ConfigParser<'a> {
base: BaseParser<'a>,
Expand Down Expand Up @@ -206,7 +208,22 @@ impl<'a> ItemParser<'a> {
let mut generics = None;
let mut tail_state = TailState::default();

while let Some(tt) = self.base.bump() {
// We default to assuming that the return is a unit, until we've spot
// a `->` token at which point we try and process it.
let mut return_heuristics = ReturnHeuristics::Unit;

while self.base.nth(0).is_some() {
if let Some(p @ Punct { chars: ROCKET, .. }) = self.base.peek_punct() {
self.base.consume(p.len());
self.parse_return_heuristics(&mut return_heuristics);
continue;
}

let tt = match self.base.bump() {
Some(tt) => tt,
None => break,
};

match &tt {
TokenTree::Ident(ident) if self.base.buf.display_as_str(&ident) == "async" => {
if async_keyword.is_none() {
Expand Down Expand Up @@ -242,7 +259,33 @@ impl<'a> ItemParser<'a> {

let tokens = self.base.into_tokens();

ItemOutput::new(tokens, async_keyword, signature, block, tail_state)
ItemOutput::new(
tokens,
async_keyword,
signature,
block,
tail_state,
return_heuristics,
)
}

/// Parse out return type heuristics. There is a *very* limited number of
/// things we understand here.
fn parse_return_heuristics(&mut self, return_heuristics: &mut ReturnHeuristics) {
match self.base.nth(0) {
Some(TokenTree::Punct(p)) if p.as_char() == '!' => {
*return_heuristics = ReturnHeuristics::Never;
}
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Parenthesis => {
if g.stream().is_empty() {
*return_heuristics = ReturnHeuristics::Unit;
}
}
_ => {
// Return type is something we don't understand :(
*return_heuristics = ReturnHeuristics::Unknown;
}
}
}

/// Since generics are implemented using angle brackets.
Expand Down Expand Up @@ -289,7 +332,7 @@ impl<'a> ItemParser<'a> {
}
tt => {
if std::mem::take(&mut update) {
tail_state.return_ = matches!(&tt, TokenTree::Ident(ident) if self.base.buf.display_as_str(ident) == "return");
tail_state.has_return = matches!(&tt, TokenTree::Ident(ident) if self.base.buf.display_as_str(ident) == "return");
tail_state.start = Some(span);
}
}
Expand Down
27 changes: 20 additions & 7 deletions tokio-macros/src/parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use core::fmt;

use proc_macro::{Delimiter, Spacing, Span, TokenTree};

const BUF: usize = 2;
const BUF: usize = 4;

// Punctuations that we look for.
pub(crate) const COMMA: [char; 2] = [',', '\0'];
pub(crate) const EQ: [char; 2] = ['=', '\0'];
pub(crate) const ROCKET: [char; 2] = ['-', '>'];

pub(crate) struct Buf {
// Static ring buffer used for processing tokens.
Expand All @@ -20,7 +21,7 @@ pub(crate) struct Buf {
impl Buf {
pub(crate) fn new() -> Self {
Self {
ring: [None, None],
ring: [None, None, None, None],
string: String::new(),
head: 0,
tail: 0,
Expand All @@ -29,7 +30,7 @@ impl Buf {

/// Clear the buffer.
fn clear(&mut self) {
self.ring = [None, None];
self.ring = [None, None, None, None];
self.head = 0;
self.tail = 0;
self.string.clear();
Expand Down Expand Up @@ -128,9 +129,21 @@ impl<'a> BaseParser<'a> {
}
}

/// Process a punctuation.
/// Step over the given number of tokens.
pub(crate) fn consume(&mut self, n: usize) {
for _ in 0..n {
if let Some(tt) = self.bump() {
self.push(tt);
}
}
}

/// Peek a punctuation with joint characters.
///
/// This processes the next 3 punctuations (if present) to ensure that when
/// we encounter a particular punctuation it occurs in isolation.
pub(crate) fn peek_punct(&mut self) -> Option<Punct> {
let mut out = [None; 2];
let mut out = [None; 3];

for (n, o) in out.iter_mut().enumerate() {
match (n, self.nth(n)) {
Expand All @@ -148,7 +161,7 @@ impl<'a> BaseParser<'a> {
}

match out {
[Some((span, head)), tail] => Some(Punct {
[Some((span, head)), tail, None] => Some(Punct {
span,
chars: [head, tail.map(|(_, c)| c).unwrap_or('\0')],
}),
Expand Down Expand Up @@ -206,7 +219,7 @@ impl Iterator for StreamIter {
}
}

/// A complete punctuation.
/// A complete punctuation with a maximum up to two characters.
#[derive(Debug)]
pub(crate) struct Punct {
pub(crate) span: Span,
Expand Down

0 comments on commit dcc8a7a

Please sign in to comment.