From 044bfe4844ae3e44177e6154b79054c3d30c1c5e Mon Sep 17 00:00:00 2001 From: Fangdun Tsai Date: Mon, 1 Jan 2024 15:45:44 +0800 Subject: [PATCH] chore(viz): use tokio_util Listener --- viz-handlers/src/embed.rs | 11 +++++------ viz/Cargo.toml | 1 + viz/src/lib.rs | 2 +- viz/src/server.rs | 30 +++++++++--------------------- viz/src/server/accept.rs | 14 -------------- viz/src/server/tcp.rs | 14 -------------- viz/src/server/unix.rs | 13 ------------- viz/src/tls/native_tls.rs | 28 +++++++++++++++++----------- viz/src/tls/rustls.rs | 24 +++++++++++++++++------- 9 files changed, 50 insertions(+), 87 deletions(-) delete mode 100644 viz/src/server/accept.rs delete mode 100644 viz/src/server/tcp.rs delete mode 100644 viz/src/server/unix.rs diff --git a/viz-handlers/src/embed.rs b/viz-handlers/src/embed.rs index 87b648a2..1ea13e71 100644 --- a/viz-handlers/src/embed.rs +++ b/viz-handlers/src/embed.rs @@ -35,7 +35,7 @@ where type Output = Result; async fn call(&self, req: Request) -> Self::Output { - serve::(self.0.to_string(), req) + serve::(&self.0, &req) } } @@ -67,14 +67,13 @@ where match req.route_info().params.first().map(|(_, v)| v) { Some(p) => p, None => "index.html", - } - .to_string(), - req, + }, + &req, ) } } -fn serve(path: String, req: Request) -> Result +fn serve(path: &str, req: &Request) -> Result where E: RustEmbed + Send + Sync + 'static, { @@ -82,7 +81,7 @@ where Err(StatusCode::METHOD_NOT_ALLOWED.into_error())?; } - match E::get(&path) { + match E::get(path) { Some(EmbeddedFile { data, metadata }) => { let hash = hex::encode(metadata.sha256_hash()); diff --git a/viz/Cargo.toml b/viz/Cargo.toml index 8f3190ac..768d3ef2 100644 --- a/viz/Cargo.toml +++ b/viz/Cargo.toml @@ -83,6 +83,7 @@ rustls-pemfile = { workspace = true, optional = true } tokio-native-tls = { workspace = true, optional = true } tokio-rustls = { workspace = true, optional = true } tokio = { workspace = true, features = ["macros"] } +tokio-util = { workspace = true, features = ["net"] } tracing.workspace = true [dev-dependencies] diff --git a/viz/src/lib.rs b/viz/src/lib.rs index 6332174c..ca4c2206 100644 --- a/viz/src/lib.rs +++ b/viz/src/lib.rs @@ -530,7 +530,7 @@ pub use responder::Responder; #[cfg(any(feature = "http1", feature = "http2"))] mod server; #[cfg(any(feature = "http1", feature = "http2"))] -pub use server::{serve, Accept, Server}; +pub use server::{serve, Server}; /// TLS #[cfg(any(feature = "native_tls", feature = "rustls"))] diff --git a/viz/src/server.rs b/viz/src/server.rs index 2ec31a81..3845bb3b 100644 --- a/viz/src/server.rs +++ b/viz/src/server.rs @@ -10,29 +10,17 @@ use hyper_util::{ rt::{TokioExecutor, TokioIo}, server::conn::auto::Builder, }; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - pin, select, - sync::watch, -}; +use tokio::{pin, select, sync::watch}; +use tokio_util::net::Listener; use crate::{future::FutureExt, Responder, Router, Tree}; -mod accept; -pub use accept::Accept; - -#[cfg(any(feature = "http1", feature = "http2"))] -mod tcp; - -#[cfg(feature = "unix-socket")] -mod unix; - /// Starts a server and serves the connections. pub fn serve(listener: L, router: Router) -> Server where - L: Accept + Send + 'static, - L::Stream: AsyncWrite + AsyncRead + Send + Unpin, - L::Addr: Send + Sync + Debug + 'static, + L: Listener + Send + 'static, + L::Io: Send + Unpin, + L::Addr: Send + Sync + Debug, { Server::::new(listener, router) } @@ -76,9 +64,9 @@ impl Server { /// Copied from Axum. Thanks. impl IntoFuture for Server where - L: Accept + Send + 'static, - L::Stream: AsyncWrite + AsyncRead + Send + Unpin, - L::Addr: Send + Sync + Debug + 'static, + L: Listener + Send + 'static, + L::Io: Send + Unpin, + L::Addr: Send + Sync + Debug, F: Future + Send + 'static, { type Output = io::Result<()>; @@ -89,7 +77,7 @@ where tree, signal, builder, - listener, + mut listener, } = self; let (shutdown_tx, shutdown_rx) = watch::channel(()); diff --git a/viz/src/server/accept.rs b/viz/src/server/accept.rs deleted file mode 100644 index 8a5dea0b..00000000 --- a/viz/src/server/accept.rs +++ /dev/null @@ -1,14 +0,0 @@ -//! The `Accept` trait and supporting types. - -use std::{future::Future, io::Result}; - -/// Asynchronously accept incoming connections. -pub trait Accept { - /// An accepted stream of the connection. - type Stream; - /// An accepted remote address of the connection. - type Addr; - - /// Accepts a new incoming connection from this listener. - fn accept(&self) -> impl Future> + Send; -} diff --git a/viz/src/server/tcp.rs b/viz/src/server/tcp.rs deleted file mode 100644 index b78fa75a..00000000 --- a/viz/src/server/tcp.rs +++ /dev/null @@ -1,14 +0,0 @@ -use std::future::Future; -use std::io::Result; -use std::net::SocketAddr; - -use tokio::net::{TcpListener, TcpStream}; - -impl super::Accept for TcpListener { - type Stream = TcpStream; - type Addr = SocketAddr; - - fn accept(&self) -> impl Future> + Send { - TcpListener::accept(self) - } -} diff --git a/viz/src/server/unix.rs b/viz/src/server/unix.rs deleted file mode 100644 index 9f658d94..00000000 --- a/viz/src/server/unix.rs +++ /dev/null @@ -1,13 +0,0 @@ -use std::future::Future; -use std::io::Result; - -use tokio::net::{unix::SocketAddr, UnixListener, UnixStream}; - -impl super::Accept for UnixListener { - type Stream = UnixStream; - type Addr = SocketAddr; - - fn accept(&self) -> impl Future> + Send { - UnixListener::accept(self) - } -} diff --git a/viz/src/tls/native_tls.rs b/viz/src/tls/native_tls.rs index 93a51d49..bb5180ab 100644 --- a/viz/src/tls/native_tls.rs +++ b/viz/src/tls/native_tls.rs @@ -1,9 +1,11 @@ use std::{ fmt, - io::{Error as IoError, ErrorKind}, + io::{Error as IoError, ErrorKind, Result as IoResult}, net::SocketAddr, + task::{Context, Poll}, }; +use futures_util::FutureExt; use tokio::net::{TcpListener, TcpStream}; use tokio_native_tls::{native_tls::TlsAcceptor as TlsAcceptorWrapper, TlsStream}; @@ -42,17 +44,21 @@ impl Config { } } -impl crate::Accept for Listener { - type Stream = TlsStream; +impl tokio_util::net::Listener for Listener { + type Io = TlsStream; type Addr = SocketAddr; - async fn accept(&self) -> std::io::Result<(Self::Stream, Self::Addr)> { - let (stream, addr) = self.inner.accept().await?; - let tls_stream = self - .acceptor - .accept(stream) - .await - .map_err(|e| IoError::new(ErrorKind::Other, e))?; - Ok((tls_stream, addr)) + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll> { + let Poll::Ready((stream, addr)) = self.inner.poll_accept(cx)? else { + return Poll::Pending; + }; + Box::pin(self.acceptor.accept(stream)) + .poll_unpin(cx) + .map_ok(|stream| (stream, addr)) + .map_err(|e| IoError::new(ErrorKind::Other, e)) + } + + fn local_addr(&self) -> IoResult { + self.inner.local_addr() } } diff --git a/viz/src/tls/rustls.rs b/viz/src/tls/rustls.rs index a2b6033e..e9dd4c1c 100644 --- a/viz/src/tls/rustls.rs +++ b/viz/src/tls/rustls.rs @@ -1,8 +1,10 @@ use std::{ - io::{Error as IoError, ErrorKind}, + io::{Error as IoError, ErrorKind, Result as IoResult}, net::SocketAddr, + task::{Context, Poll}, }; +use futures_util::FutureExt; use tokio::net::{TcpListener, TcpStream}; use tokio_rustls::{ rustls::{ @@ -149,13 +151,21 @@ impl Config { } } -impl crate::Accept for Listener { - type Stream = TlsStream; +impl tokio_util::net::Listener for Listener { + type Io = TlsStream; type Addr = SocketAddr; - async fn accept(&self) -> std::io::Result<(Self::Stream, Self::Addr)> { - let (stream, addr) = self.inner.accept().await?; - let tls_stream = self.acceptor.accept(stream).await?; - Ok((tls_stream, addr)) + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll> { + let Poll::Ready((stream, addr)) = self.inner.poll_accept(cx)? else { + return Poll::Pending; + }; + self.acceptor + .accept(stream) + .poll_unpin(cx) + .map_ok(|stream| (stream, addr)) + } + + fn local_addr(&self) -> IoResult { + self.inner.local_addr() } }