diff --git a/Cargo.toml b/Cargo.toml index 1667aa59..fdbc1794 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] +resolver = "2" members = [ "tower-http", "examples/*", diff --git a/examples/axum-key-value-store/src/main.rs b/examples/axum-key-value-store/src/main.rs index 12841189..a5e5a327 100644 --- a/examples/axum-key-value-store/src/main.rs +++ b/examples/axum-key-value-store/src/main.rs @@ -1,3 +1,8 @@ +fn main() { + eprintln!("this example has not yet been updated to hyper 1.0"); +} + +/* use axum::{ body::Bytes, extract::{Path, State}, @@ -108,3 +113,4 @@ async fn set_key(Path(path): Path, state: State, value: Bytes) // See https://github.com/tokio-rs/axum/blob/main/examples/testing/src/main.rs for an example of // how to test axum apps +*/ diff --git a/examples/tonic-key-value-store/src/main.rs b/examples/tonic-key-value-store/src/main.rs index bb7ea04a..e615e0ed 100644 --- a/examples/tonic-key-value-store/src/main.rs +++ b/examples/tonic-key-value-store/src/main.rs @@ -1,3 +1,8 @@ +fn main() { + eprint!("this example has not yet been updated to hyper 1.0"); +} + +/* use bytes::Bytes; use clap::Parser; use futures::StreamExt; @@ -370,3 +375,4 @@ mod tests { addr } } +*/ diff --git a/examples/warp-key-value-store/src/main.rs b/examples/warp-key-value-store/src/main.rs index feb179f8..065f3808 100644 --- a/examples/warp-key-value-store/src/main.rs +++ b/examples/warp-key-value-store/src/main.rs @@ -1,3 +1,8 @@ +fn main() { + eprint!("this example has not yet been updated to hyper 1.0"); +} + +/* use bytes::Bytes; use clap::Parser; use hyper::{ @@ -222,3 +227,4 @@ mod tests { addr } } +*/ diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index c56bc14f..206dcf66 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Changed - Bump Minimum Supported Rust Version to 1.66 ([#433]) +- Update to http-body 1.0 ([#348]) +- Update to http 1.0 ([#348]) - Preserve service error type in RequestDecompression ([#368]) ## Removed @@ -27,6 +29,7 @@ http-range-header to `0.4` [#418]: https://github.com/tower-rs/tower-http/pull/418 [#433]: https://github.com/tower-rs/tower-http/pull/433 +[#348]: https://github.com/tower-rs/tower-http/pull/348 [#368]: https://github.com/tower-rs/tower-http/pull/368 # 0.4.2 (July 19, 2023) diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index d2ad8ea3..d112ba9c 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tower-http" description = "Tower middleware and utilities for HTTP clients and servers" -version = "0.4.2" +version = "0.4.4" authors = ["Tower Maintainers "] edition = "2018" license = "MIT" @@ -15,8 +15,9 @@ rust-version = "1.66" [dependencies] bitflags = "2.0.2" bytes = "1" -http = "0.2.7" -http-body = "0.4.5" +http = "1.0" +http-body = "1.0.0" +http-body-util = "0.1.0" pin-project-lite = "0.2.7" tower-layer = "0.3" tower-service = "0.3" @@ -38,16 +39,19 @@ httpdate = { version = "1.0", optional = true } uuid = { version = "1.0", features = ["v4"], optional = true } [dev-dependencies] +async-trait = "0.1" +brotli = "3" bytes = "1" flate2 = "1.0" -brotli = "3" -hyper = { version = "0.14", features = ["full"] } +futures-util = "0.3.14" +hyper-util = { version = "0.1", features = ["client-legacy", "http1", "tokio"] } once_cell = "1" +serde_json = "1.0" +sync_wrapper = "0.1.1" tokio = { version = "1", features = ["full"] } tower = { version = "0.4.10", features = ["buffer", "util", "retry", "make", "timeout"] } tracing-subscriber = "0.3" uuid = { version = "1.0", features = ["v4"] } -serde_json = "1.0" zstd = "0.12" [features] diff --git a/tower-http/src/add_extension.rs b/tower-http/src/add_extension.rs index 4949a736..095646df 100644 --- a/tower-http/src/add_extension.rs +++ b/tower-http/src/add_extension.rs @@ -8,7 +8,8 @@ //! use tower_http::add_extension::AddExtensionLayer; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use http::{Request, Response}; -//! use hyper::Body; +//! use bytes::Bytes; +//! use http_body_util::Full; //! use std::{sync::Arc, convert::Infallible}; //! //! # struct DatabaseConnectionPool; @@ -21,11 +22,11 @@ //! pool: DatabaseConnectionPool, //! } //! -//! async fn handle(req: Request) -> Result, Infallible> { +//! async fn handle(req: Request>) -> Result>, Infallible> { //! // Grab the state from the request extensions. //! let state = req.extensions().get::>().unwrap(); //! -//! Ok(Response::new(Body::empty())) +//! Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] @@ -44,7 +45,7 @@ //! let response = service //! .ready() //! .await? -//! .call(Request::new(Body::empty())) +//! .call(Request::new(Full::default())) //! .await?; //! # Ok(()) //! # } @@ -137,8 +138,8 @@ where mod tests { #[allow(unused_imports)] use super::*; + use crate::test_helpers::Body; use http::Response; - use hyper::Body; use std::{convert::Infallible, sync::Arc}; use tower::{service_fn, ServiceBuilder, ServiceExt}; diff --git a/tower-http/src/auth/add_authorization.rs b/tower-http/src/auth/add_authorization.rs index e370190e..246c13b6 100644 --- a/tower-http/src/auth/add_authorization.rs +++ b/tower-http/src/auth/add_authorization.rs @@ -7,15 +7,16 @@ //! ``` //! use tower_http::validate_request::{ValidateRequestHeader, ValidateRequestHeaderLayer}; //! use tower_http::auth::AddAuthorizationLayer; -//! use hyper::{Request, Response, Body, Error}; -//! use http::{StatusCode, header::AUTHORIZATION}; -//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; -//! # async fn handle(request: Request) -> Result, Error> { -//! # Ok(Response::new(Body::empty())) +//! use http::{Request, Response, StatusCode, header::AUTHORIZATION}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! # async fn handle(request: Request>) -> Result>, BoxError> { +//! # Ok(Response::new(Full::default())) //! # } //! //! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { +//! # async fn main() -> Result<(), BoxError> { //! # let service_that_requires_auth = ValidateRequestHeader::basic( //! # tower::service_fn(handle), //! # "username", @@ -30,7 +31,7 @@ //! let response = client //! .ready() //! .await? -//! .call(Request::new(Body::empty())) +//! .call(Request::new(Full::default())) //! .await?; //! //! assert_eq!(StatusCode::OK, response.status()); @@ -84,7 +85,7 @@ impl AddAuthorizationLayer { /// /// # Panics /// - /// Panics if the token is not a valid [`HeaderValue`](http::header::HeaderValue). + /// Panics if the token is not a valid [`HeaderValue`]. pub fn bearer(token: &str) -> Self { let value = HeaderValue::try_from(format!("Bearer {}", token)).expect("token is not valid header"); @@ -147,7 +148,7 @@ impl AddAuthorization { /// /// # Panics /// - /// Panics if the token is not a valid [`HeaderValue`](http::header::HeaderValue). + /// Panics if the token is not a valid [`HeaderValue`]. pub fn bearer(inner: S, token: &str) -> Self { AddAuthorizationLayer::bearer(token).layer(inner) } @@ -187,12 +188,11 @@ where #[cfg(test)] mod tests { - use crate::validate_request::ValidateRequestHeaderLayer; - - #[allow(unused_imports)] use super::*; + use crate::test_helpers::Body; + use crate::validate_request::ValidateRequestHeaderLayer; use http::{Response, StatusCode}; - use hyper::Body; + use std::convert::Infallible; use tower::{BoxError, Service, ServiceBuilder, ServiceExt}; #[tokio::test] @@ -245,7 +245,7 @@ mod tests { let auth = request.headers().get(http::header::AUTHORIZATION).unwrap(); assert!(auth.is_sensitive()); - Ok::<_, hyper::Error>(Response::new(Body::empty())) + Ok::<_, Infallible>(Response::new(Body::empty())) }); let mut client = AddAuthorization::bearer(svc, "foo").as_sensitive(true); diff --git a/tower-http/src/auth/async_require_authorization.rs b/tower-http/src/auth/async_require_authorization.rs index 32244e83..f086add2 100644 --- a/tower-http/src/auth/async_require_authorization.rs +++ b/tower-http/src/auth/async_require_authorization.rs @@ -6,10 +6,11 @@ //! //! ``` //! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; -//! use hyper::{Request, Response, Body, Error}; -//! use http::{StatusCode, header::AUTHORIZATION}; -//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use http::{Request, Response, StatusCode, header::AUTHORIZATION}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; //! use futures_util::future::BoxFuture; +//! use bytes::Bytes; +//! use http_body_util::Full; //! //! #[derive(Clone, Copy)] //! struct MyAuth; @@ -19,7 +20,7 @@ //! B: Send + Sync + 'static, //! { //! type RequestBody = B; -//! type ResponseBody = Body; +//! type ResponseBody = Full; //! type Future = BoxFuture<'static, Result, Response>>; //! //! fn authorize(&mut self, mut request: Request) -> Self::Future { @@ -33,7 +34,7 @@ //! } else { //! let unauthorized_response = Response::builder() //! .status(StatusCode::UNAUTHORIZED) -//! .body(Body::empty()) +//! .body(Full::::default()) //! .unwrap(); //! //! Err(unauthorized_response) @@ -47,10 +48,10 @@ //! # None //! } //! -//! #[derive(Debug)] +//! #[derive(Debug, Clone)] //! struct UserId(String); //! -//! async fn handle(request: Request) -> Result, Error> { +//! async fn handle(request: Request>) -> Result>, BoxError> { //! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the //! // request was authorized and `UserId` will be present. //! let user_id = request @@ -60,11 +61,11 @@ //! //! println!("request from {:?}", user_id); //! -//! Ok(Response::new(Body::empty())) +//! Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { +//! # async fn main() -> Result<(), BoxError> { //! let service = ServiceBuilder::new() //! // Authorize requests using `MyAuth` //! .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) @@ -77,10 +78,11 @@ //! //! ``` //! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; -//! use hyper::{Request, Response, Body, Error}; -//! use http::StatusCode; -//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use http::{Request, Response, StatusCode}; +//! use tower::{Service, ServiceExt, ServiceBuilder, BoxError}; //! use futures_util::future::BoxFuture; +//! use http_body_util::Full; +//! use bytes::Bytes; //! //! async fn check_auth(request: &Request) -> Option { //! // ... @@ -90,21 +92,21 @@ //! #[derive(Debug)] //! struct UserId(String); //! -//! async fn handle(request: Request) -> Result, Error> { +//! async fn handle(request: Request>) -> Result>, BoxError> { //! # todo!(); //! // ... //! } //! //! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { +//! # async fn main() -> Result<(), BoxError> { //! let service = ServiceBuilder::new() -//! .layer(AsyncRequireAuthorizationLayer::new(|request: Request| async move { +//! .layer(AsyncRequireAuthorizationLayer::new(|request: Request>| async move { //! if let Some(user_id) = check_auth(&request).await { //! Ok(request) //! } else { //! let unauthorized_response = Response::builder() //! .status(StatusCode::UNAUTHORIZED) -//! .body(Body::empty()) +//! .body(Full::::default()) //! .unwrap(); //! //! Err(unauthorized_response) @@ -306,9 +308,9 @@ where mod tests { #[allow(unused_imports)] use super::*; + use crate::test_helpers::Body; use futures_util::future::BoxFuture; use http::{header, StatusCode}; - use hyper::Body; use tower::{BoxError, ServiceBuilder, ServiceExt}; #[derive(Clone, Copy)] @@ -346,7 +348,7 @@ mod tests { } } - #[derive(Debug)] + #[derive(Clone, Debug)] struct UserId(String); #[tokio::test] diff --git a/tower-http/src/auth/require_authorization.rs b/tower-http/src/auth/require_authorization.rs index c92ae6e8..d5c9508f 100644 --- a/tower-http/src/auth/require_authorization.rs +++ b/tower-http/src/auth/require_authorization.rs @@ -6,16 +6,17 @@ //! //! ``` //! use tower_http::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer}; -//! use hyper::{Request, Response, Body, Error}; -//! use http::{StatusCode, header::AUTHORIZATION}; -//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use http::{Request, Response, StatusCode, header::AUTHORIZATION}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! use bytes::Bytes; +//! use http_body_util::Full; //! -//! async fn handle(request: Request) -> Result, Error> { -//! Ok(Response::new(Body::empty())) +//! async fn handle(request: Request>) -> Result>, BoxError> { +//! Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { +//! # async fn main() -> Result<(), BoxError> { //! let mut service = ServiceBuilder::new() //! // Require the `Authorization` header to be `Bearer passwordlol` //! .layer(ValidateRequestHeaderLayer::bearer("passwordlol")) @@ -24,7 +25,7 @@ //! // Requests with the correct token are allowed through //! let request = Request::builder() //! .header(AUTHORIZATION, "Bearer passwordlol") -//! .body(Body::empty()) +//! .body(Full::default()) //! .unwrap(); //! //! let response = service @@ -37,7 +38,7 @@ //! //! // Requests with an invalid token get a `401 Unauthorized` response //! let request = Request::builder() -//! .body(Body::empty()) +//! .body(Full::default()) //! .unwrap(); //! //! let response = service @@ -103,7 +104,7 @@ impl ValidateRequestHeader> { /// /// # Panics /// - /// Panics if the token is not a valid [`HeaderValue`](http::header::HeaderValue). + /// Panics if the token is not a valid [`HeaderValue`]. pub fn bearer(inner: S, token: &str) -> Self where ResBody: Body + Default, @@ -119,7 +120,7 @@ impl ValidateRequestHeaderLayer> { /// /// # Panics /// - /// Panics if the token is not a valid [`HeaderValue`](http::header::HeaderValue). + /// Panics if the token is not a valid [`HeaderValue`]. pub fn bearer(token: &str) -> Self where ResBody: Body + Default, @@ -250,8 +251,8 @@ mod tests { #[allow(unused_imports)] use super::*; + use crate::test_helpers::Body; use http::header; - use hyper::Body; use tower::{BoxError, ServiceBuilder, ServiceExt}; use tower_service::Service; diff --git a/tower-http/src/body.rs b/tower-http/src/body.rs new file mode 100644 index 00000000..815a0d10 --- /dev/null +++ b/tower-http/src/body.rs @@ -0,0 +1,121 @@ +//! Body types. +//! +//! All these are wrappers around other body types. You shouldn't have to use them in your code. +//! Use `http-body-util` instead. +//! +//! They exist because we don't want to expose types from `http-body-util` in `tower-http`s public +//! API. + +#![allow(missing_docs)] + +use std::convert::Infallible; + +use bytes::{Buf, Bytes}; +use http_body::Body; +use pin_project_lite::pin_project; + +use crate::BoxError; + +macro_rules! body_methods { + () => { + #[inline] + fn poll_frame( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + self.project().inner.poll_frame(cx) + } + + #[inline] + fn is_end_stream(&self) -> bool { + Body::is_end_stream(&self.inner) + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + Body::size_hint(&self.inner) + } + }; +} + +pin_project! { + #[derive(Default)] + pub struct Full { + #[pin] + pub(crate) inner: http_body_util::Full + } +} + +impl Full { + #[allow(dead_code)] + pub(crate) fn new(inner: http_body_util::Full) -> Self { + Self { inner } + } +} + +impl Body for Full { + type Data = Bytes; + type Error = Infallible; + + body_methods!(); +} + +pin_project! { + pub struct Limited { + #[pin] + pub(crate) inner: http_body_util::Limited + } +} + +impl Limited { + #[allow(dead_code)] + pub(crate) fn new(inner: http_body_util::Limited) -> Self { + Self { inner } + } +} + +impl Body for Limited +where + B: Body, + B::Error: Into, +{ + type Data = B::Data; + type Error = BoxError; + + body_methods!(); +} + +pin_project! { + pub struct UnsyncBoxBody { + #[pin] + pub(crate) inner: http_body_util::combinators::UnsyncBoxBody + } +} + +impl Default for UnsyncBoxBody +where + D: Buf + 'static, +{ + fn default() -> Self { + Self { + inner: Default::default(), + } + } +} + +impl UnsyncBoxBody { + #[allow(dead_code)] + pub(crate) fn new(inner: http_body_util::combinators::UnsyncBoxBody) -> Self { + Self { inner } + } +} + +impl Body for UnsyncBoxBody +where + D: Buf, +{ + type Data = D; + type Error = E; + + body_methods!(); +} diff --git a/tower-http/src/builder.rs b/tower-http/src/builder.rs index 2cb4f94a..d8d14aa0 100644 --- a/tower-http/src/builder.rs +++ b/tower-http/src/builder.rs @@ -17,13 +17,14 @@ use tower_layer::Stack; /// /// ```rust /// use http::{Request, Response, header::HeaderName}; -/// use hyper::Body; +/// use bytes::Bytes; +/// use http_body_util::Full; /// use std::{time::Duration, convert::Infallible}; /// use tower::{ServiceBuilder, ServiceExt, Service}; /// use tower_http::ServiceBuilderExt; /// -/// async fn handle(request: Request) -> Result, Infallible> { -/// Ok(Response::new(Body::empty())) +/// async fn handle(request: Request>) -> Result>, Infallible> { +/// Ok(Response::new(Full::default())) /// } /// /// # #[tokio::main] @@ -33,11 +34,10 @@ use tower_layer::Stack; /// .timeout(Duration::from_secs(30)) /// // Methods from tower-http /// .trace_for_http() -/// .compression() /// .propagate_header(HeaderName::from_static("x-request-id")) /// .service_fn(handle); /// # let mut service = service; -/// # service.ready().await.unwrap().call(Request::new(Body::empty())).await.unwrap(); +/// # service.ready().await.unwrap().call(Request::new(Full::default())).await.unwrap(); /// # } /// ``` #[cfg(feature = "util")] diff --git a/tower-http/src/catch_panic.rs b/tower-http/src/catch_panic.rs index 5cf6bdbf..3f1c2279 100644 --- a/tower-http/src/catch_panic.rs +++ b/tower-http/src/catch_panic.rs @@ -10,11 +10,12 @@ //! use std::convert::Infallible; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::catch_panic::CatchPanicLayer; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! async fn handle(req: Request) -> Result, Infallible> { +//! async fn handle(req: Request>) -> Result>, Infallible> { //! panic!("something went wrong...") //! } //! @@ -24,7 +25,7 @@ //! .service_fn(handle); //! //! // Call the service. -//! let request = Request::new(Body::empty()); +//! let request = Request::new(Full::default()); //! //! let response = svc.ready().await?.call(request).await?; //! @@ -41,15 +42,16 @@ //! use std::{any::Any, convert::Infallible}; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::catch_panic::CatchPanicLayer; -//! use hyper::Body; +//! use bytes::Bytes; +//! use http_body_util::Full; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! async fn handle(req: Request) -> Result, Infallible> { +//! async fn handle(req: Request>) -> Result>, Infallible> { //! panic!("something went wrong...") //! } //! -//! fn handle_panic(err: Box) -> Response { +//! fn handle_panic(err: Box) -> Response> { //! let details = if let Some(s) = err.downcast_ref::() { //! s.clone() //! } else if let Some(s) = err.downcast_ref::<&str>() { @@ -69,7 +71,7 @@ //! Response::builder() //! .status(StatusCode::INTERNAL_SERVER_ERROR) //! .header(header::CONTENT_TYPE, "application/json") -//! .body(Body::from(body)) +//! .body(Full::from(body)) //! .unwrap() //! } //! @@ -85,7 +87,8 @@ use bytes::Bytes; use futures_util::future::{CatchUnwind, FutureExt}; use http::{HeaderValue, Request, Response, StatusCode}; -use http_body::{combinators::UnsyncBoxBody, Body, Full}; +use http_body::Body; +use http_body_util::BodyExt; use pin_project_lite::pin_project; use std::{ any::Any, @@ -97,7 +100,10 @@ use std::{ use tower_layer::Layer; use tower_service::Service; -use crate::BoxError; +use crate::{ + body::{Full, UnsyncBoxBody}, + BoxError, +}; /// Layer that applies the [`CatchPanic`] middleware that catches panics and converts them into /// `500 Internal Server` responses. @@ -262,7 +268,9 @@ where panic_handler, } => match ready!(future.poll(cx)) { Ok(Ok(res)) => { - Poll::Ready(Ok(res.map(|body| body.map_err(Into::into).boxed_unsync()))) + Poll::Ready(Ok(res.map(|body| { + UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync()) + }))) } Ok(Err(svc_err)) => Poll::Ready(Err(svc_err)), Err(panic_err) => Poll::Ready(Ok(response_for_panic( @@ -287,7 +295,7 @@ where { panic_handler .response_for_panic(err) - .map(|body| body.map_err(Into::into).boxed_unsync()) + .map(|body| UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync())) } /// Trait for creating responses from panics. @@ -325,7 +333,7 @@ where pub struct DefaultResponseForPanic; impl ResponseForPanic for DefaultResponseForPanic { - type ResponseBody = Full; + type ResponseBody = Full; fn response_for_panic( &mut self, @@ -341,7 +349,7 @@ impl ResponseForPanic for DefaultResponseForPanic { ); }; - let mut res = Response::new(Full::from("Service panicked")); + let mut res = Response::new(Full::new(http_body_util::Full::from("Service panicked"))); *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; #[allow(clippy::declare_interior_mutable_const)] @@ -358,7 +366,8 @@ mod tests { #![allow(unreachable_code)] use super::*; - use hyper::{Body, Response}; + use crate::test_helpers::Body; + use http::Response; use std::convert::Infallible; use tower::{ServiceBuilder, ServiceExt}; @@ -376,7 +385,7 @@ mod tests { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); - let body = hyper::body::to_bytes(res).await.unwrap(); + let body = crate::test_helpers::to_bytes(res).await.unwrap(); assert_eq!(&body[..], b"Service panicked"); } @@ -394,7 +403,7 @@ mod tests { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); - let body = hyper::body::to_bytes(res).await.unwrap(); + let body = crate::test_helpers::to_bytes(res).await.unwrap(); assert_eq!(&body[..], b"Service panicked"); } } diff --git a/tower-http/src/classify/grpc_errors_as_failures.rs b/tower-http/src/classify/grpc_errors_as_failures.rs index 056cec93..b88606b5 100644 --- a/tower-http/src/classify/grpc_errors_as_failures.rs +++ b/tower-http/src/classify/grpc_errors_as_failures.rs @@ -3,7 +3,7 @@ use bitflags::bitflags; use http::{HeaderMap, Response}; use std::{fmt, num::NonZeroI32}; -/// gRPC status codes. Used in [`GrpcErrorsAsFailures::success_codes`]. +/// gRPC status codes. /// /// These variants match the [gRPC status codes]. /// @@ -125,7 +125,7 @@ impl GrpcCodeBitmask { /// /// Responses are considered successful if /// -/// - `grpc-status` header value matches [`GrpcErrorsAsFailures::success_codes`] (only `Ok` by +/// - `grpc-status` header value contains a success value. /// default). /// - `grpc-status` header is missing. /// - `grpc-status` header value isn't a valid `String`. @@ -262,7 +262,6 @@ impl fmt::Display for GrpcFailureClass { } } -#[allow(clippy::if_let_some_result)] pub(crate) fn classify_grpc_metadata( headers: &HeaderMap, success_codes: GrpcCodeBitmask, diff --git a/tower-http/src/classify/mod.rs b/tower-http/src/classify/mod.rs index 08b322a0..6ea32559 100644 --- a/tower-http/src/classify/mod.rs +++ b/tower-http/src/classify/mod.rs @@ -192,7 +192,7 @@ pub trait ClassifyResponse { /// ClassifyResponse, ClassifiedResponse /// }; /// use http::{Response, StatusCode}; - /// use http_body::Empty; + /// use http_body_util::Empty; /// use bytes::Bytes; /// /// fn transform_failure_class(class: ServerErrorsFailureClass) -> NewFailureClass { @@ -375,7 +375,7 @@ impl fmt::Display for ServerErrorsFailureClass { mod usable_for_retries { #[allow(unused_imports)] use super::*; - use hyper::{Request, Response}; + use http::{Request, Response}; use tower::retry::Policy; trait IsRetryable { diff --git a/tower-http/src/classify/status_in_range_is_error.rs b/tower-http/src/classify/status_in_range_is_error.rs index ce3202a2..934d08c5 100644 --- a/tower-http/src/classify/status_in_range_is_error.rs +++ b/tower-http/src/classify/status_in_range_is_error.rs @@ -12,20 +12,23 @@ use std::{fmt, ops::RangeInclusive}; /// ```no_run /// use tower_http::{trace::TraceLayer, classify::StatusInRangeAsFailures}; /// use tower::{ServiceBuilder, Service, ServiceExt}; -/// use hyper::{Client, Body}; /// use http::{Request, Method}; +/// use http_body_util::Full; +/// use bytes::Bytes; +/// use hyper_util::{rt::TokioExecutor, client::legacy::Client}; /// /// # async fn foo() -> Result<(), tower::BoxError> { /// let classifier = StatusInRangeAsFailures::new(400..=599); /// +/// let client = Client::builder(TokioExecutor::new()).build_http(); /// let mut client = ServiceBuilder::new() /// .layer(TraceLayer::new(classifier.into_make_classifier())) -/// .service(Client::new()); +/// .service(client); /// /// let request = Request::builder() /// .method(Method::GET) /// .uri("https://example.com") -/// .body(Body::empty()) +/// .body(Full::::default()) /// .unwrap(); /// /// let response = client.ready().await?.call(request).await?; diff --git a/tower-http/src/compression/body.rs b/tower-http/src/compression/body.rs index 90229d5e..013d605b 100644 --- a/tower-http/src/compression/body.rs +++ b/tower-http/src/compression/body.rs @@ -246,46 +246,29 @@ where type Data = Bytes; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { match self.project().inner.project() { #[cfg(feature = "compression-gzip")] - BodyInnerProj::Gzip { inner } => inner.poll_data(cx), + BodyInnerProj::Gzip { inner } => inner.poll_frame(cx), #[cfg(feature = "compression-deflate")] - BodyInnerProj::Deflate { inner } => inner.poll_data(cx), + BodyInnerProj::Deflate { inner } => inner.poll_frame(cx), #[cfg(feature = "compression-br")] - BodyInnerProj::Brotli { inner } => inner.poll_data(cx), + BodyInnerProj::Brotli { inner } => inner.poll_frame(cx), #[cfg(feature = "compression-zstd")] - BodyInnerProj::Zstd { inner } => inner.poll_data(cx), - BodyInnerProj::Identity { inner } => match ready!(inner.poll_data(cx)) { - Some(Ok(mut buf)) => { - let bytes = buf.copy_to_bytes(buf.remaining()); - Poll::Ready(Some(Ok(bytes))) + BodyInnerProj::Zstd { inner } => inner.poll_frame(cx), + BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) { + Some(Ok(frame)) => { + let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())); + Poll::Ready(Some(Ok(frame))) } Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), None => Poll::Ready(None), }, } } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.project() { - #[cfg(feature = "compression-gzip")] - BodyInnerProj::Gzip { inner } => inner.poll_trailers(cx), - #[cfg(feature = "compression-deflate")] - BodyInnerProj::Deflate { inner } => inner.poll_trailers(cx), - #[cfg(feature = "compression-br")] - BodyInnerProj::Brotli { inner } => inner.poll_trailers(cx), - #[cfg(feature = "compression-zstd")] - BodyInnerProj::Zstd { inner } => inner.poll_trailers(cx), - BodyInnerProj::Identity { inner } => inner.poll_trailers(cx).map_err(Into::into), - } - } } #[cfg(feature = "compression-gzip")] diff --git a/tower-http/src/compression/future.rs b/tower-http/src/compression/future.rs index ca574fce..98763006 100644 --- a/tower-http/src/compression/future.rs +++ b/tower-http/src/compression/future.rs @@ -36,7 +36,7 @@ where { type Output = Result>, E>; - #[allow(unreachable_code, unused_mut, unused_variables)] + #[allow(unreachable_code, unused_mut, unused_variables, unreachable_patterns)] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let res = ready!(self.as_mut().project().inner.poll(cx)?); diff --git a/tower-http/src/compression/layer.rs b/tower-http/src/compression/layer.rs index cce6b65a..5dcab99c 100644 --- a/tower-http/src/compression/layer.rs +++ b/tower-http/src/compression/layer.rs @@ -123,13 +123,11 @@ impl CompressionLayer { #[cfg(test)] mod tests { use super::*; + use crate::test_helpers::Body; use http::{header::ACCEPT_ENCODING, Request, Response}; - use http_body::Body as _; - use hyper::Body; - use tokio::fs::File; - // for Body::data - use bytes::{Bytes, BytesMut}; + use http_body_util::BodyExt; use std::convert::Infallible; + use tokio::fs::File; use tokio_util::io::ReaderStream; use tower::{Service, ServiceBuilder, ServiceExt}; @@ -139,7 +137,7 @@ mod tests { // Convert the file into a `Stream`. let stream = ReaderStream::new(file); // Convert the `Stream` into a `Body`. - let body = Body::wrap_stream(stream); + let body = Body::from_stream(stream); // Create response. Ok(Response::new(body)) } @@ -166,13 +164,8 @@ mod tests { assert_eq!(response.headers()["content-encoding"], "deflate"); // Read the body - let mut body = response.into_body(); - let mut bytes = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk?; - bytes.extend_from_slice(&chunk[..]); - } - let bytes: Bytes = bytes.freeze(); + let body = response.into_body(); + let bytes = body.collect().await.unwrap().to_bytes(); let deflate_bytes_len = bytes.len(); @@ -196,13 +189,8 @@ mod tests { assert_eq!(response.headers()["content-encoding"], "br"); // Read the body - let mut body = response.into_body(); - let mut bytes = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk?; - bytes.extend_from_slice(&chunk[..]); - } - let bytes: Bytes = bytes.freeze(); + let body = response.into_body(); + let bytes = body.collect().await.unwrap().to_bytes(); let br_byte_length = bytes.len(); diff --git a/tower-http/src/compression/mod.rs b/tower-http/src/compression/mod.rs index b3beed9a..897304c3 100644 --- a/tower-http/src/compression/mod.rs +++ b/tower-http/src/compression/mod.rs @@ -7,23 +7,30 @@ //! ```rust //! use bytes::{Bytes, BytesMut}; //! use http::{Request, Response, header::ACCEPT_ENCODING}; -//! use http_body::Body as _; // for Body::data -//! use hyper::Body; +//! use http_body_util::{Full, BodyExt, StreamBody, combinators::UnsyncBoxBody}; +//! use http_body::Frame; //! use std::convert::Infallible; //! use tokio::fs::{self, File}; //! use tokio_util::io::ReaderStream; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::{compression::CompressionLayer, BoxError}; +//! use futures_util::TryStreamExt; +//! +//! type BoxBody = UnsyncBoxBody; //! //! # #[tokio::main] //! # async fn main() -> Result<(), BoxError> { -//! async fn handle(req: Request) -> Result, Infallible> { +//! async fn handle(req: Request>) -> Result, Infallible> { //! // Open the file. //! let file = File::open("Cargo.toml").await.expect("file missing"); -//! // Convert the file into a `Stream`. +//! // Convert the file into a `Stream` of `Bytes`. //! let stream = ReaderStream::new(file); +//! // Convert the stream into a stream of data `Frame`s. +//! let stream = stream.map_ok(Frame::data); //! // Convert the `Stream` into a `Body`. -//! let body = Body::wrap_stream(stream); +//! let body = StreamBody::new(stream); +//! // Erase the type because its very hard to name in the function signature. +//! let body = body.boxed_unsync(); //! // Create response. //! Ok(Response::new(body)) //! } @@ -36,7 +43,7 @@ //! // Call the service. //! let request = Request::builder() //! .header(ACCEPT_ENCODING, "gzip") -//! .body(Body::empty())?; +//! .body(Full::::default())?; //! //! let response = service //! .ready() @@ -47,13 +54,11 @@ //! assert_eq!(response.headers()["content-encoding"], "gzip"); //! //! // Read the body -//! let mut body = response.into_body(); -//! let mut bytes = BytesMut::new(); -//! while let Some(chunk) = body.data().await { -//! let chunk = chunk?; -//! bytes.extend_from_slice(&chunk[..]); -//! } -//! let bytes: Bytes = bytes.freeze(); +//! let bytes = response +//! .into_body() +//! .collect() +//! .await? +//! .to_bytes(); //! //! // The compressed body should be smaller 🤞 //! let uncompressed_len = fs::read_to_string("Cargo.toml").await?.len(); @@ -87,17 +92,18 @@ mod tests { use crate::compression::predicate::SizeAbove; use super::*; + use crate::test_helpers::{Body, WithTrailers}; use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder}; - use bytes::BytesMut; use flate2::read::GzDecoder; use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE}; - use http_body::Body as _; - use hyper::{Body, Error, Request, Response, Server}; + use http::{HeaderMap, HeaderName, Request, Response}; + use http_body_util::BodyExt; + use std::convert::Infallible; + use std::io::Read; use std::sync::{Arc, RwLock}; - use std::{io::Read, net::SocketAddr}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_util::io::StreamReader; - use tower::{make::Shared, service_fn, Service, ServiceExt}; + use tower::{service_fn, Service, ServiceExt}; // Compression filter allows every other request to be compressed #[derive(Clone)] @@ -125,13 +131,9 @@ mod tests { let res = svc.ready().await.unwrap().call(req).await.unwrap(); // read the compressed body - let mut body = res.into_body(); - let mut data = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - data.extend_from_slice(&chunk[..]); - } - let compressed_data = data.freeze().to_vec(); + let collected = res.into_body().collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + let compressed_data = collected.to_bytes(); // decompress the body // doing this with flate2 as that is much easier than async-compression and blocking during @@ -141,6 +143,9 @@ mod tests { decoder.read_to_string(&mut decompressed).unwrap(); assert_eq!(decompressed, "Hello, World!"); + + // trailers are maintained + assert_eq!(trailers["foo"], "bar"); } #[tokio::test] @@ -156,13 +161,8 @@ mod tests { let res = svc.ready().await.unwrap().call(req).await.unwrap(); // read the compressed body - let mut body = res.into_body(); - let mut data = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - data.extend_from_slice(&chunk[..]); - } - let compressed_data = data.freeze().to_vec(); + let body = res.into_body(); + let compressed_data = body.collect().await.unwrap().to_bytes(); // decompress the body let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap(); @@ -171,18 +171,6 @@ mod tests { assert_eq!(decompressed, "Hello, World!"); } - #[allow(dead_code)] - async fn is_compatible_with_hyper() { - let svc = service_fn(handle); - let svc = Compression::new(svc); - - let make_service = Shared::new(svc); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let server = Server::bind(&addr).serve(make_service); - server.await.unwrap(); - } - #[tokio::test] async fn no_recompress() { const DATA: &str = "Hello, World! I'm already compressed with br!"; @@ -225,12 +213,8 @@ mod tests { ); // read the compressed body - let mut body = res.into_body(); - let mut data = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - data.extend_from_slice(&chunk[..]); - } + let body = res.into_body(); + let data = body.collect().await.unwrap().to_bytes(); // decompress the body let data = { @@ -247,8 +231,11 @@ mod tests { assert_eq!(data, DATA.as_bytes()); } - async fn handle(_req: Request) -> Result, Error> { - Ok(Response::new(Body::from("Hello, World!"))) + async fn handle(_req: Request) -> Result>, Infallible> { + let mut trailers = HeaderMap::new(); + trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap()); + let body = Body::from("Hello, World!").with_trailers(trailers); + Ok(Response::builder().body(body).unwrap()) } #[tokio::test] @@ -269,6 +256,7 @@ mod tests { #[derive(Default, Clone)] struct EveryOtherResponse(Arc>); + #[allow(clippy::dbg_macro)] impl Predicate for EveryOtherResponse { fn should_compress(&self, _: &http::Response) -> bool where @@ -289,12 +277,8 @@ mod tests { let res = svc.ready().await.unwrap().call(req).await.unwrap(); // read the uncompressed body - let mut body = res.into_body(); - let mut data = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - data.extend_from_slice(&chunk[..]); - } + let body = res.into_body(); + let data = body.collect().await.unwrap().to_bytes(); let still_uncompressed = String::from_utf8(data.to_vec()).unwrap(); assert_eq!(DATA, &still_uncompressed); @@ -306,18 +290,14 @@ mod tests { let res = svc.ready().await.unwrap().call(req).await.unwrap(); // read the compressed body - let mut body = res.into_body(); - let mut data = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - data.extend_from_slice(&chunk[..]); - } + let body = res.into_body(); + let data = body.collect().await.unwrap().to_bytes(); assert!(String::from_utf8(data.to_vec()).is_err()); } #[tokio::test] async fn doesnt_compress_images() { - async fn handle(_req: Request) -> Result, Error> { + async fn handle(_req: Request) -> Result, Infallible> { let mut res = Response::new(Body::from( "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize), )); @@ -342,7 +322,7 @@ mod tests { #[tokio::test] async fn does_compress_svg() { - async fn handle(_req: Request) -> Result, Error> { + async fn handle(_req: Request) -> Result, Infallible> { let mut res = Response::new(Body::from( "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize), )); @@ -387,13 +367,8 @@ mod tests { let res = svc.ready().await.unwrap().call(req).await.unwrap(); // read the compressed body - let mut body = res.into_body(); - let mut data = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - data.extend_from_slice(&chunk[..]); - } - let compressed_data = data.freeze().to_vec(); + let body = res.into_body(); + let compressed_data = body.collect().await.unwrap().to_bytes(); // build the compressed body with the same quality level let compressed_with_level = { @@ -411,7 +386,7 @@ mod tests { }; assert_eq!( - compressed_data.as_slice(), + compressed_data, compressed_with_level.as_slice(), "Compression level is not respected" ); diff --git a/tower-http/src/compression_utils.rs b/tower-http/src/compression_utils.rs index 2aabca68..1d851e67 100644 --- a/tower-http/src/compression_utils.rs +++ b/tower-http/src/compression_utils.rs @@ -1,10 +1,10 @@ //! Types used by compression and decompression middleware. use crate::{content_encoding::SupportedEncodings, BoxError}; -use bytes::{Bytes, BytesMut}; +use bytes::{Buf, Bytes, BytesMut}; use futures_util::Stream; use http::HeaderValue; -use http_body::Body; +use http_body::{Body, Frame}; use pin_project_lite::pin_project; use std::{ io, @@ -12,7 +12,7 @@ use std::{ task::{ready, Context, Poll}, }; use tokio::io::AsyncRead; -use tokio_util::io::{poll_read_buf, StreamReader}; +use tokio_util::io::StreamReader; #[derive(Debug, Clone, Copy)] pub(crate) struct AcceptEncoding { @@ -150,7 +150,10 @@ pin_project! { /// `Body` that has been decorated by an `AsyncRead` pub(crate) struct WrapBody { #[pin] - pub(crate) read: M::Output, + // rust-analyer thinks this field is private if its `pub(crate)` but works fine when its + // `pub` + pub read: M::Output, + read_all_data: bool, } } @@ -174,7 +177,10 @@ impl WrapBody { // apply decorator to `AsyncRead` yielding another `AsyncRead` let read = M::apply(read, quality); - Self { read } + Self { + read, + read_all_data: false, + } } } @@ -187,65 +193,80 @@ where type Data = Bytes; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let mut this = self.project(); let mut buf = BytesMut::new(); - - let read = match ready!(poll_read_buf(this.read.as_mut(), cx, &mut buf)) { - Ok(read) => read, - Err(err) => { - let body_error: Option = M::get_pin_mut(this.read) - .get_pin_mut() - .project() - .error - .take(); - - if let Some(body_error) = body_error { - return Poll::Ready(Some(Err(body_error.into()))); - } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) { - // SENTINEL_ERROR_CODE only gets used when storing an underlying body error - unreachable!() - } else { - return Poll::Ready(Some(Err(err.into()))); + if !*this.read_all_data { + match tokio_util::io::poll_read_buf(this.read.as_mut(), cx, &mut buf) { + Poll::Ready(result) => { + match result { + Ok(read) => { + if read == 0 { + *this.read_all_data = true; + } else { + return Poll::Ready(Some(Ok(Frame::data(buf.freeze())))); + } + } + Err(err) => { + let body_error: Option = M::get_pin_mut(this.read) + .get_pin_mut() + .project() + .error + .take(); + + if let Some(body_error) = body_error { + return Poll::Ready(Some(Err(body_error.into()))); + } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) { + // SENTINEL_ERROR_CODE only gets used when storing an underlying body error + unreachable!() + } else { + return Poll::Ready(Some(Err(err.into()))); + } + } + } } + Poll::Pending => return Poll::Pending, } - }; - - if read == 0 { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(buf.freeze()))) } - } - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - let this = self.project(); - let body = M::get_pin_mut(this.read) - .get_pin_mut() - .get_pin_mut() - .get_pin_mut(); - body.poll_trailers(cx).map_err(Into::into) + // poll any remaining frames, such as trailers + let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut(); + body.poll_frame(cx).map(|option| { + option.map(|result| { + result + .map(|frame| frame.map_data(|mut data| data.copy_to_bytes(data.remaining()))) + .map_err(|err| err.into()) + }) + }) } } pin_project! { - // When https://github.com/hyperium/http-body/pull/36 is merged we can remove this - pub(crate) struct BodyIntoStream { + pub(crate) struct BodyIntoStream + where + B: Body, + { #[pin] body: B, + yielded_all_data: bool, + non_data_frame: Option>, } } #[allow(dead_code)] -impl BodyIntoStream { +impl BodyIntoStream +where + B: Body, +{ pub(crate) fn new(body: B) -> Self { - Self { body } + Self { + body, + yielded_all_data: false, + non_data_frame: None, + } } /// Get a reference to the inner body @@ -275,8 +296,63 @@ where { type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().body.poll_data(cx) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + let this = self.as_mut().project(); + + if *this.yielded_all_data { + return Poll::Ready(None); + } + + match std::task::ready!(this.body.poll_frame(cx)) { + Some(Ok(frame)) => match frame.into_data() { + Ok(data) => return Poll::Ready(Some(Ok(data))), + Err(frame) => { + *this.yielded_all_data = true; + *this.non_data_frame = Some(frame); + } + }, + Some(Err(err)) => return Poll::Ready(Some(Err(err))), + None => { + *this.yielded_all_data = true; + } + } + } + } +} + +impl Body for BodyIntoStream +where + B: Body, +{ + type Data = B::Data; + type Error = B::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + // First drive the stream impl. This consumes all data frames and buffer at most one + // trailers frame. + if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) { + return Poll::Ready(Some(frame.map(Frame::data))); + } + + let this = self.project(); + + // Yield the trailers frame `poll_next` hit. + if let Some(frame) = this.non_data_frame.take() { + return Poll::Ready(Some(Ok(frame))); + } + + // Yield any remaining frames in the body. There shouldn't be any after the trailers but + // you never know. + this.body.poll_frame(cx) + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.body.size_hint() } } @@ -337,13 +413,14 @@ pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418; /// Level of compression data should be compressed with. #[non_exhaustive] -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)] pub enum CompressionLevel { /// Fastest quality of compression, usually produces bigger size. Fastest, /// Best quality of compression, usually produces the smallest size. Best, /// Default quality of compression defined by the selected compression algorithm. + #[default] Default, /// Precise quality based on the underlying compression algorithms' /// qualities. The interpretation of this depends on the algorithm chosen @@ -352,12 +429,6 @@ pub enum CompressionLevel { Precise(i32), } -impl Default for CompressionLevel { - fn default() -> Self { - CompressionLevel::Default - } -} - #[cfg(any( feature = "compression-br", feature = "compression-gzip", diff --git a/tower-http/src/content_encoding.rs b/tower-http/src/content_encoding.rs index c962d0ee..9273a0f3 100644 --- a/tower-http/src/content_encoding.rs +++ b/tower-http/src/content_encoding.rs @@ -148,7 +148,7 @@ impl QValue { let mut c = s.chars(); // Parse "q=" (case-insensitively). match c.next() { - Some('q') | Some('Q') => (), + Some('q' | 'Q') => (), _ => return None, }; match c.next() { @@ -271,8 +271,7 @@ mod tests { #[test] fn no_accept_encoding_header() { - let encoding = - Encoding::from_headers(&http::HeaderMap::new(), SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&http::HeaderMap::new(), SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); } @@ -283,7 +282,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); } @@ -294,7 +293,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -305,7 +304,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip,deflate,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -316,7 +315,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -327,7 +326,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,deflate,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -342,7 +341,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -357,7 +356,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -376,7 +375,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -387,7 +386,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br;q=0.8"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); let mut headers = http::HeaderMap::new(); @@ -395,7 +394,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.8,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); @@ -403,7 +402,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.995,br;q=0.999"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -414,7 +413,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,deflate;q=0.6,br;q=0.8"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); let mut headers = http::HeaderMap::new(); @@ -422,7 +421,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.8,deflate;q=0.6,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); @@ -430,7 +429,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.6,deflate;q=0.8,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Deflate, encoding); let mut headers = http::HeaderMap::new(); @@ -438,7 +437,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.995,deflate;q=0.997,br;q=0.999"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -449,7 +448,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("invalid,gzip"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); } @@ -460,7 +459,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -468,7 +467,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0."), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -476,7 +475,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -487,7 +486,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gZiP"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); @@ -495,7 +494,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br;Q=0.8"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -506,7 +505,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static(" gzip\t; q=0.5 ,\tbr ;\tq=0.8\t"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -517,7 +516,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q =0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -525,7 +524,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q= 0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); } @@ -536,7 +535,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=-0.1"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -544,7 +543,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=00.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -552,7 +551,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5000"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -560,7 +559,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -568,7 +567,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=1.01"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -576,7 +575,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=1.001"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll::default()); + let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); } } diff --git a/tower-http/src/cors/allow_credentials.rs b/tower-http/src/cors/allow_credentials.rs index e489c570..de53ffed 100644 --- a/tower-http/src/cors/allow_credentials.rs +++ b/tower-http/src/cors/allow_credentials.rs @@ -57,7 +57,7 @@ impl AllowCredentials { AllowCredentialsInner::Predicate(c) => c(origin?, parts), }; - allow_creds.then(|| (header::ACCESS_CONTROL_ALLOW_CREDENTIALS, TRUE)) + allow_creds.then_some((header::ACCESS_CONTROL_ALLOW_CREDENTIALS, TRUE)) } } diff --git a/tower-http/src/cors/allow_private_network.rs b/tower-http/src/cors/allow_private_network.rs index 4163014d..9f97dc11 100644 --- a/tower-http/src/cors/allow_private_network.rs +++ b/tower-http/src/cors/allow_private_network.rs @@ -39,6 +39,10 @@ impl AllowPrivateNetwork { Self(AllowPrivateNetworkInner::Predicate(Arc::new(f))) } + #[allow( + clippy::declare_interior_mutable_const, + clippy::borrow_interior_mutable_const + )] pub(super) fn to_header( &self, origin: Option<&HeaderValue>, @@ -71,7 +75,7 @@ impl AllowPrivateNetwork { AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts), }; - allow_private_network.then(|| (ALLOW_PRIVATE_NETWORK, TRUE)) + allow_private_network.then_some((ALLOW_PRIVATE_NETWORK, TRUE)) } } @@ -111,11 +115,16 @@ impl Default for AllowPrivateNetworkInner { #[cfg(test)] mod tests { + #![allow( + clippy::declare_interior_mutable_const, + clippy::borrow_interior_mutable_const + )] + use super::AllowPrivateNetwork; use crate::cors::CorsLayer; + use crate::test_helpers::Body; use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response}; - use hyper::Body; use tower::{BoxError, ServiceBuilder, ServiceExt}; use tower_service::Service; diff --git a/tower-http/src/cors/mod.rs b/tower-http/src/cors/mod.rs index 20f2c336..0dac8e6b 100644 --- a/tower-http/src/cors/mod.rs +++ b/tower-http/src/cors/mod.rs @@ -4,13 +4,14 @@ //! //! ``` //! use http::{Request, Response, Method, header}; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! use tower::{ServiceBuilder, ServiceExt, Service}; //! use tower_http::cors::{Any, CorsLayer}; //! use std::convert::Infallible; //! -//! async fn handle(request: Request) -> Result, Infallible> { -//! Ok(Response::new(Body::empty())) +//! async fn handle(request: Request>) -> Result>, Infallible> { +//! Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] @@ -27,7 +28,7 @@ //! //! let request = Request::builder() //! .header(header::ORIGIN, "https://example.com") -//! .body(Body::empty()) +//! .body(Full::default()) //! .unwrap(); //! //! let response = service diff --git a/tower-http/src/decompression/body.rs b/tower-http/src/decompression/body.rs index 58cc40a4..88197bbf 100644 --- a/tower-http/src/decompression/body.rs +++ b/tower-http/src/decompression/body.rs @@ -294,23 +294,23 @@ where type Data = Bytes; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { match self.project().inner.project() { #[cfg(feature = "decompression-gzip")] - BodyInnerProj::Gzip { inner } => inner.poll_data(cx), + BodyInnerProj::Gzip { inner } => inner.poll_frame(cx), #[cfg(feature = "decompression-deflate")] - BodyInnerProj::Deflate { inner } => inner.poll_data(cx), + BodyInnerProj::Deflate { inner } => inner.poll_frame(cx), #[cfg(feature = "decompression-br")] - BodyInnerProj::Brotli { inner } => inner.poll_data(cx), + BodyInnerProj::Brotli { inner } => inner.poll_frame(cx), #[cfg(feature = "decompression-zstd")] - BodyInnerProj::Zstd { inner } => inner.poll_data(cx), - BodyInnerProj::Identity { inner } => match ready!(inner.poll_data(cx)) { - Some(Ok(mut buf)) => { - let bytes = buf.copy_to_bytes(buf.remaining()); - Poll::Ready(Some(Ok(bytes))) + BodyInnerProj::Zstd { inner } => inner.poll_frame(cx), + BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) { + Some(Ok(frame)) => { + let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())); + Poll::Ready(Some(Ok(frame))) } Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), None => Poll::Ready(None), @@ -326,32 +326,6 @@ where BodyInnerProj::Zstd { inner } => match inner.0 {}, } } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.project() { - #[cfg(feature = "decompression-gzip")] - BodyInnerProj::Gzip { inner } => inner.poll_trailers(cx), - #[cfg(feature = "decompression-deflate")] - BodyInnerProj::Deflate { inner } => inner.poll_trailers(cx), - #[cfg(feature = "decompression-br")] - BodyInnerProj::Brotli { inner } => inner.poll_trailers(cx), - #[cfg(feature = "decompression-zstd")] - BodyInnerProj::Zstd { inner } => inner.poll_trailers(cx), - BodyInnerProj::Identity { inner } => inner.poll_trailers(cx).map_err(Into::into), - - #[cfg(not(feature = "decompression-gzip"))] - BodyInnerProj::Gzip { inner } => match inner.0 {}, - #[cfg(not(feature = "decompression-deflate"))] - BodyInnerProj::Deflate { inner } => match inner.0 {}, - #[cfg(not(feature = "decompression-br"))] - BodyInnerProj::Brotli { inner } => match inner.0 {}, - #[cfg(not(feature = "decompression-zstd"))] - BodyInnerProj::Zstd { inner } => match inner.0 {}, - } - } } #[cfg(feature = "decompression-gzip")] diff --git a/tower-http/src/decompression/mod.rs b/tower-http/src/decompression/mod.rs index 265416fe..708df439 100644 --- a/tower-http/src/decompression/mod.rs +++ b/tower-http/src/decompression/mod.rs @@ -3,12 +3,12 @@ //! # Examples //! //! #### Request +//! //! ```rust -//! use bytes::BytesMut; +//! use bytes::Bytes; //! use flate2::{write::GzEncoder, Compression}; //! use http::{header, HeaderValue, Request, Response}; -//! use http_body::Body as _; // for Body::data -//! use hyper::Body; +//! use http_body_util::{Full, BodyExt}; //! use std::{error::Error, io::Write}; //! use tower::{Service, ServiceBuilder, service_fn, ServiceExt}; //! use tower_http::{BoxError, decompression::{DecompressionBody, RequestDecompressionLayer}}; @@ -20,7 +20,7 @@ //! encoder.write_all(b"Hello?")?; //! let request = Request::builder() //! .header(header::CONTENT_ENCODING, "gzip") -//! .body(Body::from(encoder.finish()?))?; +//! .body(Full::from(encoder.finish()?))?; //! //! // Our HTTP server //! let mut server = ServiceBuilder::new() @@ -32,33 +32,31 @@ //! let _response = server.ready().await?.call(request).await?; //! //! // Handler receives request whose body is decoded when read -//! async fn handler(mut req: Request>) -> Result, BoxError>{ -//! let mut data = BytesMut::new(); -//! while let Some(chunk) = req.body_mut().data().await { -//! let chunk = chunk?; -//! data.extend_from_slice(&chunk[..]); -//! } -//! assert_eq!(data.freeze().to_vec(), b"Hello?"); -//! Ok(Response::new(Body::from("Hello, World!"))) +//! async fn handler( +//! mut req: Request>>, +//! ) -> Result>, BoxError>{ +//! let data = req.into_body().collect().await?.to_bytes(); +//! assert_eq!(&data[..], b"Hello?"); +//! Ok(Response::new(Full::from("Hello, World!"))) //! } //! # Ok(()) //! # } //! ``` //! //! #### Response +//! //! ```rust -//! use bytes::BytesMut; +//! use bytes::Bytes; //! use http::{Request, Response}; -//! use http_body::Body as _; // for Body::data -//! use hyper::Body; +//! use http_body_util::{Full, BodyExt}; //! use std::convert::Infallible; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::{compression::Compression, decompression::DecompressionLayer, BoxError}; //! # //! # #[tokio::main] //! # async fn main() -> Result<(), tower_http::BoxError> { -//! # async fn handle(req: Request) -> Result, Infallible> { -//! # let body = Body::from("Hello, World!"); +//! # async fn handle(req: Request>) -> Result>, Infallible> { +//! # let body = Full::from("Hello, World!"); //! # Ok(Response::new(body)) //! # } //! @@ -74,7 +72,7 @@ //! // Call the service. //! // //! // `DecompressionLayer` takes care of setting `Accept-Encoding`. -//! let request = Request::new(Body::empty()); +//! let request = Request::new(Full::::default()); //! //! let response = client //! .ready() @@ -83,13 +81,9 @@ //! .await?; //! //! // Read the body -//! let mut body = response.into_body(); -//! let mut bytes = BytesMut::new(); -//! while let Some(chunk) = body.data().await { -//! let chunk = chunk?; -//! bytes.extend_from_slice(&chunk[..]); -//! } -//! let body = String::from_utf8(bytes.to_vec()).map_err(Into::::into)?; +//! let body = response.into_body(); +//! let bytes = body.collect().await?.to_bytes().to_vec(); +//! let body = String::from_utf8(bytes).map_err(Into::::into)?; //! //! assert_eq!(body, "Hello, World!"); //! # @@ -115,15 +109,16 @@ pub use self::request::service::RequestDecompression; #[cfg(test)] mod tests { + use std::convert::Infallible; use std::io::Write; use super::*; - use crate::compression::Compression; - use bytes::BytesMut; + use crate::test_helpers::Body; + use crate::{compression::Compression, test_helpers::WithTrailers}; use flate2::write::GzEncoder; use http::Response; - use http_body::Body as _; - use hyper::{Body, Client, Error, Request}; + use http::{HeaderMap, HeaderName, Request}; + use http_body_util::BodyExt; use tower::{service_fn, Service, ServiceExt}; #[tokio::test] @@ -137,15 +132,22 @@ mod tests { let res = client.ready().await.unwrap().call(req).await.unwrap(); // read the body, it will be decompressed automatically - let mut body = res.into_body(); - let mut data = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - data.extend_from_slice(&chunk[..]); - } - let decompressed_data = String::from_utf8(data.freeze().to_vec()).unwrap(); + let body = res.into_body(); + let collected = body.collect().await.unwrap(); + let trailers = collected.trailers().cloned().unwrap(); + let decompressed_data = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); assert_eq!(decompressed_data, "Hello, World!"); + + // maintains trailers + assert_eq!(trailers["foo"], "bar"); + } + + async fn handle(_req: Request) -> Result>, Infallible> { + let mut trailers = HeaderMap::new(); + trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap()); + let body = Body::from("Hello, World!").with_trailers(trailers); + Ok(Response::builder().body(body).unwrap()) } #[tokio::test] @@ -159,22 +161,14 @@ mod tests { let res = client.ready().await.unwrap().call(req).await.unwrap(); // read the body, it will be decompressed automatically - let mut body = res.into_body(); - let mut data = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - data.extend_from_slice(&chunk[..]); - } - let decompressed_data = String::from_utf8(data.freeze().to_vec()).unwrap(); + let body = res.into_body(); + let decompressed_data = + String::from_utf8(body.collect().await.unwrap().to_bytes().to_vec()).unwrap(); assert_eq!(decompressed_data, "Hello, World!"); } - async fn handle(_req: Request) -> Result, Error> { - Ok(Response::new(Body::from("Hello, World!"))) - } - - async fn handle_multi_gz(_req: Request) -> Result, Error> { + async fn handle_multi_gz(_req: Request) -> Result, Infallible> { let mut buf = Vec::new(); let mut enc1 = GzEncoder::new(&mut buf, Default::default()); enc1.write_all(b"Hello, ").unwrap(); @@ -192,11 +186,14 @@ mod tests { #[allow(dead_code)] async fn is_compatible_with_hyper() { - let mut client = Decompression::new(Client::new()); + let client = + hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) + .build_http(); + let mut client = Decompression::new(client); let req = Request::new(Body::empty()); - let _: Response> = + let _: Response> = client.ready().await.unwrap().call(req).await.unwrap(); } } diff --git a/tower-http/src/decompression/request/future.rs b/tower-http/src/decompression/request/future.rs index ca6dcfa7..bdb22f8b 100644 --- a/tower-http/src/decompression/request/future.rs +++ b/tower-http/src/decompression/request/future.rs @@ -1,8 +1,11 @@ +use crate::body::UnsyncBoxBody; use crate::compression_utils::AcceptEncoding; use crate::BoxError; use bytes::Buf; use http::{header, HeaderValue, Response, StatusCode}; -use http_body::{combinators::UnsyncBoxBody, Body, Empty}; +use http_body::Body; +use http_body_util::BodyExt; +use http_body_util::Empty; use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; @@ -72,9 +75,9 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project().kind.project() { - StateProj::Inner { fut } => fut - .poll(cx) - .map_ok(|res| res.map(|body| body.map_err(Into::into).boxed_unsync())), + StateProj::Inner { fut } => fut.poll(cx).map_ok(|res| { + res.map(|body| UnsyncBoxBody::new(body.map_err(Into::into).boxed_unsync())) + }), StateProj::Unsupported { accept } => { let res = Response::builder() .header( @@ -84,7 +87,9 @@ where .unwrap_or(HeaderValue::from_static("identity")), ) .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) - .body(Empty::new().map_err(Into::into).boxed_unsync()) + .body(UnsyncBoxBody::new( + Empty::new().map_err(Into::into).boxed_unsync(), + )) .unwrap(); Poll::Ready(Ok(res)) } diff --git a/tower-http/src/decompression/request/mod.rs b/tower-http/src/decompression/request/mod.rs index 42621a17..da3d9409 100644 --- a/tower-http/src/decompression/request/mod.rs +++ b/tower-http/src/decompression/request/mod.rs @@ -6,14 +6,12 @@ pub(super) mod service; mod tests { use super::service::RequestDecompression; use crate::decompression::DecompressionBody; - use bytes::BytesMut; + use crate::test_helpers::Body; use flate2::{write::GzEncoder, Compression}; - use http::{header, Response, StatusCode}; - use http_body::Body as _; - use hyper::{Body, Error, Request, Server}; - use std::io::Write; - use std::net::SocketAddr; - use tower::{make::Shared, service_fn, Service, ServiceExt}; + use http::{header, Request, Response, StatusCode}; + use http_body_util::BodyExt; + use std::{convert::Infallible, io::Write}; + use tower::{service_fn, Service, ServiceExt}; #[tokio::test] async fn decompress_accepted_encoding() { @@ -48,7 +46,7 @@ mod tests { async fn assert_request_is_decompressed( req: Request>, - ) -> Result, Error> { + ) -> Result, Infallible> { let (parts, mut body) = req.into_parts(); let body = read_body(&mut body).await; @@ -60,7 +58,7 @@ mod tests { async fn assert_request_is_passed_through( req: Request>, - ) -> Result, Error> { + ) -> Result, Infallible> { let (parts, mut body) = req.into_parts(); let body = read_body(&mut body).await; @@ -72,7 +70,7 @@ mod tests { async fn should_not_be_called( _: Request>, - ) -> Result, Error> { + ) -> Result, Infallible> { panic!("Inner service should not be called"); } @@ -87,23 +85,6 @@ mod tests { } async fn read_body(body: &mut DecompressionBody) -> Vec { - let mut data = BytesMut::new(); - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - data.extend_from_slice(&chunk[..]); - } - data.freeze().to_vec() - } - - #[allow(dead_code)] - async fn is_compatible_with_hyper() { - let svc = service_fn(assert_request_is_decompressed); - let svc = RequestDecompression::new(svc); - - let make_service = Shared::new(svc); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let server = Server::bind(&addr).serve(make_service); - server.await.unwrap(); + body.collect().await.unwrap().to_bytes().to_vec() } } diff --git a/tower-http/src/decompression/request/service.rs b/tower-http/src/decompression/request/service.rs index 443f73a9..663436e5 100644 --- a/tower-http/src/decompression/request/service.rs +++ b/tower-http/src/decompression/request/service.rs @@ -1,5 +1,6 @@ use super::future::RequestDecompressionFuture as ResponseFuture; use super::layer::RequestDecompressionLayer; +use crate::body::UnsyncBoxBody; use crate::compression_utils::CompressionLevel; use crate::{ compression_utils::AcceptEncoding, decompression::body::BodyInner, @@ -7,7 +8,7 @@ use crate::{ }; use bytes::Buf; use http::{header, Request, Response}; -use http_body::{combinators::UnsyncBoxBody, Body}; +use http_body::Body; use std::task::{Context, Poll}; use tower_service::Service; diff --git a/tower-http/src/follow_redirect/mod.rs b/tower-http/src/follow_redirect/mod.rs index 8c13415e..516fabf7 100644 --- a/tower-http/src/follow_redirect/mod.rs +++ b/tower-http/src/follow_redirect/mod.rs @@ -18,7 +18,8 @@ //! //! ``` //! use http::{Request, Response}; -//! use hyper::Body; +//! use bytes::Bytes; +//! use http_body_util::Full; //! use tower::{Service, ServiceBuilder, ServiceExt}; //! use tower_http::follow_redirect::{FollowRedirectLayer, RequestUri}; //! @@ -32,7 +33,7 @@ //! # .status(http::StatusCode::MOVED_PERMANENTLY) //! # .header(http::header::LOCATION, dest); //! # } -//! # Ok::<_, std::convert::Infallible>(res.body(Body::empty()).unwrap()) +//! # Ok::<_, std::convert::Infallible>(res.body(Full::::default()).unwrap()) //! # }); //! let mut client = ServiceBuilder::new() //! .layer(FollowRedirectLayer::new()) @@ -40,7 +41,7 @@ //! //! let request = Request::builder() //! .uri("https://rust-lang.org/") -//! .body(Body::empty()) +//! .body(Full::::default()) //! .unwrap(); //! //! let response = client.ready().await?.call(request).await?; @@ -56,7 +57,8 @@ //! //! ``` //! use http::{Request, Response}; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! use tower::{Service, ServiceBuilder, ServiceExt}; //! use tower_http::follow_redirect::{ //! policy::{self, PolicyExt}, @@ -65,14 +67,14 @@ //! //! #[derive(Debug)] //! enum MyError { -//! Hyper(hyper::Error), //! TooManyRedirects, +//! Other(tower::BoxError), //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), MyError> { //! # let http_client = -//! # tower::service_fn(|_: Request| async { Ok(Response::new(Body::empty())) }); +//! # tower::service_fn(|_: Request>| async { Ok(Response::new(Full::::default())) }); //! let policy = policy::Limited::new(10) // Set the maximum number of redirections to 10. //! // Return an error when the limit was reached. //! .or::<_, (), _>(policy::redirect_fn(|_| Err(MyError::TooManyRedirects))) @@ -81,7 +83,7 @@ //! //! let mut client = ServiceBuilder::new() //! .layer(FollowRedirectLayer::with_policy(policy)) -//! .map_err(MyError::Hyper) +//! .map_err(MyError::Other) //! .service(http_client); //! //! // ... @@ -323,6 +325,7 @@ where /// /// The value differs from the original request's effective URI if the middleware has followed /// redirections. +#[derive(Clone)] pub struct RequestUri(pub Uri); #[derive(Debug)] @@ -385,7 +388,8 @@ fn resolve_uri(relative: &str, base: &Uri) -> Option { #[cfg(test)] mod tests { use super::{policy::*, *}; - use hyper::{header::LOCATION, Body}; + use crate::test_helpers::Body; + use http::header::LOCATION; use std::convert::Infallible; use tower::{ServiceBuilder, ServiceExt}; diff --git a/tower-http/src/follow_redirect/policy/mod.rs b/tower-http/src/follow_redirect/policy/mod.rs index adbd4d27..8e5d39ce 100644 --- a/tower-http/src/follow_redirect/policy/mod.rs +++ b/tower-http/src/follow_redirect/policy/mod.rs @@ -122,12 +122,12 @@ pub trait PolicyExt { /// /// ``` /// use bytes::Bytes; - /// use hyper::Body; + /// use http_body_util::Full; /// use tower_http::follow_redirect::policy::{self, clone_body_fn, Limited, PolicyExt}; /// /// enum MyBody { /// Bytes(Bytes), - /// Hyper(Body), + /// Full(Full), /// } /// /// let policy = Limited::default().and::<_, _, ()>(clone_body_fn(|body| { diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 6719ddbd..4c731e83 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -24,10 +24,11 @@ //! trace::TraceLayer, //! validate_request::ValidateRequestHeaderLayer, //! }; -//! use tower::{ServiceBuilder, service_fn, make::Shared}; +//! use tower::{ServiceBuilder, service_fn, BoxError}; //! use http::{Request, Response, header::{HeaderName, CONTENT_TYPE, AUTHORIZATION}}; -//! use hyper::{Body, Error, server::Server, service::make_service_fn}; //! use std::{sync::Arc, net::SocketAddr, convert::Infallible, iter::once}; +//! use bytes::Bytes; +//! use http_body_util::Full; //! # struct DatabaseConnectionPool; //! # impl DatabaseConnectionPool { //! # fn new() -> DatabaseConnectionPool { DatabaseConnectionPool } @@ -37,7 +38,7 @@ //! //! // Our request handler. This is where we would implement the application logic //! // for responding to HTTP requests... -//! async fn handler(request: Request) -> Result, Error> { +//! async fn handler(request: Request>) -> Result>, BoxError> { //! // ... //! # todo!() //! } @@ -75,13 +76,8 @@ //! .layer(ValidateRequestHeaderLayer::accept("application/json")) //! // Wrap a `Service` in our middleware stack //! .service_fn(handler); -//! -//! // And run our service using `hyper` -//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -//! Server::bind(&addr) -//! .serve(Shared::new(service)) -//! .await -//! .expect("server error"); +//! # let mut service = service; +//! # tower::Service::call(&mut service, Request::new(Full::default())); //! } //! ``` //! @@ -100,11 +96,14 @@ //! classify::StatusInRangeAsFailures, //! }; //! use tower::{ServiceBuilder, Service, ServiceExt}; -//! use hyper::Body; +//! use hyper_util::{rt::TokioExecutor, client::legacy::Client}; +//! use http_body_util::Full; +//! use bytes::Bytes; //! use http::{Request, HeaderValue, header::USER_AGENT}; //! //! #[tokio::main] //! async fn main() { +//! let client = Client::builder(TokioExecutor::new()).build_http(); //! let mut client = ServiceBuilder::new() //! // Add tracing and consider server errors and client //! // errors as failures. @@ -118,15 +117,15 @@ //! )) //! // Decompress response bodies //! .layer(DecompressionLayer::new()) -//! // Wrap a `hyper::Client` in our middleware stack. -//! // This is possible because `hyper::Client` implements +//! // Wrap a `Client` in our middleware stack. +//! // This is possible because `Client` implements //! // `tower::Service`. -//! .service(hyper::Client::new()); +//! .service(client); //! //! // Make a request //! let request = Request::builder() //! .uri("http://example.com") -//! .body(Body::empty()) +//! .body(Full::::default()) //! .unwrap(); //! //! let response = client @@ -183,7 +182,6 @@ clippy::todo, clippy::empty_enum, clippy::enum_glob_use, - clippy::pub_enum_variant_names, clippy::mem_forget, clippy::unused_self, clippy::filter_map_next, @@ -211,7 +209,7 @@ nonstandard_style, missing_docs )] -#![deny(unreachable_pub, private_in_public)] +#![deny(unreachable_pub)] #![allow( elided_lifetimes_in_paths, // TODO: Remove this once the MSRV bumps to 1.42.0 or above. @@ -225,6 +223,9 @@ #[macro_use] pub(crate) mod macros; +#[cfg(test)] +mod test_helpers; + #[cfg(feature = "auth")] pub mod auth; @@ -342,6 +343,8 @@ pub use self::builder::ServiceBuilderExt; #[cfg(feature = "validate-request")] pub mod validate_request; +pub mod body; + /// The latency unit used to report latencies by middleware. #[non_exhaustive] #[derive(Copy, Clone, Debug)] diff --git a/tower-http/src/limit/body.rs b/tower-http/src/limit/body.rs index 4e746a5d..4e540f8b 100644 --- a/tower-http/src/limit/body.rs +++ b/tower-http/src/limit/body.rs @@ -1,6 +1,7 @@ use bytes::Bytes; -use http::{HeaderMap, HeaderValue, Response, StatusCode}; -use http_body::{Body, Full, SizeHint}; +use http::{HeaderValue, Response, StatusCode}; +use http_body::{Body, SizeHint}; +use http_body_util::Full; use pin_project_lite::pin_project; use std::pin::Pin; use std::task::{Context, Poll}; @@ -52,25 +53,13 @@ where type Data = Bytes; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { match self.project().inner.project() { - BodyProj::PayloadTooLarge { body } => body.poll_data(cx).map_err(|err| match err {}), - BodyProj::Body { body } => body.poll_data(cx), - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.project() { - BodyProj::PayloadTooLarge { body } => { - body.poll_trailers(cx).map_err(|err| match err {}) - } - BodyProj::Body { body } => body.poll_trailers(cx), + BodyProj::PayloadTooLarge { body } => body.poll_frame(cx).map_err(|err| match err {}), + BodyProj::Body { body } => body.poll_frame(cx), } } diff --git a/tower-http/src/limit/mod.rs b/tower-http/src/limit/mod.rs index a71ddbe7..3f2fede3 100644 --- a/tower-http/src/limit/mod.rs +++ b/tower-http/src/limit/mod.rs @@ -24,15 +24,15 @@ //! use bytes::Bytes; //! use std::convert::Infallible; //! use http::{Request, Response, StatusCode, HeaderValue, header::CONTENT_LENGTH}; -//! use http_body::{Limited, LengthLimitError}; +//! use http_body_util::{LengthLimitError}; //! use tower::{Service, ServiceExt, ServiceBuilder}; -//! use tower_http::limit::RequestBodyLimitLayer; -//! use hyper::Body; +//! use tower_http::{body::Limited, limit::RequestBodyLimitLayer}; +//! use http_body_util::Full; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! async fn handle(req: Request>) -> Result, Infallible> { -//! panic!("This will not be hit") +//! async fn handle(req: Request>>) -> Result>, Infallible> { +//! panic!("This should not be hit") //! } //! //! let mut svc = ServiceBuilder::new() @@ -43,10 +43,11 @@ //! // Call the service with a header that indicates the body is too large. //! let mut request = Request::builder() //! .header(CONTENT_LENGTH, HeaderValue::from_static("5000")) -//! .body(Body::empty()) +//! .body(Full::::default()) //! .unwrap(); //! -//! let response = svc.ready().await?.call(request).await?; +//! // let response = svc.ready().await?.call(request).await?; +//! let response = svc.call(request).await?; //! //! assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); //! # @@ -58,8 +59,8 @@ //! //! If a `Content-Length` header is not present, then the body will be read //! until the configured limit has been reached. If the payload is larger than -//! the limit, the [`http_body::Limited`] body will return an error. This -//! error can be inspected to determine if it is a [`http_body::LengthLimitError`] +//! the limit, the [`http_body_util::Limited`] body will return an error. This +//! error can be inspected to determine if it is a [`http_body_util::LengthLimitError`] //! and return an appropriate response in such case. //! //! Note that no error will be generated if the body is never read. Similarly, @@ -71,19 +72,20 @@ //! # use bytes::Bytes; //! # use std::convert::Infallible; //! # use http::{Request, Response, StatusCode}; -//! # use http_body::{Limited, LengthLimitError}; +//! # use http_body_util::LengthLimitError; //! # use tower::{Service, ServiceExt, ServiceBuilder, BoxError}; -//! # use tower_http::limit::RequestBodyLimitLayer; -//! # use hyper::Body; +//! # use tower_http::{body::Limited, limit::RequestBodyLimitLayer}; +//! # use http_body_util::Full; +//! # use http_body_util::BodyExt; //! # //! # #[tokio::main] //! # async fn main() -> Result<(), BoxError> { -//! async fn handle(req: Request>) -> Result, BoxError> { -//! let data = match hyper::body::to_bytes(req.into_body()).await { -//! Ok(data) => data, +//! async fn handle(req: Request>>) -> Result>, BoxError> { +//! let data = match req.into_body().collect().await { +//! Ok(collected) => collected.to_bytes(), //! Err(err) => { //! if let Some(_) = err.downcast_ref::() { -//! let mut resp = Response::new(Body::empty()); +//! let mut resp = Response::new(Full::default()); //! *resp.status_mut() = StatusCode::PAYLOAD_TOO_LARGE; //! return Ok(resp); //! } else { @@ -92,7 +94,7 @@ //! } //! }; //! -//! Ok(Response::new(Body::empty())) +//! Ok(Response::new(Full::default())) //! } //! //! let mut svc = ServiceBuilder::new() @@ -101,14 +103,14 @@ //! .service_fn(handle); //! //! // Call the service. -//! let request = Request::new(Body::empty()); +//! let request = Request::new(Full::::default()); //! //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.status(), StatusCode::OK); //! //! // Call the service with a body that is too large. -//! let request = Request::new(Body::from(Bytes::from(vec![0u8; 4097]))); +//! let request = Request::new(Full::::from(Bytes::from(vec![0u8; 4097]))); //! //! let response = svc.ready().await?.call(request).await?; //! @@ -123,7 +125,7 @@ //! If enforcement of body size limits is desired without preemptively //! handling requests with a `Content-Length` header indicating an over-sized //! request, consider using [`MapRequestBody`] to wrap the request body with -//! [`http_body::Limited`] and checking for [`http_body::LengthLimitError`] +//! [`http_body_util::Limited`] and checking for [`http_body_util::LengthLimitError`] //! like in the previous example. //! //! [`MapRequestBody`]: crate::map_request_body diff --git a/tower-http/src/limit/service.rs b/tower-http/src/limit/service.rs index 66ae41fe..fdf65d25 100644 --- a/tower-http/src/limit/service.rs +++ b/tower-http/src/limit/service.rs @@ -1,6 +1,7 @@ use super::{RequestBodyLimitLayer, ResponseBody, ResponseFuture}; +use crate::body::Limited; use http::{Request, Response}; -use http_body::{Body, Limited}; +use http_body::Body; use std::task::{Context, Poll}; use tower_service::Service; @@ -56,7 +57,7 @@ where None => self.limit, }; - let req = req.map(|body| Limited::new(body, body_limit)); + let req = req.map(|body| Limited::new(http_body_util::Limited::new(body, body_limit))); ResponseFuture::new(self.inner.call(req)) } diff --git a/tower-http/src/macros.rs b/tower-http/src/macros.rs index 6641199b..f58d34a6 100644 --- a/tower-http/src/macros.rs +++ b/tower-http/src/macros.rs @@ -46,19 +46,11 @@ macro_rules! opaque_body { type Error = <$actual as http_body::Body>::Error; #[inline] - fn poll_data( + fn poll_frame( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll>> { - self.project().inner.poll_data(cx) - } - - #[inline] - fn poll_trailers( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll, Self::Error>> { - self.project().inner.poll_trailers(cx) + ) -> std::task::Poll, Self::Error>>> { + self.project().inner.poll_frame(cx) } #[inline] diff --git a/tower-http/src/map_request_body.rs b/tower-http/src/map_request_body.rs index 9389322b..dd067e92 100644 --- a/tower-http/src/map_request_body.rs +++ b/tower-http/src/map_request_body.rs @@ -3,71 +3,51 @@ //! # Example //! //! ``` +//! use http_body_util::Full; //! use bytes::Bytes; //! use http::{Request, Response}; -//! use hyper::Body; //! use std::convert::Infallible; //! use std::{pin::Pin, task::{ready, Context, Poll}}; //! use tower::{ServiceBuilder, service_fn, ServiceExt, Service}; //! use tower_http::map_request_body::MapRequestBodyLayer; //! -//! // A wrapper for a `hyper::Body` that prints the size of data chunks -//! struct PrintChunkSizesBody { -//! inner: Body, +//! // A wrapper for a `Full` +//! struct BodyWrapper { +//! inner: Full, //! } //! -//! impl PrintChunkSizesBody { -//! fn new(inner: Body) -> Self { +//! impl BodyWrapper { +//! fn new(inner: Full) -> Self { //! Self { inner } //! } //! } //! -//! impl http_body::Body for PrintChunkSizesBody { -//! type Data = Bytes; -//! type Error = hyper::Error; -//! -//! fn poll_data( -//! mut self: Pin<&mut Self>, -//! cx: &mut Context<'_>, -//! ) -> Poll>> { -//! if let Some(chunk) = ready!(Pin::new(&mut self.inner).poll_data(cx)?) { -//! println!("chunk size = {}", chunk.len()); -//! Poll::Ready(Some(Ok(chunk))) -//! } else { -//! Poll::Ready(None) -//! } -//! } -//! -//! fn poll_trailers( -//! mut self: Pin<&mut Self>, -//! cx: &mut Context<'_>, -//! ) -> Poll, Self::Error>> { -//! Pin::new(&mut self.inner).poll_trailers(cx) -//! } -//! -//! fn is_end_stream(&self) -> bool { -//! self.inner.is_end_stream() -//! } -//! -//! fn size_hint(&self) -> http_body::SizeHint { -//! self.inner.size_hint() -//! } +//! impl http_body::Body for BodyWrapper { +//! // ... +//! # type Data = Bytes; +//! # type Error = tower::BoxError; +//! # fn poll_frame( +//! # self: Pin<&mut Self>, +//! # cx: &mut Context<'_> +//! # ) -> Poll, Self::Error>>> { unimplemented!() } +//! # fn is_end_stream(&self) -> bool { unimplemented!() } +//! # fn size_hint(&self) -> http_body::SizeHint { unimplemented!() } //! } //! -//! async fn handle(_: Request) -> Result, Infallible> { +//! async fn handle(_: Request) -> Result>, Infallible> { //! // ... -//! # Ok(Response::new(Body::empty())) +//! # Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut svc = ServiceBuilder::new() -//! // Wrap response bodies in `PrintChunkSizesBody` -//! .layer(MapRequestBodyLayer::new(PrintChunkSizesBody::new)) +//! // Wrap response bodies in `BodyWrapper` +//! .layer(MapRequestBodyLayer::new(BodyWrapper::new)) //! .service_fn(handle); //! //! // Call the service -//! let request = Request::new(Body::empty()); +//! let request = Request::new(Full::default()); //! //! svc.ready().await?.call(request).await?; //! # Ok(()) diff --git a/tower-http/src/map_response_body.rs b/tower-http/src/map_response_body.rs index 5d0bb4c8..5329e5d5 100644 --- a/tower-http/src/map_response_body.rs +++ b/tower-http/src/map_response_body.rs @@ -5,69 +5,49 @@ //! ``` //! use bytes::Bytes; //! use http::{Request, Response}; -//! use hyper::Body; +//! use http_body_util::Full; //! use std::convert::Infallible; //! use std::{pin::Pin, task::{ready, Context, Poll}}; //! use tower::{ServiceBuilder, service_fn, ServiceExt, Service}; //! use tower_http::map_response_body::MapResponseBodyLayer; //! -//! // A wrapper for a `hyper::Body` that prints the size of data chunks -//! struct PrintChunkSizesBody { -//! inner: Body, +//! // A wrapper for a `Full` +//! struct BodyWrapper { +//! inner: Full, //! } //! -//! impl PrintChunkSizesBody { -//! fn new(inner: Body) -> Self { +//! impl BodyWrapper { +//! fn new(inner: Full) -> Self { //! Self { inner } //! } //! } //! -//! impl http_body::Body for PrintChunkSizesBody { -//! type Data = Bytes; -//! type Error = hyper::Error; -//! -//! fn poll_data( -//! mut self: Pin<&mut Self>, -//! cx: &mut Context<'_>, -//! ) -> Poll>> { -//! if let Some(chunk) = ready!(Pin::new(&mut self.inner).poll_data(cx)?) { -//! println!("chunk size = {}", chunk.len()); -//! Poll::Ready(Some(Ok(chunk))) -//! } else { -//! Poll::Ready(None) -//! } -//! } -//! -//! fn poll_trailers( -//! mut self: Pin<&mut Self>, -//! cx: &mut Context<'_>, -//! ) -> Poll, Self::Error>> { -//! Pin::new(&mut self.inner).poll_trailers(cx) -//! } -//! -//! fn is_end_stream(&self) -> bool { -//! self.inner.is_end_stream() -//! } -//! -//! fn size_hint(&self) -> http_body::SizeHint { -//! self.inner.size_hint() -//! } +//! impl http_body::Body for BodyWrapper { +//! // ... +//! # type Data = Bytes; +//! # type Error = tower::BoxError; +//! # fn poll_frame( +//! # self: Pin<&mut Self>, +//! # cx: &mut Context<'_> +//! # ) -> Poll, Self::Error>>> { unimplemented!() } +//! # fn is_end_stream(&self) -> bool { unimplemented!() } +//! # fn size_hint(&self) -> http_body::SizeHint { unimplemented!() } //! } //! -//! async fn handle(_: Request) -> Result, Infallible> { +//! async fn handle(_: Request) -> Result>, Infallible> { //! // ... -//! # Ok(Response::new(Body::empty())) +//! # Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let mut svc = ServiceBuilder::new() -//! // Wrap response bodies in `PrintChunkSizesBody` -//! .layer(MapResponseBodyLayer::new(PrintChunkSizesBody::new)) +//! // Wrap response bodies in `BodyWrapper` +//! .layer(MapResponseBodyLayer::new(BodyWrapper::new)) //! .service_fn(handle); //! //! // Call the service -//! let request = Request::new(Body::from("foobar")); +//! let request = Request::new(Full::::from("foobar")); //! //! svc.ready().await?.call(request).await?; //! # Ok(()) diff --git a/tower-http/src/metrics/in_flight_requests.rs b/tower-http/src/metrics/in_flight_requests.rs index 2ca33619..dbb5e2ff 100644 --- a/tower-http/src/metrics/in_flight_requests.rs +++ b/tower-http/src/metrics/in_flight_requests.rs @@ -10,12 +10,13 @@ //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::metrics::InFlightRequestsLayer; //! use http::{Request, Response}; -//! use hyper::Body; +//! use bytes::Bytes; +//! use http_body_util::Full; //! use std::{time::Duration, convert::Infallible}; //! -//! async fn handle(req: Request) -> Result, Infallible> { +//! async fn handle(req: Request>) -> Result>, Infallible> { //! // ... -//! # Ok(Response::new(Body::empty())) +//! # Ok(Response::new(Full::default())) //! } //! //! async fn update_in_flight_requests_metric(count: usize) { @@ -44,7 +45,7 @@ //! let response = service //! .ready() //! .await? -//! .call(Request::new(Body::empty())) +//! .call(Request::new(Full::default())) //! .await?; //! # Ok(()) //! # } @@ -266,19 +267,11 @@ where type Error = B::Error; #[inline] - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - self.project().inner.poll_data(cx) - } - - #[inline] - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - self.project().inner.poll_trailers(cx) + ) -> Poll, Self::Error>>> { + self.project().inner.poll_frame(cx) } #[inline] @@ -296,8 +289,8 @@ where mod tests { #[allow(unused_imports)] use super::*; + use crate::test_helpers::Body; use http::Request; - use hyper::Body; use tower::{BoxError, ServiceBuilder}; #[tokio::test] @@ -324,7 +317,7 @@ mod tests { assert_eq!(counter.get(), 1); let body = response.into_body(); - hyper::body::to_bytes(body).await.unwrap(); + crate::test_helpers::to_bytes(body).await.unwrap(); assert_eq!(counter.get(), 0); } diff --git a/tower-http/src/normalize_path.rs b/tower-http/src/normalize_path.rs index 91ee5b92..efc7be52 100644 --- a/tower-http/src/normalize_path.rs +++ b/tower-http/src/normalize_path.rs @@ -8,15 +8,16 @@ //! ``` //! use tower_http::normalize_path::NormalizePathLayer; //! use http::{Request, Response, StatusCode}; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! use std::{iter::once, convert::Infallible}; //! use tower::{ServiceBuilder, Service, ServiceExt}; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! async fn handle(req: Request) -> Result, Infallible> { +//! async fn handle(req: Request>) -> Result>, Infallible> { //! // `req.uri().path()` will not have trailing slashes -//! # Ok(Response::new(Body::empty())) +//! # Ok(Response::new(Full::default())) //! } //! //! let mut service = ServiceBuilder::new() @@ -28,7 +29,7 @@ //! let request = Request::builder() //! // `handle` will see `/foo` //! .uri("/foo/") -//! .body(Body::empty())?; +//! .body(Full::default())?; //! //! service.ready().await?.call(request).await?; //! # diff --git a/tower-http/src/propagate_header.rs b/tower-http/src/propagate_header.rs index 214b32d8..6c77ec32 100644 --- a/tower-http/src/propagate_header.rs +++ b/tower-http/src/propagate_header.rs @@ -7,13 +7,14 @@ //! use std::convert::Infallible; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::propagate_header::PropagateHeaderLayer; -//! use hyper::Body; +//! use bytes::Bytes; +//! use http_body_util::Full; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! async fn handle(req: Request) -> Result, Infallible> { +//! async fn handle(req: Request>) -> Result>, Infallible> { //! // ... -//! # Ok(Response::new(Body::empty())) +//! # Ok(Response::new(Full::default())) //! } //! //! let mut svc = ServiceBuilder::new() @@ -24,7 +25,7 @@ //! // Call the service. //! let request = Request::builder() //! .header("x-request-id", "1337") -//! .body(Body::empty())?; +//! .body(Full::default())?; //! //! let response = svc.ready().await?.call(request).await?; //! diff --git a/tower-http/src/request_id.rs b/tower-http/src/request_id.rs index b328310c..1db2d02a 100644 --- a/tower-http/src/request_id.rs +++ b/tower-http/src/request_id.rs @@ -8,12 +8,13 @@ //! use tower_http::request_id::{ //! SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId, //! }; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! # let handler = tower::service_fn(|request: Request| async move { +//! # let handler = tower::service_fn(|request: Request>| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) //! # }); //! # @@ -47,7 +48,7 @@ //! .layer(PropagateRequestIdLayer::new(x_request_id)) //! .service(handler); //! -//! let request = Request::new(Body::empty()); +//! let request = Request::new(Full::default()); //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.headers()["x-request-id"], "0"); @@ -65,11 +66,12 @@ //! # use tower_http::request_id::{ //! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId, //! # }; -//! # use hyper::Body; +//! # use bytes::Bytes; +//! # use http_body_util::Full; //! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! # let handler = tower::service_fn(|request: Request| async move { +//! # let handler = tower::service_fn(|request: Request>| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) //! # }); //! # #[derive(Clone, Default)] @@ -92,7 +94,7 @@ //! .propagate_x_request_id() //! .service(handler); //! -//! let request = Request::new(Body::empty()); +//! let request = Request::new(Full::default()); //! let response = svc.ready().await?.call(request).await?; //! //! assert_eq!(response.headers()["x-request-id"], "0"); @@ -118,11 +120,12 @@ //! # use tower_http::request_id::{ //! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId, //! # }; -//! # use hyper::Body; +//! # use http_body_util::Full; +//! # use bytes::Bytes; //! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}}; //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! # let handler = tower::service_fn(|request: Request| async move { +//! # let handler = tower::service_fn(|request: Request>| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) //! # }); //! # #[derive(Clone, Default)] @@ -481,8 +484,9 @@ impl MakeRequestId for MakeRequestUuid { #[cfg(test)] mod tests { + use crate::test_helpers::Body; use crate::ServiceBuilderExt as _; - use hyper::{Body, Response}; + use http::Response; use std::{ convert::Infallible, sync::{ diff --git a/tower-http/src/sensitive_headers.rs b/tower-http/src/sensitive_headers.rs index 5249b43e..3bd081db 100644 --- a/tower-http/src/sensitive_headers.rs +++ b/tower-http/src/sensitive_headers.rs @@ -8,12 +8,13 @@ //! use tower_http::sensitive_headers::SetSensitiveHeadersLayer; //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use http::{Request, Response, header::AUTHORIZATION}; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! use std::{iter::once, convert::Infallible}; //! -//! async fn handle(req: Request) -> Result, Infallible> { +//! async fn handle(req: Request>) -> Result>, Infallible> { //! // ... -//! # Ok(Response::new(Body::empty())) +//! # Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] @@ -33,7 +34,7 @@ //! let response = service //! .ready() //! .await? -//! .call(Request::new(Body::empty())) +//! .call(Request::new(Full::default())) //! .await?; //! # Ok(()) //! # } @@ -56,10 +57,11 @@ //! use http::header; //! use std::sync::Arc; //! # use http::{Request, Response}; -//! # use hyper::Body; +//! # use bytes::Bytes; +//! # use http_body_util::Full; //! # use std::convert::Infallible; -//! # async fn handle(req: Request) -> Result, Infallible> { -//! # Ok(Response::new(Body::empty())) +//! # async fn handle(req: Request>) -> Result>, Infallible> { +//! # Ok(Response::new(Full::default())) //! # } //! //! # #[tokio::main] diff --git a/tower-http/src/services/fs/mod.rs b/tower-http/src/services/fs/mod.rs index 8902c31e..32dd6f1c 100644 --- a/tower-http/src/services/fs/mod.rs +++ b/tower-http/src/services/fs/mod.rs @@ -2,8 +2,7 @@ use bytes::Bytes; use futures_util::Stream; -use http::HeaderMap; -use http_body::Body; +use http_body::{Body, Frame}; use pin_project_lite::pin_project; use std::{ io, @@ -67,17 +66,14 @@ where type Data = Bytes; type Error = io::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - self.project().reader.poll_next(cx) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + ) -> Poll, Self::Error>>> { + match std::task::ready!(self.project().reader.poll_next(cx)) { + Some(Ok(chunk)) => Poll::Ready(Some(Ok(Frame::data(chunk)))), + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } } } diff --git a/tower-http/src/services/fs/serve_dir/future.rs b/tower-http/src/services/fs/serve_dir/future.rs index 4ffa4cf8..8b255f72 100644 --- a/tower-http/src/services/fs/serve_dir/future.rs +++ b/tower-http/src/services/fs/serve_dir/future.rs @@ -2,14 +2,16 @@ use super::{ open_file::{FileOpened, FileRequestExtent, OpenFileOutput}, DefaultServeDirFallback, ResponseBody, }; -use crate::{content_encoding::Encoding, services::fs::AsyncReadBody, BoxError}; +use crate::{ + body::UnsyncBoxBody, content_encoding::Encoding, services::fs::AsyncReadBody, BoxError, +}; use bytes::Bytes; use futures_util::future::{BoxFuture, FutureExt, TryFutureExt}; use http::{ header::{self, ALLOW}, HeaderValue, Request, Response, StatusCode, }; -use http_body::{Body, Empty, Full}; +use http_body_util::{BodyExt, Empty, Full}; use pin_project_lite::pin_project; use std::{ convert::Infallible, @@ -199,11 +201,13 @@ where .map_ok(|response| { response .map(|body| { - body.map_err(|err| match err.into().downcast::() { - Ok(err) => *err, - Err(err) => io::Error::new(io::ErrorKind::Other, err), - }) - .boxed_unsync() + UnsyncBoxBody::new( + body.map_err(|err| match err.into().downcast::() { + Ok(err) => *err, + Err(err) => io::Error::new(io::ErrorKind::Other, err), + }) + .boxed_unsync(), + ) }) .map(ResponseBody::new) }) @@ -247,14 +251,14 @@ fn build_response(output: FileOpened) -> Response { } else { let body = if let Some(file) = maybe_file { let range_size = range.end() - range.start() + 1; - ResponseBody::new( + ResponseBody::new(UnsyncBoxBody::new( AsyncReadBody::with_capacity_limited( file, output.chunk_size, range_size, ) .boxed_unsync(), - ) + )) } else { empty_body() }; @@ -289,9 +293,9 @@ fn build_response(output: FileOpened) -> Response { // Not a range request None => { let body = if let Some(file) = maybe_file { - ResponseBody::new( + ResponseBody::new(UnsyncBoxBody::new( AsyncReadBody::with_capacity(file, output.chunk_size).boxed_unsync(), - ) + )) } else { empty_body() }; @@ -306,10 +310,10 @@ fn build_response(output: FileOpened) -> Response { fn body_from_bytes(bytes: Bytes) -> ResponseBody { let body = Full::from(bytes).map_err(|err| match err {}).boxed_unsync(); - ResponseBody::new(body) + ResponseBody::new(UnsyncBoxBody::new(body)) } fn empty_body() -> ResponseBody { let body = Empty::new().map_err(|err| match err {}).boxed_unsync(); - ResponseBody::new(body) + ResponseBody::new(UnsyncBoxBody::new(body)) } diff --git a/tower-http/src/services/fs/serve_dir/mod.rs b/tower-http/src/services/fs/serve_dir/mod.rs index 525cb680..ed643d56 100644 --- a/tower-http/src/services/fs/serve_dir/mod.rs +++ b/tower-http/src/services/fs/serve_dir/mod.rs @@ -1,12 +1,13 @@ use self::future::ResponseFuture; use crate::{ + body::UnsyncBoxBody, content_encoding::{encodings, SupportedEncodings}, set_status::SetStatus, }; use bytes::Bytes; use futures_util::FutureExt; use http::{header, HeaderValue, Method, Request, Response, StatusCode}; -use http_body::{combinators::UnsyncBoxBody, Body, Empty}; +use http_body_util::{BodyExt, Empty}; use percent_encoding::percent_decode; use std::{ convert::Infallible, @@ -47,15 +48,6 @@ const DEFAULT_CAPACITY: usize = 65536; /// // This will serve files in the "assets" directory and /// // its subdirectories /// let service = ServeDir::new("assets"); -/// -/// # async { -/// // Run our service using `hyper` -/// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); -/// hyper::Server::bind(&addr) -/// .serve(tower::make::Shared::new(service)) -/// .await -/// .expect("server error"); -/// # }; /// ``` #[derive(Clone, Debug)] pub struct ServeDir { @@ -216,15 +208,6 @@ impl ServeDir { /// let service = ServeDir::new("assets") /// // respond with `not_found.html` for missing files /// .fallback(ServeFile::new("assets/not_found.html")); - /// - /// # async { - /// // Run our service using `hyper` - /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); - /// hyper::Server::bind(&addr) - /// .serve(tower::make::Shared::new(service)) - /// .await - /// .expect("server error"); - /// # }; /// ``` pub fn fallback(self, new_fallback: F2) -> ServeDir { ServeDir { @@ -251,15 +234,6 @@ impl ServeDir { /// let service = ServeDir::new("assets") /// // respond with `404 Not Found` and the contents of `not_found.html` for missing files /// .not_found_service(ServeFile::new("assets/not_found.html")); - /// - /// # async { - /// // Run our service using `hyper` - /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); - /// hyper::Server::bind(&addr) - /// .serve(tower::make::Shared::new(service)) - /// .await - /// .expect("server error"); - /// # }; /// ``` /// /// Setups like this are often found in single page applications. @@ -292,13 +266,13 @@ impl ServeDir { /// use tower_http::services::ServeDir; /// use std::{io, convert::Infallible}; /// use http::{Request, Response, StatusCode}; - /// use http_body::{combinators::UnsyncBoxBody, Body as _}; - /// use hyper::Body; + /// use http_body::Body as _; + /// use http_body_util::{Full, BodyExt, combinators::UnsyncBoxBody}; /// use bytes::Bytes; /// use tower::{service_fn, ServiceExt, BoxError}; /// /// async fn serve_dir( - /// request: Request + /// request: Request> /// ) -> Result>, Infallible> { /// let mut service = ServeDir::new("assets"); /// @@ -307,7 +281,7 @@ impl ServeDir { /// // /// // Its shown here for demonstration but you can do `service.try_call(request)` /// // otherwise - /// let ready_service = match ServiceExt::>::ready(&mut service).await { + /// let ready_service = match ServiceExt::>>::ready(&mut service).await { /// Ok(ready_service) => ready_service, /// Err(infallible) => match infallible {}, /// }; @@ -317,7 +291,7 @@ impl ServeDir { /// Ok(response.map(|body| body.map_err(Into::into).boxed_unsync())) /// } /// Err(err) => { - /// let body = Body::from("Something went wrong...") + /// let body = Full::from("Something went wrong...") /// .map_err(Into::into) /// .boxed_unsync(); /// let response = Response::builder() @@ -328,15 +302,6 @@ impl ServeDir { /// } /// } /// } - /// - /// # async { - /// // Run our service using `hyper` - /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000)); - /// hyper::Server::bind(&addr) - /// .serve(tower::make::Shared::new(service_fn(serve_dir))) - /// .await - /// .expect("server error"); - /// # }; /// ``` pub fn try_call( &mut self, @@ -447,8 +412,9 @@ where let response = result.unwrap_or_else(|err| { tracing::error!(error = %err, "Failed to read file"); - let body = - ResponseBody::new(Empty::new().map_err(|err| match err {}).boxed_unsync()); + let body = ResponseBody::new(UnsyncBoxBody::new( + Empty::new().map_err(|err| match err {}).boxed_unsync(), + )); Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(body) diff --git a/tower-http/src/services/fs/serve_dir/open_file.rs b/tower-http/src/services/fs/serve_dir/open_file.rs index a24aa088..401d34de 100644 --- a/tower-http/src/services/fs/serve_dir/open_file.rs +++ b/tower-http/src/services/fs/serve_dir/open_file.rs @@ -5,7 +5,7 @@ use super::{ use crate::content_encoding::{Encoding, QValue}; use bytes::Bytes; use http::{header, HeaderValue, Method, Request, Uri}; -use http_body::Empty; +use http_body_util::Empty; use http_range_header::RangeUnsatisfiableError; use std::{ ffi::OsStr, diff --git a/tower-http/src/services/fs/serve_dir/tests.rs b/tower-http/src/services/fs/serve_dir/tests.rs index eb870d4a..07ae110b 100644 --- a/tower-http/src/services/fs/serve_dir/tests.rs +++ b/tower-http/src/services/fs/serve_dir/tests.rs @@ -1,4 +1,5 @@ use crate::services::{ServeDir, ServeFile}; +use crate::test_helpers::{to_bytes, Body}; use brotli::BrotliDecompress; use bytes::Bytes; use flate2::bufread::{DeflateDecoder, GzDecoder}; @@ -6,7 +7,7 @@ use http::header::ALLOW; use http::{header, Method, Response}; use http::{Request, StatusCode}; use http_body::Body as HttpBody; -use hyper::Body; +use http_body_util::BodyExt; use std::convert::Infallible; use std::io::Read; use tower::{service_fn, ServiceExt}; @@ -59,8 +60,7 @@ async fn head_request() { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-length"], "23"); - let body = res.into_body().data().await; - assert!(body.is_none()); + assert!(res.into_body().frame().await.is_none()); } #[tokio::test] @@ -79,8 +79,7 @@ async fn precompresed_head_request() { assert_eq!(res.headers()["content-encoding"], "gzip"); assert_eq!(res.headers()["content-length"], "59"); - let body = res.into_body().data().await; - assert!(body.is_none()); + assert!(res.into_body().frame().await.is_none()); } #[tokio::test] @@ -116,7 +115,7 @@ async fn precompressed_gzip() { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); @@ -137,7 +136,7 @@ async fn precompressed_br() { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); @@ -157,7 +156,7 @@ async fn precompressed_deflate() { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "deflate"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decoder = DeflateDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); @@ -178,7 +177,7 @@ async fn unsupported_precompression_alogrithm_fallbacks_to_uncompressed() { assert_eq!(res.headers()["content-type"], "text/plain"); assert!(res.headers().get("content-encoding").is_none()); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("\"This is a test file!\"")); } @@ -206,7 +205,7 @@ async fn only_precompressed_variant_existing() { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); @@ -228,7 +227,7 @@ async fn missing_precompressed_variant_fallbacks_to_uncompressed() { // Uncompressed file is served because compressed version is missing assert!(res.headers().get("content-encoding").is_none()); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("Test file!")); } @@ -250,7 +249,7 @@ async fn missing_precompressed_variant_fallbacks_to_uncompressed_for_head_reques // Uncompressed file is served because compressed version is missing assert!(res.headers().get("content-encoding").is_none()); - assert!(res.into_body().data().await.is_none()); + assert!(res.into_body().frame().await.is_none()); } #[tokio::test] @@ -346,7 +345,7 @@ async fn fallbacks_to_different_precompressed_variant_if_not_found_for_head_requ assert_eq!(res.headers()["content-encoding"], "br"); assert_eq!(res.headers()["content-length"], "15"); - assert!(res.into_body().data().await.is_none()); + assert!(res.into_body().frame().await.is_none()); } #[tokio::test] @@ -365,7 +364,7 @@ async fn fallbacks_to_different_precompressed_variant_if_not_found() { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); @@ -404,7 +403,7 @@ where B: HttpBody + Unpin, B::Error: std::fmt::Debug, { - let bytes = hyper::body::to_bytes(body).await.unwrap(); + let bytes = to_bytes(body).await.unwrap(); String::from_utf8(bytes.to_vec()).unwrap() } @@ -474,7 +473,7 @@ async fn read_partial_in_bounds() { ))); assert_eq!(res.headers()["content-type"], "text/markdown"); - let body = hyper::body::to_bytes(res.into_body()).await.ok().unwrap(); + let body = to_bytes(res.into_body()).await.ok().unwrap(); let source = Bytes::from(file_contents[bytes_start_incl..=bytes_end_incl].to_vec()); assert_eq!(body, source); } @@ -583,8 +582,7 @@ async fn last_modified() { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_MODIFIED); - let body = res.into_body().data().await; - assert!(body.is_none()); + assert!(res.into_body().frame().await.is_none()); let svc = ServeDir::new(".."); let req = Request::builder() @@ -596,7 +594,7 @@ async fn last_modified() { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let readme_bytes = include_bytes!("../../../../../README.md"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), readme_bytes); // -- If-Unmodified-Since @@ -610,7 +608,7 @@ async fn last_modified() { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), readme_bytes); let svc = ServeDir::new(".."); @@ -622,8 +620,7 @@ async fn last_modified() { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED); - let body = res.into_body().data().await; - assert!(body.is_none()); + assert!(res.into_body().frame().await.is_none()); } #[tokio::test] diff --git a/tower-http/src/services/fs/serve_file.rs b/tower-http/src/services/fs/serve_file.rs index ede79936..fc2553ea 100644 --- a/tower-http/src/services/fs/serve_file.rs +++ b/tower-http/src/services/fs/serve_file.rs @@ -128,14 +128,14 @@ where #[cfg(test)] mod tests { use crate::services::ServeFile; + use crate::test_helpers::Body; use brotli::BrotliDecompress; use flate2::bufread::DeflateDecoder; use flate2::bufread::GzDecoder; use http::header; use http::Method; use http::{Request, StatusCode}; - use http_body::Body as _; - use hyper::Body; + use http_body_util::BodyExt; use mime::Mime; use std::io::Read; use std::str::FromStr; @@ -149,7 +149,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/markdown"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("# Tower HTTP")); @@ -163,7 +163,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "image/jpg"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("# Tower HTTP")); @@ -180,8 +180,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-length"], "23"); - let body = res.into_body().data().await; - assert!(body.is_none()); + assert!(res.into_body().frame().await.is_none()); } #[tokio::test] @@ -199,8 +198,7 @@ mod tests { assert_eq!(res.headers()["content-encoding"], "gzip"); assert_eq!(res.headers()["content-length"], "59"); - let body = res.into_body().data().await; - assert!(body.is_none()); + assert!(res.into_body().frame().await.is_none()); } #[tokio::test] @@ -216,7 +214,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); @@ -236,7 +234,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/plain"); assert!(res.headers().get("content-encoding").is_none()); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("\"This is a test file!\"")); } @@ -255,7 +253,7 @@ mod tests { // Uncompressed file is served because compressed version is missing assert!(res.headers().get("content-encoding").is_none()); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("Test file!")); } @@ -276,8 +274,7 @@ mod tests { // Uncompressed file is served because compressed version is missing assert!(res.headers().get("content-encoding").is_none()); - let body = res.into_body().data().await; - assert!(body.is_none()); + assert!(res.into_body().frame().await.is_none()); } #[tokio::test] @@ -299,7 +296,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); @@ -319,7 +316,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); @@ -338,7 +335,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "deflate"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decoder = DeflateDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); @@ -360,7 +357,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "gzip"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decoder = GzDecoder::new(&body[..]); let mut decompressed = String::new(); decoder.read_to_string(&mut decompressed).unwrap(); @@ -375,7 +372,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); @@ -390,7 +387,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/markdown"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let body = String::from_utf8(body.to_vec()).unwrap(); assert!(body.starts_with("# Tower HTTP")); @@ -412,7 +409,7 @@ mod tests { assert_eq!(res.headers()["content-type"], "text/plain"); assert_eq!(res.headers()["content-encoding"], "br"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); let mut decompressed = Vec::new(); BrotliDecompress(&mut &body[..], &mut decompressed).unwrap(); let decompressed = String::from_utf8(decompressed.to_vec()).unwrap(); @@ -437,8 +434,7 @@ mod tests { assert_eq!(res.headers()["content-length"], "15"); assert_eq!(res.headers()["content-encoding"], "br"); - let body = res.into_body().data().await; - assert!(body.is_none()); + assert!(res.into_body().frame().await.is_none()); } #[tokio::test] @@ -489,8 +485,7 @@ mod tests { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::NOT_MODIFIED); - let body = res.into_body().data().await; - assert!(body.is_none()); + assert!(res.into_body().frame().await.is_none()); let svc = ServeFile::new("../README.md"); let req = Request::builder() @@ -501,7 +496,7 @@ mod tests { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let readme_bytes = include_bytes!("../../../../README.md"); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), readme_bytes); // -- If-Unmodified-Since @@ -514,7 +509,7 @@ mod tests { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); - let body = res.into_body().data().await.unwrap().unwrap(); + let body = res.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), readme_bytes); let svc = ServeFile::new("../README.md"); @@ -525,7 +520,6 @@ mod tests { let res = svc.oneshot(req).await.unwrap(); assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED); - let body = res.into_body().data().await; - assert!(body.is_none()); + assert!(res.into_body().frame().await.is_none()); } } diff --git a/tower-http/src/services/redirect.rs b/tower-http/src/services/redirect.rs index f5e0552d..020927c9 100644 --- a/tower-http/src/services/redirect.rs +++ b/tower-http/src/services/redirect.rs @@ -7,18 +7,19 @@ //! //! ```rust //! use http::{Request, Uri, StatusCode}; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! use tower::{Service, ServiceExt}; //! use tower_http::services::Redirect; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { //! let uri: Uri = "https://example.com/".parse().unwrap(); -//! let mut service: Redirect = Redirect::permanent(uri); +//! let mut service: Redirect> = Redirect::permanent(uri); //! //! let request = Request::builder() //! .uri("http://example.com") -//! .body(Body::empty()) +//! .body(Full::::default()) //! .unwrap(); //! //! let response = service.oneshot(request).await?; diff --git a/tower-http/src/set_header/request.rs b/tower-http/src/set_header/request.rs index 1edec90c..4032e23a 100644 --- a/tower-http/src/set_header/request.rs +++ b/tower-http/src/set_header/request.rs @@ -12,12 +12,13 @@ //! use http::{Request, Response, header::{self, HeaderValue}}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::set_header::SetRequestHeaderLayer; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! # let http_client = tower::service_fn(|_: Request| async move { -//! # Ok::<_, std::convert::Infallible>(Response::new(Body::empty())) +//! # let http_client = tower::service_fn(|_: Request>| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(Full::::default())) //! # }); //! # //! let mut svc = ServiceBuilder::new() @@ -33,7 +34,7 @@ //! ) //! .service(http_client); //! -//! let request = Request::new(Body::empty()); +//! let request = Request::new(Full::default()); //! //! let response = svc.ready().await?.call(request).await?; //! # @@ -47,12 +48,13 @@ //! use http::{Request, Response, header::{self, HeaderValue}}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::set_header::SetRequestHeaderLayer; -//! use hyper::Body; +//! use bytes::Bytes; +//! use http_body_util::Full; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! # let http_client = tower::service_fn(|_: Request| async move { -//! # Ok::<_, std::convert::Infallible>(Response::new(Body::empty())) +//! # let http_client = tower::service_fn(|_: Request>| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(Full::::default())) //! # }); //! fn date_header_value() -> HeaderValue { //! // ... @@ -67,14 +69,14 @@ //! // may have. //! SetRequestHeaderLayer::overriding( //! header::DATE, -//! |request: &Request| { +//! |request: &Request>| { //! Some(date_header_value()) //! } //! ) //! ) //! .service(http_client); //! -//! let request = Request::new(Body::empty()); +//! let request = Request::new(Full::default()); //! //! let response = svc.ready().await?.call(request).await?; //! # diff --git a/tower-http/src/set_header/response.rs b/tower-http/src/set_header/response.rs index a612b926..c7b8ea84 100644 --- a/tower-http/src/set_header/response.rs +++ b/tower-http/src/set_header/response.rs @@ -12,11 +12,12 @@ //! use http::{Request, Response, header::{self, HeaderValue}}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::set_header::SetResponseHeaderLayer; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! # let render_html = tower::service_fn(|request: Request| async move { +//! # let render_html = tower::service_fn(|request: Request>| async move { //! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) //! # }); //! # @@ -33,7 +34,7 @@ //! ) //! .service(render_html); //! -//! let request = Request::new(Body::empty()); +//! let request = Request::new(Full::default()); //! //! let response = svc.ready().await?.call(request).await?; //! @@ -49,13 +50,14 @@ //! use http::{Request, Response, header::{self, HeaderValue}}; //! use tower::{Service, ServiceExt, ServiceBuilder}; //! use tower_http::set_header::SetResponseHeaderLayer; -//! use hyper::Body; +//! use bytes::Bytes; +//! use http_body_util::Full; //! use http_body::Body as _; // for `Body::size_hint` //! //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { -//! # let render_html = tower::service_fn(|request: Request| async move { -//! # Ok::<_, std::convert::Infallible>(Response::new(Body::from("1234567890"))) +//! # let render_html = tower::service_fn(|request: Request>| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(Full::from("1234567890"))) //! # }); //! # //! let mut svc = ServiceBuilder::new() @@ -67,7 +69,7 @@ //! // may have. //! SetResponseHeaderLayer::overriding( //! header::CONTENT_LENGTH, -//! |response: &Response| { +//! |response: &Response>| { //! if let Some(size) = response.body().size_hint().exact() { //! // If the response body has a known size, returning `Some` will //! // set the `Content-Length` header to that value. @@ -82,7 +84,7 @@ //! ) //! .service(render_html); //! -//! let request = Request::new(Body::empty()); +//! let request = Request::new(Full::default()); //! //! let response = svc.ready().await?.call(request).await?; //! @@ -300,8 +302,8 @@ where #[cfg(test)] mod tests { use super::*; + use crate::test_helpers::Body; use http::{header, HeaderValue}; - use hyper::Body; use std::convert::Infallible; use tower::{service_fn, ServiceExt}; diff --git a/tower-http/src/set_status.rs b/tower-http/src/set_status.rs index bdc4999a..65f5405e 100644 --- a/tower-http/src/set_status.rs +++ b/tower-http/src/set_status.rs @@ -5,13 +5,14 @@ //! ``` //! use tower_http::set_status::SetStatusLayer; //! use http::{Request, Response, StatusCode}; -//! use hyper::Body; +//! use bytes::Bytes; +//! use http_body_util::Full; //! use std::{iter::once, convert::Infallible}; //! use tower::{ServiceBuilder, Service, ServiceExt}; //! -//! async fn handle(req: Request) -> Result, Infallible> { +//! async fn handle(req: Request>) -> Result>, Infallible> { //! // ... -//! # Ok(Response::new(Body::empty())) +//! # Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] @@ -22,7 +23,7 @@ //! .service_fn(handle); //! //! // Call the service. -//! let request = Request::builder().body(Body::empty())?; +//! let request = Request::builder().body(Full::default())?; //! //! let response = service.ready().await?.call(request).await?; //! diff --git a/tower-http/src/test_helpers.rs b/tower-http/src/test_helpers.rs new file mode 100644 index 00000000..6add4233 --- /dev/null +++ b/tower-http/src/test_helpers.rs @@ -0,0 +1,166 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_util::TryStream; +use http::HeaderMap; +use http_body::Frame; +use http_body_util::BodyExt; +use pin_project_lite::pin_project; +use sync_wrapper::SyncWrapper; +use tower::BoxError; + +type BoxBody = http_body_util::combinators::UnsyncBoxBody; + +#[derive(Debug)] +pub(crate) struct Body(BoxBody); + +impl Body { + pub(crate) fn new(body: B) -> Self + where + B: http_body::Body + Send + 'static, + B::Error: Into, + { + Self(body.map_err(Into::into).boxed_unsync()) + } + + pub(crate) fn empty() -> Self { + Self::new(http_body_util::Empty::new()) + } + + pub(crate) fn from_stream(stream: S) -> Self + where + S: TryStream + Send + 'static, + S::Ok: Into, + S::Error: Into, + { + Self::new(StreamBody { + stream: SyncWrapper::new(stream), + }) + } + + pub(crate) fn with_trailers(self, trailers: HeaderMap) -> WithTrailers { + WithTrailers { + inner: self, + trailers: Some(trailers), + } + } +} + +impl Default for Body { + fn default() -> Self { + Self::empty() + } +} + +macro_rules! body_from_impl { + ($ty:ty) => { + impl From<$ty> for Body { + fn from(buf: $ty) -> Self { + Self::new(http_body_util::Full::from(buf)) + } + } + }; +} + +body_from_impl!(&'static [u8]); +body_from_impl!(std::borrow::Cow<'static, [u8]>); +body_from_impl!(Vec); + +body_from_impl!(&'static str); +body_from_impl!(std::borrow::Cow<'static, str>); +body_from_impl!(String); + +body_from_impl!(Bytes); + +impl http_body::Body for Body { + type Data = Bytes; + type Error = BoxError; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Pin::new(&mut self.0).poll_frame(cx) + } + + fn size_hint(&self) -> http_body::SizeHint { + self.0.size_hint() + } + + fn is_end_stream(&self) -> bool { + self.0.is_end_stream() + } +} + +pin_project! { + struct StreamBody { + #[pin] + stream: SyncWrapper, + } +} + +impl http_body::Body for StreamBody +where + S: TryStream, + S::Ok: Into, + S::Error: Into, +{ + type Data = Bytes; + type Error = BoxError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let stream = self.project().stream.get_pin_mut(); + match std::task::ready!(stream.try_poll_next(cx)) { + Some(Ok(chunk)) => Poll::Ready(Some(Ok(Frame::data(chunk.into())))), + Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), + None => Poll::Ready(None), + } + } +} + +pub(crate) async fn to_bytes(body: T) -> Result +where + T: http_body::Body, +{ + futures_util::pin_mut!(body); + Ok(body.collect().await?.to_bytes()) +} + +pin_project! { + pub(crate) struct WithTrailers { + #[pin] + inner: B, + trailers: Option, + } +} + +impl http_body::Body for WithTrailers +where + B: http_body::Body, +{ + type Data = B::Data; + type Error = B::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let this = self.project(); + match std::task::ready!(this.inner.poll_frame(cx)) { + Some(frame) => Poll::Ready(Some(frame)), + None => { + if let Some(trailers) = this.trailers.take() { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) + } else { + Poll::Ready(None) + } + } + } + } +} diff --git a/tower-http/src/timeout/body.rs b/tower-http/src/timeout/body.rs index efdf3165..3705d1c0 100644 --- a/tower-http/src/timeout/body.rs +++ b/tower-http/src/timeout/body.rs @@ -29,12 +29,13 @@ pin_project! { /// /// ``` /// use http::{Request, Response}; - /// use hyper::Body; + /// use bytes::Bytes; + /// use http_body_util::Full; /// use std::time::Duration; /// use tower::ServiceBuilder; /// use tower_http::timeout::RequestBodyTimeoutLayer; /// - /// async fn handle(_: Request) -> Result, std::convert::Infallible> { + /// async fn handle(_: Request>) -> Result>, std::convert::Infallible> { /// // ... /// # todo!() /// } @@ -50,13 +51,8 @@ pin_project! { /// ``` pub struct TimeoutBody { timeout: Duration, - // In http-body 1.0, `poll_*` will be merged into `poll_frame`. - // Merge the two `sleep_data` and `sleep_trailers` into one `sleep`. - // See: https://github.com/tower-rs/tower-http/pull/303#discussion_r1004834958 #[pin] - sleep_data: Option, - #[pin] - sleep_trailers: Option, + sleep: Option, #[pin] body: B, } @@ -67,8 +63,7 @@ impl TimeoutBody { pub fn new(timeout: Duration, body: B) -> Self { TimeoutBody { timeout, - sleep_data: None, - sleep_trailers: None, + sleep: None, body, } } @@ -82,18 +77,18 @@ where type Data = B::Data; type Error = Box; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let mut this = self.project(); // Start the `Sleep` if not active. - let sleep_pinned = if let Some(some) = this.sleep_data.as_mut().as_pin_mut() { + let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() { some } else { - this.sleep_data.set(Some(sleep(*this.timeout))); - this.sleep_data.as_mut().as_pin_mut().unwrap() + this.sleep.set(Some(sleep(*this.timeout))); + this.sleep.as_mut().as_pin_mut().unwrap() }; // Error if the timeout has expired. @@ -102,36 +97,11 @@ where } // Check for body data. - let data = ready!(this.body.poll_data(cx)); - // Some data is ready. Reset the `Sleep`... - this.sleep_data.set(None); - - Poll::Ready(data.transpose().map_err(Into::into).transpose()) - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - let mut this = self.project(); - - // In http-body 1.0, `poll_*` will be merged into `poll_frame`. - // Merge the two `sleep_data` and `sleep_trailers` into one `sleep`. - // See: https://github.com/tower-rs/tower-http/pull/303#discussion_r1004834958 - - let sleep_pinned = if let Some(some) = this.sleep_trailers.as_mut().as_pin_mut() { - some - } else { - this.sleep_trailers.set(Some(sleep(*this.timeout))); - this.sleep_trailers.as_mut().as_pin_mut().unwrap() - }; - - // Error if the timeout has expired. - if let Poll::Ready(()) = sleep_pinned.poll(cx) { - return Poll::Ready(Err(Box::new(TimeoutError(())))); - } + let frame = ready!(this.body.poll_frame(cx)); + // A frame is ready. Reset the `Sleep`... + this.sleep.set(None); - this.body.poll_trailers(cx).map_err(Into::into) + Poll::Ready(frame.transpose().map_err(Into::into).transpose()) } } @@ -151,6 +121,8 @@ mod tests { use super::*; use bytes::Bytes; + use http_body::Frame; + use http_body_util::BodyExt; use pin_project_lite::pin_project; use std::{error::Error, fmt::Display}; @@ -158,9 +130,10 @@ mod tests { struct MockError; impl Error for MockError {} + impl Display for MockError { - fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - todo!() + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "mock error") } } @@ -175,19 +148,14 @@ mod tests { type Data = Bytes; type Error = MockError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let this = self.project(); - this.sleep.poll(cx).map(|_| Some(Ok(vec![].into()))) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - todo!() + this.sleep + .poll(cx) + .map(|_| Some(Ok(Frame::data(vec![].into())))) } } @@ -201,7 +169,12 @@ mod tests { }; let timeout_body = TimeoutBody::new(timeout_sleep, mock_body); - assert!(timeout_body.boxed().data().await.unwrap().is_ok()); + assert!(timeout_body + .boxed() + .frame() + .await + .expect("no frame") + .is_ok()); } #[tokio::test] @@ -214,6 +187,6 @@ mod tests { }; let timeout_body = TimeoutBody::new(timeout_sleep, mock_body); - assert!(timeout_body.boxed().data().await.unwrap().is_err()); + assert!(timeout_body.boxed().frame().await.unwrap().is_err()); } } diff --git a/tower-http/src/timeout/mod.rs b/tower-http/src/timeout/mod.rs index 13b6e19b..facb6a92 100644 --- a/tower-http/src/timeout/mod.rs +++ b/tower-http/src/timeout/mod.rs @@ -17,14 +17,15 @@ //! //! ``` //! use http::{Request, Response}; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! use std::{convert::Infallible, time::Duration}; //! use tower::ServiceBuilder; //! use tower_http::timeout::TimeoutLayer; //! -//! async fn handle(_: Request) -> Result, Infallible> { +//! async fn handle(_: Request>) -> Result>, Infallible> { //! // ... -//! # Ok(Response::new(Body::empty())) +//! # Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] diff --git a/tower-http/src/trace/body.rs b/tower-http/src/trace/body.rs index 55648747..7a27cd90 100644 --- a/tower-http/src/trace/body.rs +++ b/tower-http/src/trace/body.rs @@ -1,7 +1,6 @@ use super::{OnBodyChunk, OnEos, OnFailure}; use crate::classify::ClassifyEos; -use http::HeaderMap; -use http_body::Body; +use http_body::{Body, Frame}; use pin_project_lite::pin_project; use std::{ fmt, @@ -40,70 +39,57 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let this = self.project(); let _guard = this.span.enter(); - - let result = if let Some(result) = ready!(this.inner.poll_data(cx)) { - result - } else { - return Poll::Ready(None); - }; + let result = ready!(this.inner.poll_frame(cx)); let latency = this.start.elapsed(); *this.start = Instant::now(); - match &result { - Ok(chunk) => { - this.on_body_chunk.on_body_chunk(chunk, latency, this.span); + match result { + Some(Ok(frame)) => { + let frame = match frame.into_data() { + Ok(chunk) => { + this.on_body_chunk.on_body_chunk(&chunk, latency, this.span); + Frame::data(chunk) + } + Err(frame) => frame, + }; + + let frame = match frame.into_trailers() { + Ok(trailers) => { + if let Some((on_eos, stream_start)) = this.on_eos.take() { + on_eos.on_eos(Some(&trailers), stream_start.elapsed(), this.span); + } + Frame::trailers(trailers) + } + Err(frame) => frame, + }; + + Poll::Ready(Some(Ok(frame))) } - Err(err) => { + Some(Err(err)) => { if let Some((classify_eos, mut on_failure)) = this.classify_eos.take().zip(this.on_failure.take()) { - let failure_class = classify_eos.classify_error(err); + let failure_class = classify_eos.classify_error(&err); on_failure.on_failure(failure_class, latency, this.span); } - } - } - - Poll::Ready(Some(result)) - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - let this = self.project(); - let _guard = this.span.enter(); - let result = ready!(this.inner.poll_trailers(cx)); - - let latency = this.start.elapsed(); - - if let Some((classify_eos, mut on_failure)) = - this.classify_eos.take().zip(this.on_failure.take()) - { - match &result { - Ok(trailers) => { - if let Err(failure_class) = classify_eos.classify_eos(trailers.as_ref()) { - on_failure.on_failure(failure_class, latency, this.span); - } - if let Some((on_eos, stream_start)) = this.on_eos.take() { - on_eos.on_eos(trailers.as_ref(), stream_start.elapsed(), this.span); - } - } - Err(err) => { - let failure_class = classify_eos.classify_error(err); - on_failure.on_failure(failure_class, latency, this.span); + Poll::Ready(Some(Err(err))) + } + None => { + if let Some((on_eos, stream_start)) = this.on_eos.take() { + on_eos.on_eos(None, stream_start.elapsed(), this.span); } + + Poll::Ready(None) } } - - Poll::Ready(result) } fn is_end_stream(&self) -> bool { diff --git a/tower-http/src/trace/mod.rs b/tower-http/src/trace/mod.rs index f46087dc..255bb703 100644 --- a/tower-http/src/trace/mod.rs +++ b/tower-http/src/trace/mod.rs @@ -6,13 +6,14 @@ //! //! ```rust //! use http::{Request, Response}; -//! use hyper::Body; //! use tower::{ServiceBuilder, ServiceExt, Service}; //! use tower_http::trace::TraceLayer; //! use std::convert::Infallible; +//! use http_body_util::Full; +//! use bytes::Bytes; //! -//! async fn handle(request: Request) -> Result, Infallible> { -//! Ok(Response::new(Body::from("foo"))) +//! async fn handle(request: Request>) -> Result>, Infallible> { +//! Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] @@ -24,7 +25,7 @@ //! .layer(TraceLayer::new_for_http()) //! .service_fn(handle); //! -//! let request = Request::new(Body::from("foo")); +//! let request = Request::new(Full::from("foo")); //! //! let response = service //! .ready() @@ -50,7 +51,7 @@ //! //! ```rust //! use http::{Request, Response, HeaderMap, StatusCode}; -//! use hyper::Body; +//! use http_body_util::Full; //! use bytes::Bytes; //! use tower::ServiceBuilder; //! use tracing::Level; @@ -62,8 +63,8 @@ //! # use tower::{ServiceExt, Service}; //! # use std::convert::Infallible; //! -//! # async fn handle(request: Request) -> Result, Infallible> { -//! # Ok(Response::new(Body::from("foo"))) +//! # async fn handle(request: Request>) -> Result>, Infallible> { +//! # Ok(Response::new(Full::from("foo"))) //! # } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { @@ -90,7 +91,7 @@ //! # let response = service //! # .ready() //! # .await? -//! # .call(Request::new(Body::from("foo"))) +//! # .call(Request::new(Full::from("foo"))) //! # .await?; //! # Ok(()) //! # } @@ -100,7 +101,7 @@ //! //! ```rust //! use http::{Request, Response, HeaderMap, StatusCode}; -//! use hyper::Body; +//! use http_body_util::Full; //! use bytes::Bytes; //! use tower::ServiceBuilder; //! use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer}; @@ -109,8 +110,8 @@ //! # use tower::{ServiceExt, Service}; //! # use std::convert::Infallible; //! -//! # async fn handle(request: Request) -> Result, Infallible> { -//! # Ok(Response::new(Body::from("foo"))) +//! # async fn handle(request: Request>) -> Result>, Infallible> { +//! # Ok(Response::new(Full::from("foo"))) //! # } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { @@ -119,13 +120,13 @@ //! let service = ServiceBuilder::new() //! .layer( //! TraceLayer::new_for_http() -//! .make_span_with(|request: &Request| { +//! .make_span_with(|request: &Request>| { //! tracing::debug_span!("http-request") //! }) -//! .on_request(|request: &Request, _span: &Span| { +//! .on_request(|request: &Request>, _span: &Span| { //! tracing::debug!("started {} {}", request.method(), request.uri().path()) //! }) -//! .on_response(|response: &Response, latency: Duration, _span: &Span| { +//! .on_response(|response: &Response>, latency: Duration, _span: &Span| { //! tracing::debug!("response generated in {:?}", latency) //! }) //! .on_body_chunk(|chunk: &Bytes, latency: Duration, _span: &Span| { @@ -143,7 +144,7 @@ //! # let response = service //! # .ready() //! # .await? -//! # .call(Request::new(Body::from("foo"))) +//! # .call(Request::new(Full::from("foo"))) //! # .await?; //! # Ok(()) //! # } @@ -160,12 +161,13 @@ //! use std::time::Duration; //! use tracing::Span; //! # use tower::{ServiceExt, Service}; -//! # use hyper::Body; +//! # use http_body_util::Full; +//! # use bytes::Bytes; //! # use http::{Response, Request}; //! # use std::convert::Infallible; //! -//! # async fn handle(request: Request) -> Result, Infallible> { -//! # Ok(Response::new(Body::from("foo"))) +//! # async fn handle(request: Request>) -> Result>, Infallible> { +//! # Ok(Response::new(Full::from("foo"))) //! # } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { @@ -188,7 +190,7 @@ //! # let response = service //! # .ready() //! # .await? -//! # .call(Request::new(Body::from("foo"))) +//! # .call(Request::new(Full::from("foo"))) //! # .await?; //! # Ok(()) //! # } @@ -216,14 +218,14 @@ //! ### `on_body_chunk` //! //! The `on_body_chunk` callback is called when the response body produces a new -//! chunk, that is when [`Body::poll_data`] returns `Poll::Ready(Some(Ok(chunk)))`. +//! chunk, that is when [`Body::poll_frame`] returns a data frame. //! //! `on_body_chunk` is called even if the chunk is empty. //! //! ### `on_eos` //! //! The `on_eos` callback is called when a streaming response body ends, that is -//! when [`Body::poll_trailers`] returns `Poll::Ready(Ok(trailers))`. +//! when [`Body::poll_frame`] returns a trailers frame. //! //! `on_eos` is called even if the trailers produced are `None`. //! @@ -233,8 +235,7 @@ //! //! - The inner [`Service`]'s response future resolves to an error. //! - A response is classified as a failure. -//! - [`Body::poll_data`] returns an error. -//! - [`Body::poll_trailers`] returns an error. +//! - [`Body::poll_frame`] returns an error. //! - An end-of-stream is classified as a failure. //! //! # Recording fields on the span @@ -245,7 +246,7 @@ //! //! ```rust //! use http::{Request, Response, HeaderMap, StatusCode}; -//! use hyper::Body; +//! use http_body_util::Full; //! use bytes::Bytes; //! use tower::ServiceBuilder; //! use tower_http::trace::TraceLayer; @@ -253,8 +254,8 @@ //! use std::time::Duration; //! # use std::convert::Infallible; //! -//! # async fn handle(request: Request) -> Result, Infallible> { -//! # Ok(Response::new(Body::from("foo"))) +//! # async fn handle(request: Request>) -> Result>, Infallible> { +//! # Ok(Response::new(Full::from("foo"))) //! # } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { @@ -263,13 +264,13 @@ //! let service = ServiceBuilder::new() //! .layer( //! TraceLayer::new_for_http() -//! .make_span_with(|request: &Request| { +//! .make_span_with(|request: &Request>| { //! tracing::debug_span!( //! "http-request", //! status_code = tracing::field::Empty, //! ) //! }) -//! .on_response(|response: &Response, _latency: Duration, span: &Span| { +//! .on_response(|response: &Response>, _latency: Duration, span: &Span| { //! span.record("status_code", &tracing::field::display(response.status())); //! //! tracing::debug!("response generated") @@ -290,7 +291,8 @@ //! //! ```rust //! use http::{Request, Response}; -//! use hyper::Body; +//! use http_body_util::Full; +//! use bytes::Bytes; //! use tower::ServiceBuilder; //! use tower_http::{ //! trace::TraceLayer, @@ -301,8 +303,8 @@ //! }; //! use std::convert::Infallible; //! -//! # async fn handle(request: Request) -> Result, Infallible> { -//! # Ok(Response::new(Body::from("foo"))) +//! # async fn handle(request: Request>) -> Result>, Infallible> { +//! # Ok(Response::new(Full::from("foo"))) //! # } //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { @@ -380,8 +382,7 @@ //! [`TraceLayer::make_span_with`]: crate::trace::TraceLayer::make_span_with //! [`Span`]: tracing::Span //! [`ServerErrorsAsFailures`]: crate::classify::ServerErrorsAsFailures -//! [`Body::poll_trailers`]: http_body::Body::poll_trailers -//! [`Body::poll_data`]: http_body::Body::poll_data +//! [`Body::poll_frame`]: http_body::Body::poll_frame use std::{fmt, time::Duration}; @@ -482,9 +483,9 @@ impl fmt::Display for Latency { mod tests { use super::*; use crate::classify::ServerErrorsFailureClass; + use crate::test_helpers::Body; use bytes::Bytes; use http::{HeaderMap, Request, Response}; - use hyper::Body; use once_cell::sync::Lazy; use std::{ sync::atomic::{AtomicU32, Ordering}, @@ -506,7 +507,7 @@ mod tests { tracing::info_span!("test-span", foo = tracing::field::Empty) }) .on_request(|_req: &Request, span: &Span| { - span.record("foo", &42); + span.record("foo", 42); ON_REQUEST_COUNT.fetch_add(1, Ordering::SeqCst); }) .on_response(|_res: &Response, _latency: Duration, _span: &Span| { @@ -542,7 +543,9 @@ mod tests { assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); - hyper::body::to_bytes(res.into_body()).await.unwrap(); + crate::test_helpers::to_bytes(res.into_body()) + .await + .unwrap(); assert_eq!(1, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); @@ -595,7 +598,9 @@ mod tests { assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); - hyper::body::to_bytes(res.into_body()).await.unwrap(); + crate::test_helpers::to_bytes(res.into_body()) + .await + .unwrap(); assert_eq!(3, ON_BODY_CHUNK_COUNT.load(Ordering::SeqCst), "body chunk"); assert_eq!(0, ON_EOS.load(Ordering::SeqCst), "eos"); assert_eq!(0, ON_FAILURE.load(Ordering::SeqCst), "failure"); @@ -614,7 +619,7 @@ mod tests { Ok::<_, BoxError>(Bytes::from("three")), ]); - let body = Body::wrap_stream(stream); + let body = Body::from_stream(stream); Ok(Response::new(body)) } diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index 4f1ea680..327266af 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -4,16 +4,17 @@ //! //! ``` //! use tower_http::validate_request::ValidateRequestHeaderLayer; -//! use hyper::{Request, Response, Body, Error}; -//! use http::{StatusCode, header::ACCEPT}; -//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use http::{Request, Response, StatusCode, header::ACCEPT}; +//! use http_body_util::Full; +//! use bytes::Bytes; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; //! -//! async fn handle(request: Request) -> Result, Error> { -//! Ok(Response::new(Body::empty())) +//! async fn handle(request: Request>) -> Result>, BoxError> { +//! Ok(Response::new(Full::default())) //! } //! //! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { +//! # async fn main() -> Result<(), BoxError> { //! let mut service = ServiceBuilder::new() //! // Require the `Accept` header to be `application/json`, `*/*` or `application/*` //! .layer(ValidateRequestHeaderLayer::accept("application/json")) @@ -22,7 +23,7 @@ //! // Requests with the correct value are allowed through //! let request = Request::builder() //! .header(ACCEPT, "application/json") -//! .body(Body::empty()) +//! .body(Full::default()) //! .unwrap(); //! //! let response = service @@ -36,7 +37,7 @@ //! // Requests with an invalid value get a `406 Not Acceptable` response //! let request = Request::builder() //! .header(ACCEPT, "text/strings") -//! .body(Body::empty()) +//! .body(Full::default()) //! .unwrap(); //! //! let response = service @@ -54,15 +55,16 @@ //! //! ``` //! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; -//! use hyper::{Request, Response, Body, Error}; -//! use http::{StatusCode, header::ACCEPT}; -//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use http::{Request, Response, StatusCode, header::ACCEPT}; +//! use http_body_util::Full; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; +//! use bytes::Bytes; //! //! #[derive(Clone, Copy)] //! pub struct MyHeader { /* ... */ } //! //! impl ValidateRequest for MyHeader { -//! type ResponseBody = Body; +//! type ResponseBody = Full; //! //! fn validate( //! &mut self, @@ -73,13 +75,13 @@ //! } //! } //! -//! async fn handle(request: Request) -> Result, Error> { -//! Ok(Response::new(Body::empty())) +//! async fn handle(request: Request>) -> Result>, BoxError> { +//! Ok(Response::new(Full::default())) //! } //! //! //! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { +//! # async fn main() -> Result<(), BoxError> { //! let service = ServiceBuilder::new() //! // Validate requests using `MyHeader` //! .layer(ValidateRequestHeaderLayer::custom(MyHeader { /* ... */ })) @@ -92,21 +94,22 @@ //! //! ``` //! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest}; -//! use hyper::{Request, Response, Body, Error}; -//! use http::{StatusCode, header::ACCEPT}; -//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use http::{Request, Response, StatusCode, header::ACCEPT}; +//! use bytes::Bytes; +//! use http_body_util::Full; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError}; //! -//! async fn handle(request: Request) -> Result, Error> { +//! async fn handle(request: Request>) -> Result>, BoxError> { //! # todo!(); //! // ... //! } //! //! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { +//! # async fn main() -> Result<(), BoxError> { //! let service = ServiceBuilder::new() -//! .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request| { +//! .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request>| { //! // Validate the request -//! # Ok::<_, Response>(()) +//! # Ok::<_, Response>>(()) //! })) //! .service_fn(handle); //! # Ok(()) @@ -150,10 +153,11 @@ impl ValidateRequestHeaderLayer> { /// # Example /// /// ``` - /// use hyper::Body; + /// use http_body_util::Full; + /// use bytes::Bytes; /// use tower_http::validate_request::{AcceptHeader, ValidateRequestHeaderLayer}; /// - /// let layer = ValidateRequestHeaderLayer::>::accept("application/json"); + /// let layer = ValidateRequestHeaderLayer::>>::accept("application/json"); /// ``` /// /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept @@ -381,7 +385,7 @@ where .into_iter() .filter_map(|header| header.to_str().ok()) .any(|h| { - MimeIter::new(&h) + MimeIter::new(h) .map(|mim| { if let Ok(mim) = mim { let typ = self.header_value.type_(); @@ -412,8 +416,8 @@ where mod tests { #[allow(unused_imports)] use super::*; - use http::{header, StatusCode}; - use hyper::Body; + use crate::test_helpers::Body; + use http::header; use tower::{BoxError, ServiceBuilder, ServiceExt}; #[tokio::test]