diff --git a/Cargo.toml b/Cargo.toml index d69999b..8c2e8ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,12 +29,13 @@ tracing = { version = "0.1", default-features = false, features = ["std"] } tokio = { version = "1", optional = true, features = ["net", "rt", "time"] } tower-service ={ version = "0.3", optional = true } tower = { version = "0.4.1", optional = true, features = ["make", "util"] } +slab = { version = "0.4.9", optional = true } [dev-dependencies] hyper = { version = "1.1.0", features = ["full"] } bytes = "1" http-body-util = "0.1.0" -tokio = { version = "1", features = ["macros", "test-util"] } +tokio = { version = "1", features = ["macros", "test-util", "signal"] } tokio-test = "0.4" pretty_env_logger = "0.5" @@ -50,6 +51,7 @@ full = [ "client-legacy", "server", "server-auto", + "server-graceful", "service", "http1", "http2", @@ -61,6 +63,7 @@ client-legacy = ["client"] server = ["hyper/server"] server-auto = ["server", "http1", "http2"] +server-graceful = ["dep:slab"] service = ["dep:tower", "dep:tower-service"] @@ -75,3 +78,7 @@ __internal_happy_eyeballs_tests = [] [[example]] name = "client" required-features = ["client-legacy", "http1", "tokio"] + +[[example]] +name = "server_graceful" +required-features = ["tokio", "server-graceful", "server-auto"] diff --git a/examples/server_graceful.rs b/examples/server_graceful.rs new file mode 100644 index 0000000..8bd208b --- /dev/null +++ b/examples/server_graceful.rs @@ -0,0 +1,65 @@ +use bytes::Bytes; +use std::convert::Infallible; +use std::pin::pin; +use std::time::Duration; +use tokio::net::TcpListener; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let listener = TcpListener::bind("127.0.0.1:8080").await?; + + let graceful = hyper_util::server::graceful::GracefulShutdown::new(); + let mut ctrl_c = pin!(tokio::signal::ctrl_c()); + + loop { + tokio::select! { + conn = listener.accept() => { + let (stream, peer_addr) = match conn { + Ok(conn) => conn, + Err(e) => { + eprintln!("accept error: {}", e); + tokio::time::sleep(Duration::from_secs(1)).await; + continue; + } + }; + eprintln!("incomming connection accepted: {}", peer_addr); + + let stream = hyper_util::rt::TokioIo::new(Box::pin(stream)); + let watcher = graceful.watcher(); + + tokio::spawn(async move { + let server = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + let conn = server.serve_connection_with_upgrades(stream, hyper::service::service_fn(|_| async move { + tokio::time::sleep(Duration::from_secs(5)).await; // emulate slow request + let body = http_body_util::Full::::from("Hello World!".to_owned()); + Ok::<_, Infallible>(http::Response::new(body)) + })); + + let conn = watcher.watch(conn); + + if let Err(err) = conn.await { + eprintln!("connection error: {}", err); + } + eprintln!("connection dropped: {}", peer_addr); + }); + }, + + _ = ctrl_c.as_mut() => { + drop(listener); + eprintln!("Ctrl-C received, starting shutdown"); + break; + } + } + } + + tokio::select! { + _ = graceful.shutdown() => { + eprintln!("Gracefully shutdown!"); + }, + _ = tokio::time::sleep(Duration::from_secs(10)) => { + eprintln!("Waited 10 seconds for graceful shutdown, aborting..."); + } + } + + Ok(()) +} diff --git a/src/server/graceful.rs b/src/server/graceful.rs new file mode 100644 index 0000000..0037fe1 --- /dev/null +++ b/src/server/graceful.rs @@ -0,0 +1,620 @@ +//! Utility to gracefully shutdown a server. +//! +//! This module provides a [`GracefulShutdown`] type, +//! which can be used to gracefully shutdown a server. +//! +//! See +//! for an example of how to use this. + +use pin_project_lite::pin_project; +use slab::Slab; +use std::{ + fmt::{self, Debug}, + future::Future, + pin::Pin, + sync::{atomic::AtomicUsize, Arc, Mutex}, + task::{self, Poll, Waker}, +}; + +/// A graceful shutdown watcher +pub struct GracefulShutdown { + // state used to keep track of all futures that are being watched + state: Arc, + // state that the watched futures to know when shutdown signal is received + future_state: Arc, +} + +impl Debug for GracefulShutdown { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulShutdown").finish() + } +} + +impl Default for GracefulShutdown { + fn default() -> Self { + Self::new() + } +} + +impl GracefulShutdown { + /// Create a new graceful shutdown watcher + pub fn new() -> Self { + Self { + state: Arc::new(GracefulState { + counter: AtomicUsize::new(0), + waker_list: Mutex::new(Slab::new()), + }), + future_state: Arc::new(GracefulState { + counter: AtomicUsize::new(1), + waker_list: Mutex::new(Slab::new()), + }), + } + } + + /// Get a graceful shutdown watcher + pub fn watcher(&self) -> GracefulWatcher { + self.state.subscribe(); + GracefulWatcher { + state: self.state.clone(), + future_state: self.future_state.clone(), + } + } + + /// Wait for a graceful shutdown + pub fn shutdown(self) -> GracefulWaiter { + // prepare futures and signal them to shutdown + self.future_state.unsubscribe(); + + // return the future to wait for shutdown + GracefulWaiter::new(self.state) + } +} + +/// A graceful shutdown watcher. +pub struct GracefulWatcher { + state: Arc, + future_state: Arc, +} + +impl Drop for GracefulWatcher { + fn drop(&mut self) { + self.state.unsubscribe(); + } +} + +impl GracefulWatcher { + /// Watch a future for graceful shutdown, + /// returning a wrapper that can be awaited on. + pub fn watch(&self, conn: C) -> GracefulFuture { + // add a counter for this future to ensure it is taken into account + self.state.subscribe(); + + let cancel = GracefulWaiter::new(self.future_state.clone()); + let future = GracefulConnectionFuture::new(conn, cancel); + + // return the graceful future, ready to be shutdown, + // and handling all the hyper graceful logic + GracefulFuture { + future, + state: Some(self.state.clone()), + } + } +} + +struct GracefulState { + counter: AtomicUsize, + waker_list: Mutex>>, +} + +impl GracefulState { + fn subscribe(&self) { + self.counter + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + + fn unsubscribe(&self) { + if self + .counter + .fetch_sub(1, std::sync::atomic::Ordering::SeqCst) + == 1 + { + let mut waker_list = self.waker_list.lock().unwrap(); + for (_, waker) in waker_list.iter_mut() { + if let Some(waker) = waker.take() { + waker.wake(); + } + } + } + } +} + +pin_project! { + /// A wrapper around a future that's being watched for graceful shutdown. + /// + /// This is returned by [`GracefulShutdown::watch`]. + /// + /// # Panics + /// + /// This future might panic if it is polled + /// after the internal future has already returned `Poll::Ready` before. + /// + /// Whether or not this future panics in such cases is an implementation detail + /// and should not be relied upon. + pub struct GracefulFuture { + #[pin] + future: GracefulConnectionFuture, + state: Option>, + } +} + +impl Debug for GracefulFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulFuture").finish() + } +} + +impl Future for GracefulFuture +where + C: GracefulConnection, +{ + type Output = C::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let this = self.project(); + match this.future.poll(cx) { + Poll::Ready(v) => { + this.state.take().unwrap().unsubscribe(); + Poll::Ready(v) + } + Poll::Pending => Poll::Pending, + } + } +} + +/// A future that waits until the graceful shutdown is completed. +pub struct GracefulWaiter { + state: GracefulWaiterState, +} + +impl GracefulWaiter { + fn new(state: Arc) -> Self { + Self { + state: GracefulWaiterState { state, key: None }, + } + } +} + +impl Future for GracefulWaiter { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let state = &mut self.state; + + if state + .state + .counter + .load(std::sync::atomic::Ordering::SeqCst) + == 0 + { + return Poll::Ready(()); + } + + let mut waker_list = state.state.waker_list.lock().unwrap(); + + if state + .state + .counter + .load(std::sync::atomic::Ordering::SeqCst) + == 0 + { + // check again in case of race condition + return Poll::Ready(()); + } + + let waker = Some(cx.waker().clone()); + state.key = Some(match state.key.take() { + Some(key) => { + *waker_list.get_mut(key).unwrap() = waker; + key + } + None => waker_list.insert(waker), + }); + + Poll::Pending + } +} + +struct GracefulWaiterState { + state: Arc, + key: Option, +} + +impl Drop for GracefulWaiterState { + /// When the waiter is dropped, we need to remove its waker from the waker list. + /// As to ensure the graceful waiter is cancel safe. + fn drop(&mut self) { + if let Some(key) = self.key.take() { + let mut wakers = self.state.waker_list.lock().unwrap(); + wakers.remove(key); + } + } +} + +pin_project! { + struct GracefulConnectionFuture { + #[pin] + conn: C, + #[pin] + cancel: F, + cancelled: bool, + } +} + +impl GracefulConnectionFuture { + fn new(conn: C, cancel: F) -> Self { + Self { + conn, + cancel, + cancelled: false, + } + } +} + +impl Debug for GracefulConnectionFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulConnectionFuture").finish() + } +} + +impl Future for GracefulConnectionFuture +where + C: GracefulConnection, + F: Future, +{ + type Output = C::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let mut this = self.project(); + if !*this.cancelled { + if let Poll::Ready(_) = this.cancel.poll(cx) { + *this.cancelled = true; + this.conn.as_mut().graceful_shutdown(); + } + } + this.conn.poll(cx) + } +} + +/// An internal utility trait as an umbrella target for all (hyper) connection +/// types that the [`GracefulShutdown`] can watch. +pub trait GracefulConnection: Future> + private::Sealed { + /// The error type returned by the connection when used as a future. + type Error; + + /// Start a graceful shutdown process for this connection. + fn graceful_shutdown(self: Pin<&mut Self>); +} + +#[cfg(feature = "http1")] +impl GracefulConnection for hyper::server::conn::http1::Connection +where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, +{ + type Error = hyper::Error; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper::server::conn::http1::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "http2")] +impl GracefulConnection for hyper::server::conn::http2::Connection +where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, +{ + type Error = hyper::Error; + + fn graceful_shutdown(self: Pin<&mut Self>) { + hyper::server::conn::http2::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "server-auto")] +impl<'a, I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'a, I, S, E> +where + S: hyper::service::Service, Response = http::Response>, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, +{ + type Error = Box; + + fn graceful_shutdown(self: Pin<&mut Self>) { + crate::server::conn::auto::Connection::graceful_shutdown(self); + } +} + +#[cfg(feature = "server-auto")] +impl<'a, I, B, S, E> GracefulConnection + for crate::server::conn::auto::UpgradeableConnection<'a, I, S, E> +where + S: hyper::service::Service, Response = http::Response>, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, +{ + type Error = Box; + + fn graceful_shutdown(self: Pin<&mut Self>) { + crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self); + } +} + +mod private { + pub trait Sealed {} + + #[cfg(feature = "http1")] + impl Sealed for hyper::server::conn::http1::Connection + where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + { + } + + #[cfg(feature = "http1")] + impl Sealed for hyper::server::conn::http1::UpgradeableConnection + where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + { + } + + #[cfg(feature = "http2")] + impl Sealed for hyper::server::conn::http2::Connection + where + S: hyper::service::HttpService, + S::Error: Into>, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, + { + } + + #[cfg(feature = "server-auto")] + impl<'a, I, B, S, E> Sealed for crate::server::conn::auto::Connection<'a, I, S, E> + where + S: hyper::service::Service< + http::Request, + Response = http::Response, + >, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, + { + } + + #[cfg(feature = "server-auto")] + impl<'a, I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'a, I, S, E> + where + S: hyper::service::Service< + http::Request, + Response = http::Response, + >, + S::Error: Into>, + S::Future: 'static, + I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + B: hyper::body::Body + 'static, + B::Error: Into>, + E: hyper::rt::bounds::Http2ServerConnExec, + { + } +} + +#[cfg(test)] +mod test { + use super::*; + use pin_project_lite::pin_project; + use std::sync::atomic::{AtomicUsize, Ordering}; + + pin_project! { + #[derive(Debug)] + struct DummyConnection { + #[pin] + future: F, + shutdown_counter: Arc, + } + } + + impl private::Sealed for DummyConnection {} + + impl GracefulConnection for DummyConnection { + type Error = (); + + fn graceful_shutdown(self: Pin<&mut Self>) { + self.shutdown_counter.fetch_add(1, Ordering::SeqCst); + } + } + + impl Future for DummyConnection { + type Output = Result<(), ()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + match self.project().future.poll(cx) { + Poll::Ready(_) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } + } + + #[tokio::test] + async fn test_graceful_shutdown_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + let (dummy_tx, _) = tokio::sync::broadcast::channel(1); + + for i in 1..=3 { + let watcher = graceful.watcher(); + let mut dummy_rx = dummy_tx.subscribe(); + let shutdown_counter = shutdown_counter.clone(); + + tokio::spawn(async move { + let future = async move { + tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await; + let _ = dummy_rx.recv().await; + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = watcher.watch(dummy_conn); + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + let _ = dummy_tx.send(()); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(500)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + } + } + } + + #[tokio::test] + async fn test_graceful_shutdown_delayed_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let watcher = graceful.watcher(); + let shutdown_counter = shutdown_counter.clone(); + + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(i * 5)).await; + let future = async move { + tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await; + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = watcher.watch(dummy_conn); + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(500)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + } + } + } + + #[tokio::test] + async fn test_graceful_shutdown_multi_per_watcher_ok() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let watcher = graceful.watcher(); + let shutdown_counter = shutdown_counter.clone(); + + tokio::spawn(async move { + let mut futures = Vec::new(); + for u in 1..=i { + let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50)); + let dummy_conn = DummyConnection { + future, + shutdown_counter: shutdown_counter.clone(), + }; + let conn = watcher.watch(dummy_conn); + futures.push(conn); + } + futures_util::future::join_all(futures).await; + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(500)) => { + panic!("timeout") + }, + _ = graceful.shutdown() => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6); + } + } + } + + #[tokio::test] + async fn test_graceful_shutdown_timeout() { + let graceful = GracefulShutdown::new(); + let shutdown_counter = Arc::new(AtomicUsize::new(0)); + + for i in 1..=3 { + let watcher = graceful.watcher(); + let shutdown_counter = shutdown_counter.clone(); + + tokio::spawn(async move { + let future = async move { + if i == 1 { + std::future::pending::<()>().await + } else { + std::future::ready(()).await + } + }; + let dummy_conn = DummyConnection { + future, + shutdown_counter, + }; + let conn = watcher.watch(dummy_conn); + conn.await.unwrap(); + }); + } + + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); + + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_millis(500)) => { + assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); + }, + _ = graceful.shutdown() => { + panic!("shutdown should not be completed: as not all our conns finish") + } + } + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 7b4515c..a4838ac 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,3 +1,6 @@ //! Server utilities. pub mod conn; + +#[cfg(feature = "server-graceful")] +pub mod graceful;