-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add layer to decompress request bodies Still needs some refactor to remove duplicate code and needs documentation * Refactor decompression modules * Fix incorrect rename * Add ResponseFuture for RequestDecompression Which either polls its inner future or returns 415 Unsupported Media Type * Refactor DecompressionService * Rollback rename and move of `Decompression` Refactoring of `decompression` module will be done in a later PR * Re-add `request` module to the `decompression` module * Send "identity" encoding when no encodings are accepted * Add documentation * Add example * Fix styling * Fix some styling of imports and documentation * Add enable parameter to RequestDecompressionLayer::pass_through_unaccepted * Cleanup redundant code * fix imports * actually fix import * check for zstd * zstd --------- Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
- Loading branch information
1 parent
f3d8528
commit 67130ab
Showing
6 changed files
with
548 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
use crate::compression_utils::AcceptEncoding; | ||
use crate::BoxError; | ||
use bytes::Buf; | ||
use http::{header, HeaderValue, Response, StatusCode}; | ||
use http_body::{combinators::UnsyncBoxBody, Body, Empty}; | ||
use pin_project_lite::pin_project; | ||
use std::future::Future; | ||
use std::pin::Pin; | ||
use std::task::Context; | ||
use std::task::Poll; | ||
|
||
pin_project! { | ||
#[derive(Debug)] | ||
/// Response future of [`RequestDecompression`] | ||
pub struct RequestDecompressionFuture<F, B, E> | ||
where | ||
F: Future<Output = Result<Response<B>, E>>, | ||
B: Body | ||
{ | ||
#[pin] | ||
kind: Kind<F, B, E>, | ||
} | ||
} | ||
|
||
pin_project! { | ||
#[derive(Debug)] | ||
#[project = StateProj] | ||
enum Kind<F, B, E> | ||
where | ||
F: Future<Output = Result<Response<B>, E>>, | ||
B: Body | ||
{ | ||
Inner { | ||
#[pin] | ||
fut: F | ||
}, | ||
Unsupported { | ||
#[pin] | ||
accept: AcceptEncoding | ||
}, | ||
} | ||
} | ||
|
||
impl<F, B, E> RequestDecompressionFuture<F, B, E> | ||
where | ||
F: Future<Output = Result<Response<B>, E>>, | ||
B: Body, | ||
{ | ||
#[must_use] | ||
pub(super) fn unsupported_encoding(accept: AcceptEncoding) -> Self { | ||
Self { | ||
kind: Kind::Unsupported { accept }, | ||
} | ||
} | ||
|
||
#[must_use] | ||
pub(super) fn inner(fut: F) -> Self { | ||
Self { | ||
kind: Kind::Inner { fut }, | ||
} | ||
} | ||
} | ||
|
||
impl<F, B, E> Future for RequestDecompressionFuture<F, B, E> | ||
where | ||
F: Future<Output = Result<Response<B>, E>>, | ||
B: Body + Send + 'static, | ||
B::Data: Buf + 'static, | ||
B::Error: Into<BoxError> + 'static, | ||
E: Into<BoxError>, | ||
{ | ||
type Output = Result<Response<UnsyncBoxBody<B::Data, BoxError>>, BoxError>; | ||
|
||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||
match self.project().kind.project() { | ||
StateProj::Inner { fut } => fut | ||
.poll(cx) | ||
.map_ok(|res| res.map(|body| body.map_err(Into::into).boxed_unsync())) | ||
.map_err(Into::into), | ||
StateProj::Unsupported { accept } => { | ||
let res = Response::builder() | ||
.header( | ||
header::ACCEPT_ENCODING, | ||
accept | ||
.to_header_value() | ||
.unwrap_or(HeaderValue::from_static("identity")), | ||
) | ||
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE) | ||
.body(Empty::new().map_err(Into::into).boxed_unsync()) | ||
.unwrap(); | ||
Poll::Ready(Ok(res)) | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
use super::service::RequestDecompression; | ||
use crate::compression_utils::AcceptEncoding; | ||
use tower_layer::Layer; | ||
|
||
/// Decompresses request bodies and calls its underlying service. | ||
/// | ||
/// Transparently decompresses request bodies based on the `Content-Encoding` header. | ||
/// When the encoding in the `Content-Encoding` header is not accepted an `Unsupported Media Type` | ||
/// status code will be returned with the accepted encodings in the `Accept-Encoding` header. | ||
/// | ||
/// Enabling pass-through of unaccepted encodings will not return an `Unsupported Media Type`. But | ||
/// will call the underlying service with the unmodified request if the encoding is not supported. | ||
/// This is disabled by default. | ||
/// | ||
/// See the [module docs](crate::decompression) for more details. | ||
#[derive(Debug, Default, Clone)] | ||
pub struct RequestDecompressionLayer { | ||
accept: AcceptEncoding, | ||
pass_through_unaccepted: bool, | ||
} | ||
|
||
impl<S> Layer<S> for RequestDecompressionLayer { | ||
type Service = RequestDecompression<S>; | ||
|
||
fn layer(&self, service: S) -> Self::Service { | ||
RequestDecompression { | ||
inner: service, | ||
accept: self.accept, | ||
pass_through_unaccepted: self.pass_through_unaccepted, | ||
} | ||
} | ||
} | ||
|
||
impl RequestDecompressionLayer { | ||
/// Creates a new `RequestDecompressionLayer`. | ||
pub fn new() -> Self { | ||
Default::default() | ||
} | ||
|
||
/// Sets whether to support gzip encoding. | ||
#[cfg(feature = "decompression-gzip")] | ||
pub fn gzip(mut self, enable: bool) -> Self { | ||
self.accept.set_gzip(enable); | ||
self | ||
} | ||
|
||
/// Sets whether to support Deflate encoding. | ||
#[cfg(feature = "decompression-deflate")] | ||
pub fn deflate(mut self, enable: bool) -> Self { | ||
self.accept.set_deflate(enable); | ||
self | ||
} | ||
|
||
/// Sets whether to support Brotli encoding. | ||
#[cfg(feature = "decompression-br")] | ||
pub fn br(mut self, enable: bool) -> Self { | ||
self.accept.set_br(enable); | ||
self | ||
} | ||
|
||
/// Sets whether to support Zstd encoding. | ||
#[cfg(feature = "decompression-zstd")] | ||
pub fn zstd(mut self, enable: bool) -> Self { | ||
self.accept.set_zstd(enable); | ||
self | ||
} | ||
|
||
/// Disables support for gzip encoding. | ||
/// | ||
/// This method is available even if the `gzip` crate feature is disabled. | ||
pub fn no_gzip(mut self) -> Self { | ||
self.accept.set_gzip(false); | ||
self | ||
} | ||
|
||
/// Disables support for Deflate encoding. | ||
/// | ||
/// This method is available even if the `deflate` crate feature is disabled. | ||
pub fn no_deflate(mut self) -> Self { | ||
self.accept.set_deflate(false); | ||
self | ||
} | ||
|
||
/// Disables support for Brotli encoding. | ||
/// | ||
/// This method is available even if the `br` crate feature is disabled. | ||
pub fn no_br(mut self) -> Self { | ||
self.accept.set_br(false); | ||
self | ||
} | ||
|
||
/// Disables support for Zstd encoding. | ||
/// | ||
/// This method is available even if the `zstd` crate feature is disabled. | ||
pub fn no_zstd(mut self) -> Self { | ||
self.accept.set_zstd(false); | ||
self | ||
} | ||
|
||
/// Sets whether to pass through the request even when the encoding is not supported. | ||
pub fn pass_through_unaccepted(mut self, enable: bool) -> Self { | ||
self.pass_through_unaccepted = enable; | ||
self | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
pub(super) mod future; | ||
pub(super) mod layer; | ||
pub(super) mod service; | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::service::RequestDecompression; | ||
use crate::decompression::DecompressionBody; | ||
use bytes::BytesMut; | ||
use flate2::{write::GzEncoder, Compression}; | ||
use http::{header, Response, StatusCode}; | ||
use http_body::Body as _; | ||
use hyper::{Body, Error, Request, Server}; | ||
use std::io::Write; | ||
use std::net::SocketAddr; | ||
use tower::{make::Shared, service_fn, Service, ServiceExt}; | ||
|
||
#[tokio::test] | ||
async fn decompress_accepted_encoding() { | ||
let req = request_gzip(); | ||
let mut svc = RequestDecompression::new(service_fn(assert_request_is_decompressed)); | ||
let _ = svc.ready().await.unwrap().call(req).await.unwrap(); | ||
} | ||
|
||
#[tokio::test] | ||
async fn support_unencoded_body() { | ||
let req = Request::builder().body(Body::from("Hello?")).unwrap(); | ||
let mut svc = RequestDecompression::new(service_fn(assert_request_is_decompressed)); | ||
let _ = svc.ready().await.unwrap().call(req).await.unwrap(); | ||
} | ||
|
||
#[tokio::test] | ||
async fn unaccepted_content_encoding_returns_unsupported_media_type() { | ||
let req = request_gzip(); | ||
let mut svc = RequestDecompression::new(service_fn(should_not_be_called)).gzip(false); | ||
let res = svc.ready().await.unwrap().call(req).await.unwrap(); | ||
assert_eq!(StatusCode::UNSUPPORTED_MEDIA_TYPE, res.status()); | ||
} | ||
|
||
#[tokio::test] | ||
async fn pass_through_unsupported_encoding_when_enabled() { | ||
let req = request_gzip(); | ||
let mut svc = RequestDecompression::new(service_fn(assert_request_is_passed_through)) | ||
.pass_through_unaccepted(true) | ||
.gzip(false); | ||
let _ = svc.ready().await.unwrap().call(req).await.unwrap(); | ||
} | ||
|
||
async fn assert_request_is_decompressed( | ||
req: Request<DecompressionBody<Body>>, | ||
) -> Result<Response<Body>, Error> { | ||
let (parts, mut body) = req.into_parts(); | ||
let body = read_body(&mut body).await; | ||
|
||
assert_eq!(body, b"Hello?"); | ||
assert!(!parts.headers.contains_key(header::CONTENT_ENCODING)); | ||
|
||
Ok(Response::new(Body::from("Hello, World!"))) | ||
} | ||
|
||
async fn assert_request_is_passed_through( | ||
req: Request<DecompressionBody<Body>>, | ||
) -> Result<Response<Body>, Error> { | ||
let (parts, mut body) = req.into_parts(); | ||
let body = read_body(&mut body).await; | ||
|
||
assert_ne!(body, b"Hello?"); | ||
assert!(parts.headers.contains_key(header::CONTENT_ENCODING)); | ||
|
||
Ok(Response::new(Body::empty())) | ||
} | ||
|
||
async fn should_not_be_called( | ||
_: Request<DecompressionBody<Body>>, | ||
) -> Result<Response<Body>, Error> { | ||
panic!("Inner service should not be called"); | ||
} | ||
|
||
fn request_gzip() -> Request<Body> { | ||
let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); | ||
encoder.write_all(b"Hello?").unwrap(); | ||
let body = encoder.finish().unwrap(); | ||
Request::builder() | ||
.header(header::CONTENT_ENCODING, "gzip") | ||
.body(Body::from(body)) | ||
.unwrap() | ||
} | ||
|
||
async fn read_body(body: &mut DecompressionBody<Body>) -> Vec<u8> { | ||
let mut data = BytesMut::new(); | ||
while let Some(chunk) = body.data().await { | ||
let chunk = chunk.unwrap(); | ||
data.extend_from_slice(&chunk[..]); | ||
} | ||
data.freeze().to_vec() | ||
} | ||
|
||
#[allow(dead_code)] | ||
async fn is_compatible_with_hyper() { | ||
let svc = service_fn(assert_request_is_decompressed); | ||
let svc = RequestDecompression::new(svc); | ||
|
||
let make_service = Shared::new(svc); | ||
|
||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); | ||
let server = Server::bind(&addr).serve(make_service); | ||
server.await.unwrap(); | ||
} | ||
} |
Oops, something went wrong.