diff --git a/http-body-util/Cargo.toml b/http-body-util/Cargo.toml index c9d36aa..4b733e4 100644 --- a/http-body-util/Cargo.toml +++ b/http-body-util/Cargo.toml @@ -33,4 +33,4 @@ http-body = { version = "1", path = "../http-body" } pin-project-lite = "0.2" [dev-dependencies] -tokio = { version = "1", features = ["macros", "rt"] } +tokio = { version = "1", features = ["macros", "rt", "sync", "rt-multi-thread"] } diff --git a/http-body-util/src/combinators/mod.rs b/http-body-util/src/combinators/mod.rs index 0ecdb0b..38d2637 100644 --- a/http-body-util/src/combinators/mod.rs +++ b/http-body-util/src/combinators/mod.rs @@ -5,6 +5,7 @@ mod collect; mod frame; mod map_err; mod map_frame; +mod with_trailers; pub use self::{ box_body::{BoxBody, UnsyncBoxBody}, @@ -12,4 +13,5 @@ pub use self::{ frame::Frame, map_err::MapErr, map_frame::MapFrame, + with_trailers::WithTrailers, }; diff --git a/http-body-util/src/combinators/with_trailers.rs b/http-body-util/src/combinators/with_trailers.rs new file mode 100644 index 0000000..383e1ec --- /dev/null +++ b/http-body-util/src/combinators/with_trailers.rs @@ -0,0 +1,213 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::ready; +use http::HeaderMap; +use http_body::{Body, Frame}; +use pin_project_lite::pin_project; + +pin_project! { + /// Adds trailers to a body. + /// + /// See [`BodyExt::with_trailers`] for more details. + pub struct WithTrailers { + #[pin] + state: State, + } +} + +impl WithTrailers { + pub(crate) fn new(body: T, trailers: F) -> Self { + Self { + state: State::PollBody { + body, + trailers: Some(trailers), + }, + } + } +} + +pin_project! { + #[project = StateProj] + enum State { + PollBody { + #[pin] + body: T, + trailers: Option, + }, + PollTrailers { + #[pin] + trailers: F, + prev_trailers: Option, + }, + Done, + } +} + +impl Body for WithTrailers +where + T: Body, + F: Future>>, +{ + type Data = T::Data; + type Error = T::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + loop { + let mut this = self.as_mut().project(); + + match this.state.as_mut().project() { + StateProj::PollBody { body, trailers } => match ready!(body.poll_frame(cx)?) { + Some(frame) => match frame.into_trailers() { + Ok(prev_trailers) => { + let trailers = trailers.take().unwrap(); + this.state.set(State::PollTrailers { + trailers, + prev_trailers: Some(prev_trailers), + }); + } + Err(frame) => { + return Poll::Ready(Some(Ok(frame))); + } + }, + None => { + let trailers = trailers.take().unwrap(); + this.state.set(State::PollTrailers { + trailers, + prev_trailers: None, + }); + } + }, + StateProj::PollTrailers { + trailers, + prev_trailers, + } => { + let trailers = ready!(trailers.poll(cx)?); + match (trailers, prev_trailers.take()) { + (None, None) => return Poll::Ready(None), + (None, Some(trailers)) | (Some(trailers), None) => { + this.state.set(State::Done); + return Poll::Ready(Some(Ok(Frame::trailers(trailers)))); + } + (Some(new_trailers), Some(mut prev_trailers)) => { + prev_trailers.extend(new_trailers); + this.state.set(State::Done); + return Poll::Ready(Some(Ok(Frame::trailers(prev_trailers)))); + } + } + } + StateProj::Done => { + return Poll::Ready(None); + } + } + } + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + match &self.state { + State::PollBody { body, .. } => body.size_hint(), + State::PollTrailers { .. } | State::Done => Default::default(), + } + } +} + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use bytes::Bytes; + use http::{HeaderMap, HeaderName, HeaderValue}; + + use crate::{BodyExt, Empty, Full}; + + #[allow(unused_imports)] + use super::*; + + #[tokio::test] + async fn works() { + let mut trailers = HeaderMap::new(); + trailers.insert( + HeaderName::from_static("foo"), + HeaderValue::from_static("bar"), + ); + + let body = + Full::::from("hello").with_trailers(std::future::ready(Some( + Ok::<_, Infallible>(trailers.clone()), + ))); + + futures_util::pin_mut!(body); + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + let data = unwrap_ready(body.as_mut().poll_frame(&mut cx)) + .unwrap() + .unwrap() + .into_data() + .unwrap(); + assert_eq!(data, "hello"); + + let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx)) + .unwrap() + .unwrap() + .into_trailers() + .unwrap(); + assert_eq!(body_trailers, trailers); + + assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none()); + } + + #[tokio::test] + async fn merges_trailers() { + let mut trailers_1 = HeaderMap::new(); + trailers_1.insert( + HeaderName::from_static("foo"), + HeaderValue::from_static("bar"), + ); + + let mut trailers_2 = HeaderMap::new(); + trailers_2.insert( + HeaderName::from_static("baz"), + HeaderValue::from_static("qux"), + ); + + let body = Empty::::new() + .with_trailers(std::future::ready(Some(Ok::<_, Infallible>( + trailers_1.clone(), + )))) + .with_trailers(std::future::ready(Some(Ok::<_, Infallible>( + trailers_2.clone(), + )))); + + futures_util::pin_mut!(body); + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx)) + .unwrap() + .unwrap() + .into_trailers() + .unwrap(); + + let mut all_trailers = HeaderMap::new(); + all_trailers.extend(trailers_1); + all_trailers.extend(trailers_2); + assert_eq!(body_trailers, all_trailers); + + assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none()); + } + + fn unwrap_ready(poll: Poll) -> T { + match poll { + Poll::Ready(t) => t, + Poll::Pending => panic!("pending"), + } + } +} diff --git a/http-body-util/src/lib.rs b/http-body-util/src/lib.rs index 059ada6..1b715b3 100644 --- a/http-body-util/src/lib.rs +++ b/http-body-util/src/lib.rs @@ -89,6 +89,50 @@ pub trait BodyExt: http_body::Body { collected: Some(crate::Collected::default()), } } + + /// Add trailers to the body. + /// + /// The trailers will be sent when all previous frames have been sent and the `trailers` future + /// resolves. + /// + /// # Example + /// + /// ``` + /// use http::HeaderMap; + /// use http_body_util::{Full, BodyExt}; + /// use bytes::Bytes; + /// + /// # #[tokio::main] + /// async fn main() { + /// let (tx, rx) = tokio::sync::oneshot::channel::(); + /// + /// let body = Full::::from("Hello, World!") + /// // add trailers via a future + /// .with_trailers(async move { + /// match rx.await { + /// Ok(trailers) => Some(Ok(trailers)), + /// Err(_err) => None, + /// } + /// }); + /// + /// // compute the trailers in the background + /// tokio::spawn(async move { + /// let _ = tx.send(compute_trailers().await); + /// }); + /// + /// async fn compute_trailers() -> HeaderMap { + /// // ... + /// # unimplemented!() + /// } + /// # } + /// ``` + fn with_trailers(self, trailers: F) -> combinators::WithTrailers + where + Self: Sized, + F: std::future::Future>>, + { + combinators::WithTrailers::new(self, trailers) + } } impl BodyExt for T where T: http_body::Body {}