From 9056447d0c37f8ad29451863a5348fad84ede8a4 Mon Sep 17 00:00:00 2001 From: Fangdun Tsai Date: Thu, 5 Oct 2023 20:50:57 +0800 Subject: [PATCH] feat(viz): add serve function --- Cargo.toml | 1 + examples/hello-world/src/main.rs | 9 +- viz-core/Cargo.toml | 1 + viz-core/src/io.rs | 157 ------------------------------- viz-core/src/lib.rs | 5 +- viz/Cargo.toml | 1 + viz/src/lib.rs | 11 ++- viz/src/serve.rs | 24 +++++ 8 files changed, 39 insertions(+), 170 deletions(-) delete mode 100644 viz-core/src/io.rs create mode 100644 viz/src/serve.rs diff --git a/Cargo.toml b/Cargo.toml index b2624247..234a8716 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,7 @@ http = "0.2" http-body = "=1.0.0-rc.2" http-body-util = "=0.1.0-rc.3" hyper = { version = "=1.0.0-rc.4", features = ["server"] } +hyper-util = { git = "https://github.com/hyperium/hyper-util", rev = "63e84bf", features = ["auto"] } futures-util = "0.3" tokio = { version = "1.32", features = ["net"] } diff --git a/examples/hello-world/src/main.rs b/examples/hello-world/src/main.rs index 09ad005f..b30bf785 100644 --- a/examples/hello-world/src/main.rs +++ b/examples/hello-world/src/main.rs @@ -3,7 +3,7 @@ use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpListener; -use viz::{server::conn::http1, Io, Request, Responder, Result, Router, Tree}; +use viz::{serve, Request, Result, Router, Tree}; async fn index(_: Request) -> Result<&'static str> { Ok("Hello, World!") @@ -13,7 +13,7 @@ async fn index(_: Request) -> Result<&'static str> { async fn main() -> Result<()> { let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let listener = TcpListener::bind(addr).await?; - println!("listening on {addr}"); + println!("listening on http://{addr}"); let app = Router::new().get("/", index); let tree = Arc::new(Tree::from(app)); @@ -22,10 +22,7 @@ async fn main() -> Result<()> { let (stream, addr) = listener.accept().await?; let tree = tree.clone(); tokio::task::spawn(async move { - if let Err(err) = http1::Builder::new() - .serve_connection(Io::new(stream), Responder::new(tree, Some(addr))) - .await - { + if let Err(err) = serve(stream, Some(addr), tree).await { eprintln!("Error while serving HTTP connection: {err}"); } }); diff --git a/viz-core/Cargo.toml b/viz-core/Cargo.toml index 2173ee3f..418c91a1 100644 --- a/viz-core/Cargo.toml +++ b/viz-core/Cargo.toml @@ -63,6 +63,7 @@ http-body.workspace = true http-body-util.workspace = true hyper.workspace = true +hyper-util.workspace = true mime.workspace = true thiserror.workspace = true diff --git a/viz-core/src/io.rs b/viz-core/src/io.rs deleted file mode 100644 index e696ef02..00000000 --- a/viz-core/src/io.rs +++ /dev/null @@ -1,157 +0,0 @@ -use hyper::rt::{Read, ReadBufCursor, Write}; -use pin_project_lite::pin_project; -use std::{ - io::{Error, IoSlice}, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -pin_project! { - /// A wrapping implementing hyper IO traits for a type that - /// implements Tokio's IO traits. - #[derive(Debug)] - pub struct Io { - #[pin] - inner: T, - } -} - -impl Io { - /// Wrap a type implementing Tokio's IO traits. - pub fn new(inner: T) -> Self { - Self { inner } - } - - /// Borrow the inner type. - pub fn inner(&self) -> &T { - &self.inner - } - - /// Consume this wrapper and get the inner type. - pub fn into_inner(self) -> T { - self.inner - } -} - -impl Read for Io -where - T: AsyncRead, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: ReadBufCursor<'_>, - ) -> Poll> { - let n = unsafe { - let mut tempbuf = ReadBuf::uninit(buf.as_mut()); - match AsyncRead::poll_read(self.project().inner, cx, &mut tempbuf) { - Poll::Ready(Ok(())) => tempbuf.filled().len(), - other => return other, - } - }; - - unsafe { - buf.advance(n); - } - Poll::Ready(Ok(())) - } -} - -impl Write for Io -where - T: AsyncWrite, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - AsyncWrite::poll_write(self.project().inner, cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - AsyncWrite::poll_flush(self.project().inner, cx) - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - AsyncWrite::poll_shutdown(self.project().inner, cx) - } - - fn is_write_vectored(&self) -> bool { - AsyncWrite::is_write_vectored(&self.inner) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) - } -} - -impl AsyncRead for Io -where - T: Read, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - tempbuf: &mut ReadBuf<'_>, - ) -> Poll> { - //let init = tempbuf.initialized().len(); - let filled = tempbuf.filled().len(); - let sub_filled = unsafe { - let mut buf = hyper::rt::ReadBuf::uninit(tempbuf.unfilled_mut()); - - match Read::poll_read(self.project().inner, cx, buf.unfilled()) { - Poll::Ready(Ok(())) => buf.filled().len(), - other => return other, - } - }; - - let n_filled = filled + sub_filled; - // At least sub_filled bytes had to have been initialized. - let n_init = sub_filled; - unsafe { - tempbuf.assume_init(n_init); - tempbuf.set_filled(n_filled); - } - - Poll::Ready(Ok(())) - } -} - -impl AsyncWrite for Io -where - T: Write, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Write::poll_write(self.project().inner, cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Write::poll_flush(self.project().inner, cx) - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Write::poll_shutdown(self.project().inner, cx) - } - - fn is_write_vectored(&self) -> bool { - Write::is_write_vectored(&self.inner) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Write::poll_write_vectored(self.project().inner, cx, bufs) - } -} diff --git a/viz-core/src/lib.rs b/viz-core/src/lib.rs index 34d01f9e..39aa6a4f 100644 --- a/viz-core/src/lib.rs +++ b/viz-core/src/lib.rs @@ -5,7 +5,7 @@ #![doc(html_logo_url = "https://viz.rs/logo.svg")] #![doc(html_favicon_url = "https://viz.rs/logo.svg")] #![allow(clippy::module_name_repetitions)] -// #![forbid(unsafe_code)] +#![forbid(unsafe_code)] #![warn( missing_debug_implementations, missing_docs, @@ -43,7 +43,6 @@ mod body; mod error; mod from_request; mod into_response; -mod io; mod request; mod response; @@ -51,7 +50,6 @@ pub use body::{IncomingBody, OutgoingBody}; pub use error::Error; pub use from_request::FromRequest; pub use into_response::IntoResponse; -pub use io::Io; pub use request::RequestExt; pub use response::ResponseExt; @@ -61,6 +59,7 @@ pub use bytes::{Bytes, BytesMut}; pub use headers; pub use http::{header, Method, StatusCode}; pub use hyper::body::{Body, Incoming}; +pub use hyper_util::rt::TokioIo as Io; pub use std::future::Future; pub use thiserror::Error as ThisError; diff --git a/viz/Cargo.toml b/viz/Cargo.toml index 43778478..0421a38f 100644 --- a/viz/Cargo.toml +++ b/viz/Cargo.toml @@ -72,6 +72,7 @@ viz-handlers = { workspace = true, optional = true } viz-macros = { workspace = true, optional = true } hyper.workspace = true +hyper-util.workspace = true tokio.workspace = true futures-util = { workspace = true, optional = true } diff --git a/viz/src/lib.rs b/viz/src/lib.rs index ffb3647e..2781bd7c 100644 --- a/viz/src/lib.rs +++ b/viz/src/lib.rs @@ -522,7 +522,7 @@ #![doc(html_logo_url = "https://viz.rs/logo.svg")] #![doc(html_favicon_url = "https://viz.rs/logo.svg")] -// #![forbid(unsafe_code)] +#![forbid(unsafe_code)] #![warn( missing_debug_implementations, missing_docs, @@ -543,6 +543,9 @@ mod responder; #[cfg(any(feature = "http1", feature = "http2"))] pub use responder::Responder; +mod serve; +pub use serve::serve; + /// TLS pub mod tls; pub use viz_core::*; @@ -553,10 +556,10 @@ pub use viz_router::*; #[doc(inline)] pub use viz_handlers as handlers; +#[cfg(any(feature = "http1", feature = "http2"))] +pub use hyper::server; + #[cfg(feature = "macros")] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] #[doc(inline)] pub use viz_macros::handler; - -#[cfg(any(feature = "http1", feature = "http2"))] -pub use hyper::server; diff --git a/viz/src/serve.rs b/viz/src/serve.rs new file mode 100644 index 00000000..7969a103 --- /dev/null +++ b/viz/src/serve.rs @@ -0,0 +1,24 @@ +use std::{net::SocketAddr, sync::Arc}; + +use hyper_util::{rt::TokioExecutor, server::conn::auto::Builder}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use viz_core::{Io, Result}; +use viz_router::Tree; + +use crate::Responder; + +/// Serve the connections. +/// +/// # Errors +/// +/// Will return `Err` if the connection does not be served. +pub async fn serve(stream: I, addr: Option, tree: Arc) -> Result<()> +where + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(Io::new(stream).into_inner(), Responder::new(tree, addr)) + .await + .map_err(Into::into) +}