diff --git a/Cargo.toml b/Cargo.toml index 7dd959a2..842f6e2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ members = [ "examples/databases/*", "examples/htmlx", "examples/tower", + "examples/smol", ] [workspace.package] @@ -113,6 +114,14 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } tower = "0.4" tower-http = "0.5" +# soml +async-channel = "2.1" +async-executor = "1.8" +async-io = "2.2" +async-net = "2.0" +smol-hyper = "0.1.1" +futures-lite = { version = "2.1.0", default-features = false, features = ["std"] } + [workspace.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] diff --git a/examples/smol/Cargo.toml b/examples/smol/Cargo.toml new file mode 100644 index 00000000..4bfc81e8 --- /dev/null +++ b/examples/smol/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "smol-example" +version = "0.1.0" +edition.workspace = true +publish = false + +[dependencies] +viz = { workspace = true, features = ["smol"] } + +# smol +async-executor = "1.8" +async-io = "2.2" +async-net = "2.0" +smol-hyper = "0.1.1" +smol-macros = "0.1" +macro_rules_attribute = "0.2" diff --git a/examples/smol/src/main.rs b/examples/smol/src/main.rs new file mode 100644 index 00000000..0793e449 --- /dev/null +++ b/examples/smol/src/main.rs @@ -0,0 +1,23 @@ +use std::io; +use std::sync::Arc; + +use async_net::TcpListener; +use macro_rules_attribute::apply; +use viz::{IntoResponse, Request, Response, Result, Router}; + +#[apply(smol_macros::main!)] +async fn main(ex: &Arc>) -> io::Result<()> { + // Build our application with a route. + let app = Router::new().get("/", handler); + + // Create a `smol`-based TCP listener. + let listener = TcpListener::bind(("127.0.0.1", 3000)).await.unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + + // Run it + viz::serve(ex.clone(), listener, app).await +} + +async fn handler(_: Request) -> Result { + Ok("

Hello, World!

".into_response()) +} diff --git a/viz/Cargo.toml b/viz/Cargo.toml index 768d3ef2..99fddf87 100644 --- a/viz/Cargo.toml +++ b/viz/Cargo.toml @@ -67,6 +67,15 @@ otel-prometheus = ["handlers", "viz-handlers?/prometheus"] rustls = ["dep:rustls-pemfile", "dep:tokio-rustls", "dep:futures-util"] native-tls = ["dep:tokio-native-tls", "dep:futures-util"] +smol = [ + # "dep:async-channel", + "dep:async-executor", + # "dep:async-io", + "dep:async-net", + "dep:smol-hyper", + "dep:futures-lite" +] + [dependencies] viz-core.workspace = true viz-router.workspace = true @@ -86,6 +95,14 @@ tokio = { workspace = true, features = ["macros"] } tokio-util = { workspace = true, features = ["net"] } tracing.workspace = true +# smol +# async-channel = { workspace = true, optional = true } +async-executor = { workspace = true, optional = true } +# async-io = { workspace = true, optional = true } +async-net = { workspace = true, optional = true } +smol-hyper = { workspace = true, optional = true } +futures-lite = { workspace = true, optional = true } + [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } diff --git a/viz/src/lib.rs b/viz/src/lib.rs index 2444960f..24d7ac46 100644 --- a/viz/src/lib.rs +++ b/viz/src/lib.rs @@ -528,9 +528,8 @@ pub use responder::Responder; mod server; pub use server::{serve, Listener, Server}; -/// TLS #[cfg(any(feature = "native_tls", feature = "rustls"))] -pub mod tls; +pub use server::tls; pub use viz_core::*; pub use viz_router::*; diff --git a/viz/src/server.rs b/viz/src/server.rs index 6999db2d..67c87b45 100644 --- a/viz/src/server.rs +++ b/viz/src/server.rs @@ -1,188 +1,53 @@ -use std::{ - fmt::Debug, - future::{pending, Future, IntoFuture, Pending}, - io, - pin::Pin, - sync::Arc, -}; - -use hyper_util::{ - rt::{TokioExecutor, TokioIo}, - server::conn::auto::Builder, -}; -use tokio::{pin, select, sync::watch}; - -use crate::{future::FutureExt, Responder, Router, Tree}; +use std::{fmt::Debug, future::Pending}; mod listener; pub use listener::Listener; -#[cfg(any(feature = "http1", feature = "http2"))] -mod tcp; +#[cfg(not(feature = "smol"))] +mod tokio; +#[cfg(not(feature = "smol"))] +pub use self::tokio::serve; + +#[cfg(feature = "smol")] +mod smol; +#[cfg(feature = "smol")] +pub use self::smol::serve; + +#[cfg(any(feature = "native_tls", feature = "rustls"))] +#[path = "server/tls.rs"] +pub(super) mod internal; + +/// TLS +#[cfg(any(feature = "native_tls", feature = "rustls"))] +pub mod tls { + pub use super::internal::*; -#[cfg(all(unix, feature = "unix-socket"))] -mod unix; + #[cfg(not(feature = "smol"))] + pub use super::tokio::tls::*; -/// Starts a server and serves the connections. -pub fn serve(listener: L, router: Router) -> Server -where - L: Listener + Send + 'static, - L::Io: Send + Unpin, - L::Addr: Send + Sync + Debug, -{ - Server::::new(listener, router) + #[cfg(feature = "smol")] + pub use super::smol::tls::*; } /// A listening HTTP server that accepts connections. #[derive(Debug)] -pub struct Server> { - signal: F, - tree: Tree, +pub struct Server> { + signal: S, + tree: crate::Tree, + executor: E, listener: L, - builder: Builder, + build: F, } -impl Server { - /// Starts a [`Server`] with a listener and a [`Tree`]. - pub fn new(listener: L, router: Router) -> Server { - Server { - listener, - signal: pending(), - tree: router.into(), - builder: Builder::new(TokioExecutor::new()), - } - } - +impl Server { /// Changes the signal for graceful shutdown. - pub fn signal(self, signal: T) -> Server { + pub fn signal(self, signal: X) -> Server { Server { signal, tree: self.tree, - builder: self.builder, + build: self.build, + executor: self.executor, listener: self.listener, } } - - /// Returns the HTTP1 or HTTP2 connection builder. - pub fn builder(&mut self) -> &mut Builder { - &mut self.builder - } -} - -/// Copied from Axum. Thanks. -impl IntoFuture for Server -where - L: Listener + Send + 'static, - L::Io: Send + Unpin, - L::Addr: Send + Sync + Debug, - F: Future + Send + 'static, -{ - type Output = io::Result<()>; - type IntoFuture = Pin + Send>>; - - fn into_future(self) -> Self::IntoFuture { - let Self { - tree, - signal, - builder, - listener, - } = self; - - let (shutdown_tx, shutdown_rx) = watch::channel(()); - let shutdown_tx = Arc::new(shutdown_tx); - - tokio::spawn(async move { - signal.await; - tracing::trace!("received graceful shutdown signal"); - drop(shutdown_rx); - }); - - let (close_tx, close_rx) = watch::channel(()); - - let tree = Arc::new(tree); - - Box::pin(async move { - loop { - let (stream, remote_addr) = select! { - res = listener.accept() => { - match res { - Ok(conn) => conn, - Err(e) => { - if !is_connection_error(&e) { - // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) - tracing::error!("listener accept error: {e}"); - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - } - continue - } - } - } - () = shutdown_tx.closed() => { - tracing::trace!("server is closing"); - break; - } - }; - - tracing::trace!("connection {:?} accepted", remote_addr); - - let io = TokioIo::new(stream); - let remote_addr = Arc::new(remote_addr); - let builder = builder.clone(); - let responder = - Responder::>::new(tree.clone(), Some(remote_addr.clone())); - - let shutdown_tx = Arc::clone(&shutdown_tx); - let close_rx = close_rx.clone(); - - tokio::spawn(async move { - let conn = builder.serve_connection_with_upgrades(io, responder); - pin!(conn); - - let shutdown = shutdown_tx.closed().fuse(); - pin!(shutdown); - - loop { - select! { - res = conn.as_mut() => { - if let Err(e) = res { - tracing::error!("connection failed: {e}"); - } - break; - } - () = &mut shutdown => { - tracing::trace!("connection is starting to graceful shutdown"); - conn.as_mut().graceful_shutdown(); - } - } - } - - tracing::trace!("connection {:?} closed", remote_addr); - - drop(close_rx); - }); - } - - drop(close_rx); - drop(listener); - - tracing::trace!( - "waiting for {} task(s) to finish", - close_tx.receiver_count() - ); - close_tx.closed().await; - - tracing::trace!("server shutdown complete"); - - Ok(()) - }) - } -} - -fn is_connection_error(e: &io::Error) -> bool { - matches!( - e.kind(), - io::ErrorKind::ConnectionRefused - | io::ErrorKind::ConnectionAborted - | io::ErrorKind::ConnectionReset - ) } diff --git a/viz/src/server/listener.rs b/viz/src/server/listener.rs index 4b72e1eb..8f26ce80 100644 --- a/viz/src/server/listener.rs +++ b/viz/src/server/listener.rs @@ -1,11 +1,9 @@ use std::{future::Future, io::Result}; -use tokio::io::{AsyncRead, AsyncWrite}; - /// A trait for a listener: `TcpListener` and `UnixListener`. pub trait Listener { /// The stream's type of this listener. - type Io: AsyncRead + AsyncWrite; + type Io; /// The socket address type of this listener. type Addr; diff --git a/viz/src/server/smol.rs b/viz/src/server/smol.rs new file mode 100644 index 00000000..1d7da493 --- /dev/null +++ b/viz/src/server/smol.rs @@ -0,0 +1,112 @@ +use std::{ + borrow::Borrow, + fmt::Debug, + future::{pending, Future, IntoFuture, Pending}, + io, + pin::Pin, + sync::Arc, +}; + +use async_executor::Executor; +use futures_lite::io::{AsyncRead, AsyncWrite}; +use hyper::rt::Timer; +use hyper_util::server::conn::auto::Builder; +use smol_hyper::rt::{FuturesIo, SmolExecutor, SmolTimer}; + +use crate::{future::FutureExt, Responder, Router, Tree}; +use crate::{Listener, Server}; + +#[cfg(any(feature = "http1", feature = "http2"))] +mod tcp; + +#[cfg(all(unix, feature = "unix-socket"))] +mod unix; + +/// TLS +#[cfg(any(feature = "native_tls", feature = "rustls"))] +pub mod tls; + +// impl Server { +// /// Starts a [`Server`] with a listener and a [`Router`]. +// pub fn new<'ex>(executor: E, listener: L, router: Router) -> Server> +// where +// E: Borrow> + Clone + Send + 'ex, +// { +// Server { +// executor, +// listener, +// signal: pending(), +// tree: router.into(), +// build: +// } +// } +// } + +/// Serve a future using [`smol`]'s TCP listener. +pub async fn serve<'ex, E, L>(executor: E, listener: L, router: Router) -> io::Result<()> +where + E: Borrow> + Clone + Send + 'ex, + L: Listener + Send + 'static, + L::Io: AsyncRead + AsyncWrite + Send + Unpin, + L::Addr: Send + Sync + Debug, +{ + let tree = Arc::::new(router.into()); + + loop { + // Wait for a new connection. + let (stream, remote_addr) = match listener.accept().await { + Ok(conn) => conn, + Err(e) => { + if !is_connection_error(&e) { + // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) + tracing::error!("listener accept error: {e}"); + SmolTimer::new() + .sleep(std::time::Duration::from_secs(1)) + .await; + } + continue; + } + }; + + // Wrap it in a `FuturesIo`. + let io = FuturesIo::new(stream); + let remote_addr = Arc::new(remote_addr); + let responder = Responder::>::new(tree.clone(), Some(remote_addr.clone())); + + // Spawn the service on our executor. + let task = executor.borrow().spawn({ + let executor = executor.clone(); + async move { + let mut builder = Builder::new(SmolExecutor::new(AsRefExecutor(executor.borrow()))); + builder.http1().timer(SmolTimer::new()); + builder.http2().timer(SmolTimer::new()); + + if let Err(err) = builder.serve_connection_with_upgrades(io, responder).await { + tracing::error!("unintelligible hyper error: {err}"); + } + } + }); + + // Detach the task and let it run forever. + task.detach(); + } +} + +fn is_connection_error(e: &io::Error) -> bool { + matches!( + e.kind(), + io::ErrorKind::ConnectionRefused + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::ConnectionReset + ) +} + +#[derive(Clone)] +struct AsRefExecutor<'this, 'ex>(&'this Executor<'ex>); + +impl<'ex> AsRef> for AsRefExecutor<'_, 'ex> { + #[inline] + fn as_ref(&self) -> &Executor<'ex> { + self.0 + } +} diff --git a/viz/src/server/smol/tcp.rs b/viz/src/server/smol/tcp.rs new file mode 100644 index 00000000..2eeb4e85 --- /dev/null +++ b/viz/src/server/smol/tcp.rs @@ -0,0 +1,16 @@ +use std::{future::Future, io::Result}; + +use async_net::{SocketAddr, TcpListener, TcpStream}; + +impl crate::Listener for TcpListener { + type Io = TcpStream; + type Addr = SocketAddr; + + fn accept(&self) -> impl Future> + Send { + TcpListener::accept(self) + } + + fn local_addr(&self) -> Result { + TcpListener::local_addr(self) + } +} diff --git a/viz/src/server/smol/tls.rs b/viz/src/server/smol/tls.rs new file mode 100644 index 00000000..70b786d1 --- /dev/null +++ b/viz/src/server/smol/tls.rs @@ -0,0 +1 @@ +// TODO diff --git a/viz/src/server/smol/unix.rs b/viz/src/server/smol/unix.rs new file mode 100644 index 00000000..29313d71 --- /dev/null +++ b/viz/src/server/smol/unix.rs @@ -0,0 +1,16 @@ +use std::{future::Future, io::Result}; + +use async_net::unix::{SocketAddr, UnixListener, UnixStream}; + +impl crate::Listener for UnixListener { + type Io = UnixStream; + type Addr = SocketAddr; + + fn accept(&self) -> impl Future> + Send { + UnixListener::accept(self) + } + + fn local_addr(&self) -> Result { + UnixListener::local_addr(self) + } +} diff --git a/viz/src/tls/listener.rs b/viz/src/server/tls.rs similarity index 59% rename from viz/src/tls/listener.rs rename to viz/src/server/tls.rs index 081d23a7..1491baa6 100644 --- a/viz/src/tls/listener.rs +++ b/viz/src/server/tls.rs @@ -1,3 +1,5 @@ +//! A TLS listener wrapper. + /// Unified TLS listener type. #[derive(Debug)] pub struct TlsListener { @@ -13,4 +15,14 @@ impl TlsListener { acceptor: a, } } + + /// Gets the listener. + pub fn get_ref(&self) -> &T { + &self.inner + } + + /// Gets the acceptor. + pub fn get_acceptor(&self) -> &A { + &self.acceptor + } } diff --git a/viz/src/server/tokio.rs b/viz/src/server/tokio.rs new file mode 100644 index 00000000..7e3165b0 --- /dev/null +++ b/viz/src/server/tokio.rs @@ -0,0 +1,177 @@ +use std::{ + fmt::Debug, + future::{pending, Future, IntoFuture, Pending}, + io, + pin::Pin, + sync::Arc, +}; + +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + pin, select, + sync::watch, +}; + +use crate::{future::FutureExt, Responder, Router}; +use crate::{Listener, Server}; + +#[cfg(any(feature = "http1", feature = "http2"))] +mod tcp; + +#[cfg(all(unix, feature = "unix-socket"))] +mod unix; + +/// TLS +#[cfg(any(feature = "native_tls", feature = "rustls"))] +pub mod tls; + +/// Starts a server and serves the connections. +pub fn serve( + listener: L, + router: Router, +) -> Server Builder> { + Server:: Builder>::new( + TokioExecutor::new(), + listener, + router, + Builder::new, + ) +} + +impl Server { + /// Starts a [`Server`] with a listener and a [`Router`]. + pub fn new(executor: E, listener: L, router: Router, build: F) -> Server> + where + F: Fn(TokioExecutor) -> Builder + Send + 'static, + { + Server { + build, + executor, + listener, + signal: pending(), + tree: router.into(), + } + } +} + +/// Copied from Axum. Thanks. +impl IntoFuture for Server +where + L: Listener + Send + 'static, + L::Io: AsyncRead + AsyncWrite + Send + Unpin, + L::Addr: Send + Sync + Debug, + F: Fn(TokioExecutor) -> Builder + Send + 'static, + S: Future + Send + 'static, +{ + type Output = io::Result<()>; + type IntoFuture = Pin + Send>>; + + fn into_future(self) -> Self::IntoFuture { + let Self { + tree, + build, + signal, + executor, + listener, + } = self; + + let (shutdown_tx, shutdown_rx) = watch::channel(()); + let shutdown_tx = Arc::new(shutdown_tx); + + tokio::spawn(async move { + signal.await; + tracing::trace!("received graceful shutdown signal"); + drop(shutdown_rx); + }); + + let (close_tx, close_rx) = watch::channel(()); + + let tree = Arc::new(tree); + + Box::pin(async move { + loop { + let (stream, remote_addr) = select! { + res = listener.accept() => { + match res { + Ok(conn) => conn, + Err(e) => { + if !is_connection_error(&e) { + // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) + tracing::error!("listener accept error: {e}"); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } + continue + } + } + } + () = shutdown_tx.closed() => { + tracing::trace!("server is closing"); + break; + } + }; + + tracing::trace!("connection {:?} accepted", remote_addr); + + let io = TokioIo::new(stream); + let remote_addr = Arc::new(remote_addr); + let builder = (build)(executor.clone()); + let responder = + Responder::>::new(tree.clone(), Some(remote_addr.clone())); + + let shutdown_tx = Arc::clone(&shutdown_tx); + let close_rx = close_rx.clone(); + + tokio::spawn(async move { + let conn = builder.serve_connection_with_upgrades(io, responder); + pin!(conn); + + let shutdown = shutdown_tx.closed().fuse(); + pin!(shutdown); + + loop { + select! { + res = conn.as_mut() => { + if let Err(e) = res { + tracing::error!("connection failed: {e}"); + } + break; + } + () = &mut shutdown => { + tracing::trace!("connection is starting to graceful shutdown"); + conn.as_mut().graceful_shutdown(); + } + } + } + + tracing::trace!("connection {:?} closed", remote_addr); + + drop(close_rx); + }); + } + + drop(close_rx); + drop(listener); + + tracing::trace!( + "waiting for {} task(s) to finish", + close_tx.receiver_count() + ); + close_tx.closed().await; + + tracing::trace!("server shutdown complete"); + + Ok(()) + }) + } +} + +fn is_connection_error(e: &io::Error) -> bool { + matches!( + e.kind(), + io::ErrorKind::ConnectionRefused + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::ConnectionReset + ) +} diff --git a/viz/src/server/tcp.rs b/viz/src/server/tokio/tcp.rs similarity index 100% rename from viz/src/server/tcp.rs rename to viz/src/server/tokio/tcp.rs diff --git a/viz/src/tls.rs b/viz/src/server/tokio/tls.rs similarity index 72% rename from viz/src/tls.rs rename to viz/src/server/tokio/tls.rs index 0a53abcc..673b1cba 100644 --- a/viz/src/tls.rs +++ b/viz/src/server/tokio/tls.rs @@ -1,7 +1,3 @@ -mod listener; - -pub use listener::TlsListener; - /// `native_tls` #[cfg(feature = "native-tls")] pub mod native_tls; diff --git a/viz/src/tls/native_tls.rs b/viz/src/server/tokio/tls/native_tls.rs similarity index 95% rename from viz/src/tls/native_tls.rs rename to viz/src/server/tokio/tls/native_tls.rs index f727f661..bdd79c57 100644 --- a/viz/src/tls/native_tls.rs +++ b/viz/src/server/tokio/tls/native_tls.rs @@ -37,7 +37,7 @@ impl Config { } } -impl crate::Listener for super::TlsListener { +impl crate::Listener for crate::tls::TlsListener { type Io = TlsStream; type Addr = SocketAddr; diff --git a/viz/src/tls/rustls.rs b/viz/src/server/tokio/tls/rustls.rs similarity index 98% rename from viz/src/tls/rustls.rs rename to viz/src/server/tokio/tls/rustls.rs index 3144bdd1..1a84a25c 100644 --- a/viz/src/tls/rustls.rs +++ b/viz/src/server/tokio/tls/rustls.rs @@ -148,7 +148,7 @@ impl Config { } } -impl crate::Listener for super::TlsListener { +impl crate::Listener for crate::tls::TlsListener { type Io = TlsStream; type Addr = SocketAddr; diff --git a/viz/src/server/unix.rs b/viz/src/server/tokio/unix.rs similarity index 100% rename from viz/src/server/unix.rs rename to viz/src/server/tokio/unix.rs