From 15781fe22b7b100075bf2be176976598a1402ed1 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 26 Nov 2023 13:21:35 +0100 Subject: [PATCH] Prepare `serve` for potentially supporting graceful shutdown (#2357) --- .../fail/argument_not_extractor.stderr | 2 +- .../fail/parts_extracting_body.stderr | 2 +- axum/benches/benches.rs | 4 +- axum/src/serve.rs | 156 +++++++++++++----- examples/testing-websockets/src/main.rs | 7 +- 5 files changed, 129 insertions(+), 42 deletions(-) diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index acda57077a..71fc041749 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -12,8 +12,8 @@ error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied > > > - as FromRequestParts> > + as FromRequestParts> > and $N others = note: required for `bool` to implement `FromRequest<(), axum_core::extract::private::ViaParts>` diff --git a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr index f473aefd29..e4a0d849a2 100644 --- a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr +++ b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr @@ -13,6 +13,6 @@ error[E0277]: the trait bound `String: FromRequestParts` is not satisfied > > > - as FromRequestParts> > + as FromRequestParts> and $N others diff --git a/axum/benches/benches.rs b/axum/benches/benches.rs index 9888fd5ebb..5bcdc906f9 100644 --- a/axum/benches/benches.rs +++ b/axum/benches/benches.rs @@ -5,6 +5,7 @@ use axum::{ }; use serde::{Deserialize, Serialize}; use std::{ + future::IntoFuture, io::BufRead, process::{Command, Stdio}, }; @@ -161,7 +162,8 @@ impl BenchmarkBuilder { let addr = listener.local_addr().unwrap(); std::thread::spawn(move || { - rt.block_on(axum::serve(listener, app)).unwrap(); + rt.block_on(axum::serve(listener, app).into_future()) + .unwrap(); }); let mut cmd = Command::new("rewrk"); diff --git a/axum/src/serve.rs b/axum/src/serve.rs index f28c1c2ba2..c6aa2784c8 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -2,8 +2,9 @@ use std::{ convert::Infallible, - future::Future, + future::{Future, IntoFuture}, io, + marker::PhantomData, net::SocketAddr, pin::Pin, task::{Context, Poll}, @@ -86,48 +87,129 @@ use tower_service::Service; /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -pub async fn serve(tcp_listener: TcpListener, mut make_service: M) -> io::Result<()> +pub fn serve(tcp_listener: TcpListener, make_service: M) -> Serve where M: for<'a> Service, Error = Infallible, Response = S>, S: Service + Clone + Send + 'static, S::Future: Send, { - loop { - let (tcp_stream, remote_addr) = tcp_listener.accept().await?; - let tcp_stream = TokioIo::new(tcp_stream); - - poll_fn(|cx| make_service.poll_ready(cx)) - .await - .unwrap_or_else(|err| match err {}); - - let tower_service = make_service - .call(IncomingStream { - tcp_stream: &tcp_stream, - remote_addr, - }) - .await - .unwrap_or_else(|err| match err {}); - - let hyper_service = TowerToHyperService { - service: tower_service, - }; - - tokio::task::spawn(async move { - match Builder::new(TokioExecutor::new()) - // upgrades needed for websockets - .serve_connection_with_upgrades(tcp_stream, hyper_service) - .await - { - Ok(()) => {} - Err(_err) => { - // This error only appears when the client doesn't send a request and - // terminate the connection. - // - // If client sends one request then terminate connection whenever, it doesn't - // appear. - } + Serve { + tcp_listener, + make_service, + _marker: PhantomData, + } +} + +/// Future returned by [`serve`]. +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +pub struct Serve { + tcp_listener: TcpListener, + make_service: M, + _marker: PhantomData, +} + +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl std::fmt::Debug for Serve +where + M: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { + tcp_listener, + make_service, + _marker: _, + } = self; + + f.debug_struct("Serve") + .field("tcp_listener", tcp_listener) + .field("make_service", make_service) + .finish() + } +} + +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl IntoFuture for Serve +where + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, + S: Service + Clone + Send + 'static, + S::Future: Send, +{ + type Output = io::Result<()>; + type IntoFuture = private::ServeFuture; + + fn into_future(self) -> Self::IntoFuture { + private::ServeFuture(Box::pin(async move { + let Self { + tcp_listener, + mut make_service, + _marker: _, + } = self; + + loop { + let (tcp_stream, remote_addr) = tcp_listener.accept().await?; + let tcp_stream = TokioIo::new(tcp_stream); + + poll_fn(|cx| make_service.poll_ready(cx)) + .await + .unwrap_or_else(|err| match err {}); + + let tower_service = make_service + .call(IncomingStream { + tcp_stream: &tcp_stream, + remote_addr, + }) + .await + .unwrap_or_else(|err| match err {}); + + let hyper_service = TowerToHyperService { + service: tower_service, + }; + + tokio::task::spawn(async move { + match Builder::new(TokioExecutor::new()) + // upgrades needed for websockets + .serve_connection_with_upgrades(tcp_stream, hyper_service) + .await + { + Ok(()) => {} + Err(_err) => { + // This error only appears when the client doesn't send a request and + // terminate the connection. + // + // If client sends one request then terminate connection whenever, it doesn't + // appear. + } + } + }); } - }); + })) + } +} + +mod private { + use std::{ + future::Future, + io, + pin::Pin, + task::{Context, Poll}, + }; + + pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>); + + impl Future for ServeFuture { + type Output = io::Result<()>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0.as_mut().poll(cx) + } + } + + impl std::fmt::Debug for ServeFuture { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ServeFuture").finish_non_exhaustive() + } } } diff --git a/examples/testing-websockets/src/main.rs b/examples/testing-websockets/src/main.rs index 954168b170..384be35d53 100644 --- a/examples/testing-websockets/src/main.rs +++ b/examples/testing-websockets/src/main.rs @@ -92,7 +92,10 @@ where #[cfg(test)] mod tests { use super::*; - use std::net::{Ipv4Addr, SocketAddr}; + use std::{ + future::IntoFuture, + net::{Ipv4Addr, SocketAddr}, + }; use tokio_tungstenite::tungstenite; // We can integration test one handler by running the server in a background task and @@ -103,7 +106,7 @@ mod tests { .await .unwrap(); let addr = listener.local_addr().unwrap(); - tokio::spawn(axum::serve(listener, app())); + tokio::spawn(axum::serve(listener, app()).into_future()); let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/integration-testable"))