Skip to content

Commit

Permalink
util: Two call_all bug fixes (tower-rs#709)
Browse files Browse the repository at this point in the history
One is handling poll_ready errors (tower-rs#706).

The other is fixing the TODO about disarming poll_ready, since there is no disarm this makes sure
`poll_ready` is only called if `call` will immediately follow.
  • Loading branch information
leoyvens authored Nov 1, 2022
1 parent d27ba65 commit c9d84cd
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 19 deletions.
59 changes: 42 additions & 17 deletions tower/src/util/call_all/common.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use futures_core::{ready, Stream};
use pin_project_lite::pin_project;
use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
Expand All @@ -9,13 +10,30 @@ use tower_service::Service;

pin_project! {
/// The [`Future`] returned by the [`ServiceExt::call_all`] combinator.
#[derive(Debug)]
pub(crate) struct CallAll<Svc, S, Q> {
pub(crate) struct CallAll<Svc, S, Q>
where
S: Stream,
{
service: Option<Svc>,
#[pin]
stream: S,
queue: Q,
eof: bool,
curr_req: Option<S::Item>
}
}

impl<Svc, S, Q> fmt::Debug for CallAll<Svc, S, Q>
where
Svc: fmt::Debug,
S: Stream + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CallAll")
.field("service", &self.service)
.field("stream", &self.stream)
.field("eof", &self.eof)
.finish()
}
}

Expand All @@ -39,6 +57,7 @@ where
stream,
queue,
eof: false,
curr_req: None,
}
}

Expand Down Expand Up @@ -90,27 +109,33 @@ where
}
}

// If not done, and we don't have a stored request, gather the next request from the
// stream (if there is one), or return `Pending` if the stream is not ready.
if this.curr_req.is_none() {
*this.curr_req = match ready!(this.stream.as_mut().poll_next(cx)) {
Some(next_req) => Some(next_req),
None => {
// Mark that there will be no more requests.
*this.eof = true;
continue;
}
};
}

// Then, see that the service is ready for another request
let svc = this
.service
.as_mut()
.expect("Using CallAll after extracing inner Service");
ready!(svc.poll_ready(cx))?;

// If it is, gather the next request (if there is one), or return `Pending` if the
// stream is not ready.
// TODO: We probably want to "release" the slot we reserved in Svc if the
// stream returns `Pending`. It may be a while until we get around to actually
// using it.
match ready!(this.stream.as_mut().poll_next(cx)) {
Some(req) => {
this.queue.push(svc.call(req));
}
None => {
// We're all done once any outstanding requests have completed
*this.eof = true;
}

if let Err(e) = ready!(svc.poll_ready(cx)) {
// Set eof to prevent the service from being called again after a `poll_ready` error
*this.eof = true;
return Poll::Ready(Some(Err(e)));
}

// Unwrap: The check above always sets `this.curr_req` if none.
this.queue.push(svc.call(this.curr_req.take().unwrap()));
}
}
}
91 changes: 89 additions & 2 deletions tower/tests/util/call_all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ use futures_util::{
future::{ready, Ready},
pin_mut,
};
use std::fmt;
use std::future::Future;
use std::task::{Context, Poll};
use std::{cell::Cell, rc::Rc};
use tokio_test::{assert_pending, assert_ready, task};
use tower::util::ServiceExt;
use tower_service::*;
use tower_test::{assert_request_eq, mock};
use tower_test::{assert_request_eq, mock, mock::Mock};

type Error = Box<dyn std::error::Error + Send + Sync>;
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;

#[derive(Debug, Eq, PartialEq)]
struct Srv {
Expand Down Expand Up @@ -164,3 +166,88 @@ async fn pending() {
assert_eq!(res.transpose().unwrap(), Some("res"));
assert_pending!(task.enter(|cx, _| ca.as_mut().poll_next(cx)));
}

#[tokio::test]
async fn poll_ready_error() {
struct ReadyOnceThenErr {
polled: bool,
inner: Mock<&'static str, &'static str>,
}

#[derive(Debug)]
pub struct StringErr(String);

impl fmt::Display for StringErr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}

impl std::error::Error for StringErr {}

impl Service<&'static str> for ReadyOnceThenErr {
type Response = &'static str;
type Error = Error;
type Future = <Mock<&'static str, &'static str> as Service<&'static str>>::Future;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.polled {
false => {
self.polled = true;
self.inner.poll_ready(cx)
}
true => Poll::Ready(Err(Box::new(StringErr("poll_ready error".to_string())))),
}
}

fn call(&mut self, req: &'static str) -> Self::Future {
self.inner.call(req)
}
}

let _t = support::trace_init();

let (mock, mut handle) = mock::pair::<_, &'static str>();
let svc = ReadyOnceThenErr {
polled: false,
inner: mock,
};
let mut task = task::spawn(());

// "req0" is called, then "req1" receives a poll_ready error so "req2" will never be called.
// Still the response from "req0" is waited on before ending the `call_all` stream.
let requests = futures_util::stream::iter(vec!["req0", "req1", "req2"]);
let ca = svc.call_all(requests);
pin_mut!(ca);
let err = assert_ready!(task.enter(|cx, _| ca.as_mut().poll_next(cx)));
assert_eq!(err.unwrap().unwrap_err().to_string(), "poll_ready error");
assert_request_eq!(handle, "req0").send_response("res0");
let res = assert_ready!(task.enter(|cx, _| ca.as_mut().poll_next(cx)));
assert_eq!(res.transpose().unwrap(), Some("res0"));
let res = assert_ready!(task.enter(|cx, _| ca.as_mut().poll_next(cx)));
assert_eq!(res.transpose().unwrap(), None);
}

#[tokio::test]
async fn stream_does_not_block_service() {
use tower::buffer::Buffer;
use tower::limit::ConcurrencyLimit;

let _t = support::trace_init();
let (mock, mut handle) = mock::pair::<_, &'static str>();
let mut task = task::spawn(());

let svc = Buffer::new(ConcurrencyLimit::new(mock, 1), 1);

// Always pending, but should not occupy a concurrency slot.
let pending = svc.clone().call_all(futures_util::stream::pending());
pin_mut!(pending);
assert_pending!(task.enter(|cx, _| pending.as_mut().poll_next(cx)));

let call = svc.oneshot("req");
pin_mut!(call);
assert_pending!(task.enter(|cx, _| call.as_mut().poll(cx)));
assert_request_eq!(handle, "req").send_response("res");
let res = assert_ready!(task.enter(|cx, _| call.as_mut().poll(cx)));
assert_eq!(res.unwrap(), "res");
}

0 comments on commit c9d84cd

Please sign in to comment.