Skip to content

Commit

Permalink
util: add BodyExt::with_trailers (#99)
Browse files Browse the repository at this point in the history
* Add `BodyExt::with_trailers`

* fix tests

* merge with trailers from inner body

* clean up setting state

* fix `is_end_stream`
  • Loading branch information
davidpdrsn authored Nov 20, 2023
1 parent c58b641 commit e5b462a
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 1 deletion.
2 changes: 1 addition & 1 deletion http-body-util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ http-body = { version = "1", path = "../http-body" }
pin-project-lite = "0.2"

[dev-dependencies]
tokio = { version = "1", features = ["macros", "rt"] }
tokio = { version = "1", features = ["macros", "rt", "sync", "rt-multi-thread"] }
2 changes: 2 additions & 0 deletions http-body-util/src/combinators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ mod collect;
mod frame;
mod map_err;
mod map_frame;
mod with_trailers;

pub use self::{
box_body::{BoxBody, UnsyncBoxBody},
collect::Collect,
frame::Frame,
map_err::MapErr,
map_frame::MapFrame,
with_trailers::WithTrailers,
};
213 changes: 213 additions & 0 deletions http-body-util/src/combinators/with_trailers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};

use futures_util::ready;
use http::HeaderMap;
use http_body::{Body, Frame};
use pin_project_lite::pin_project;

pin_project! {
/// Adds trailers to a body.
///
/// See [`BodyExt::with_trailers`] for more details.
pub struct WithTrailers<T, F> {
#[pin]
state: State<T, F>,
}
}

impl<T, F> WithTrailers<T, F> {
pub(crate) fn new(body: T, trailers: F) -> Self {
Self {
state: State::PollBody {
body,
trailers: Some(trailers),
},
}
}
}

pin_project! {
#[project = StateProj]
enum State<T, F> {
PollBody {
#[pin]
body: T,
trailers: Option<F>,
},
PollTrailers {
#[pin]
trailers: F,
prev_trailers: Option<HeaderMap>,
},
Done,
}
}

impl<T, F> Body for WithTrailers<T, F>
where
T: Body,
F: Future<Output = Option<Result<HeaderMap, T::Error>>>,
{
type Data = T::Data;
type Error = T::Error;

fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
loop {
let mut this = self.as_mut().project();

match this.state.as_mut().project() {
StateProj::PollBody { body, trailers } => match ready!(body.poll_frame(cx)?) {
Some(frame) => match frame.into_trailers() {
Ok(prev_trailers) => {
let trailers = trailers.take().unwrap();
this.state.set(State::PollTrailers {
trailers,
prev_trailers: Some(prev_trailers),
});
}
Err(frame) => {
return Poll::Ready(Some(Ok(frame)));
}
},
None => {
let trailers = trailers.take().unwrap();
this.state.set(State::PollTrailers {
trailers,
prev_trailers: None,
});
}
},
StateProj::PollTrailers {
trailers,
prev_trailers,
} => {
let trailers = ready!(trailers.poll(cx)?);
match (trailers, prev_trailers.take()) {
(None, None) => return Poll::Ready(None),
(None, Some(trailers)) | (Some(trailers), None) => {
this.state.set(State::Done);
return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
}
(Some(new_trailers), Some(mut prev_trailers)) => {
prev_trailers.extend(new_trailers);
this.state.set(State::Done);
return Poll::Ready(Some(Ok(Frame::trailers(prev_trailers))));
}
}
}
StateProj::Done => {
return Poll::Ready(None);
}
}
}
}

#[inline]
fn size_hint(&self) -> http_body::SizeHint {
match &self.state {
State::PollBody { body, .. } => body.size_hint(),
State::PollTrailers { .. } | State::Done => Default::default(),
}
}
}

#[cfg(test)]
mod tests {
use std::convert::Infallible;

use bytes::Bytes;
use http::{HeaderMap, HeaderName, HeaderValue};

use crate::{BodyExt, Empty, Full};

#[allow(unused_imports)]
use super::*;

#[tokio::test]
async fn works() {
let mut trailers = HeaderMap::new();
trailers.insert(
HeaderName::from_static("foo"),
HeaderValue::from_static("bar"),
);

let body =
Full::<Bytes>::from("hello").with_trailers(std::future::ready(Some(
Ok::<_, Infallible>(trailers.clone()),
)));

futures_util::pin_mut!(body);
let waker = futures_util::task::noop_waker();
let mut cx = Context::from_waker(&waker);

let data = unwrap_ready(body.as_mut().poll_frame(&mut cx))
.unwrap()
.unwrap()
.into_data()
.unwrap();
assert_eq!(data, "hello");

let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx))
.unwrap()
.unwrap()
.into_trailers()
.unwrap();
assert_eq!(body_trailers, trailers);

assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none());
}

#[tokio::test]
async fn merges_trailers() {
let mut trailers_1 = HeaderMap::new();
trailers_1.insert(
HeaderName::from_static("foo"),
HeaderValue::from_static("bar"),
);

let mut trailers_2 = HeaderMap::new();
trailers_2.insert(
HeaderName::from_static("baz"),
HeaderValue::from_static("qux"),
);

let body = Empty::<Bytes>::new()
.with_trailers(std::future::ready(Some(Ok::<_, Infallible>(
trailers_1.clone(),
))))
.with_trailers(std::future::ready(Some(Ok::<_, Infallible>(
trailers_2.clone(),
))));

futures_util::pin_mut!(body);
let waker = futures_util::task::noop_waker();
let mut cx = Context::from_waker(&waker);

let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx))
.unwrap()
.unwrap()
.into_trailers()
.unwrap();

let mut all_trailers = HeaderMap::new();
all_trailers.extend(trailers_1);
all_trailers.extend(trailers_2);
assert_eq!(body_trailers, all_trailers);

assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none());
}

fn unwrap_ready<T>(poll: Poll<T>) -> T {
match poll {
Poll::Ready(t) => t,
Poll::Pending => panic!("pending"),
}
}
}
44 changes: 44 additions & 0 deletions http-body-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,50 @@ pub trait BodyExt: http_body::Body {
collected: Some(crate::Collected::default()),
}
}

/// Add trailers to the body.
///
/// The trailers will be sent when all previous frames have been sent and the `trailers` future
/// resolves.
///
/// # Example
///
/// ```
/// use http::HeaderMap;
/// use http_body_util::{Full, BodyExt};
/// use bytes::Bytes;
///
/// # #[tokio::main]
/// async fn main() {
/// let (tx, rx) = tokio::sync::oneshot::channel::<HeaderMap>();
///
/// let body = Full::<Bytes>::from("Hello, World!")
/// // add trailers via a future
/// .with_trailers(async move {
/// match rx.await {
/// Ok(trailers) => Some(Ok(trailers)),
/// Err(_err) => None,
/// }
/// });
///
/// // compute the trailers in the background
/// tokio::spawn(async move {
/// let _ = tx.send(compute_trailers().await);
/// });
///
/// async fn compute_trailers() -> HeaderMap {
/// // ...
/// # unimplemented!()
/// }
/// # }
/// ```
fn with_trailers<F>(self, trailers: F) -> combinators::WithTrailers<Self, F>
where
Self: Sized,
F: std::future::Future<Output = Option<Result<http::HeaderMap, Self::Error>>>,
{
combinators::WithTrailers::new(self, trailers)
}
}

impl<T: ?Sized> BodyExt for T where T: http_body::Body {}

0 comments on commit e5b462a

Please sign in to comment.