From f485c28d430d9c15e4982c4bdc9a01659ce606e7 Mon Sep 17 00:00:00 2001 From: Fangdun Tsai Date: Thu, 28 Dec 2023 20:59:25 +0800 Subject: [PATCH] refactor: remove async_trait on Handler --- Cargo.toml | 2 +- examples/hello-world/src/main.rs | 14 +- viz-core/src/handler.rs | 16 ++- viz-core/src/handler/after.rs | 12 +- viz-core/src/handler/and_then.rs | 12 +- viz-core/src/handler/around.rs | 10 +- viz-core/src/handler/before.rs | 15 +- viz-core/src/handler/boxed.rs | 33 +++-- viz-core/src/handler/catch_error.rs | 36 ++--- viz-core/src/handler/catch_unwind.rs | 31 ++--- viz-core/src/handler/cloneable.rs | 16 +++ viz-core/src/handler/either.rs | 7 +- viz-core/src/handler/fn_ext.rs | 7 +- viz-core/src/handler/fn_ext_hanlder.rs | 13 +- viz-core/src/handler/into_handler.rs | 8 +- viz-core/src/handler/map.rs | 12 +- viz-core/src/handler/map_err.rs | 13 +- viz-core/src/handler/map_into_response.rs | 11 +- viz-core/src/handler/or_else.rs | 12 +- viz-core/src/handler/service.rs | 15 +- viz-core/src/handler/try_handler.rs | 14 +- viz-core/src/lib.rs | 7 +- viz-core/src/macros.rs | 14 +- viz-core/src/middleware/compression.rs | 31 +++-- viz-core/src/middleware/cookie.rs | 46 ++++--- viz-core/src/middleware/cors.rs | 160 +++++++++++----------- viz-core/src/middleware/csrf.rs | 74 +++++----- viz-core/src/middleware/limits.rs | 14 +- viz-core/src/middleware/otel/metrics.rs | 59 ++++---- viz-core/src/middleware/otel/tracing.rs | 119 ++++++++-------- viz-core/src/middleware/session/config.rs | 97 +++++++------ viz-core/src/types/state.rs | 16 ++- viz-core/tests/handler.rs | 69 +++++----- viz-handlers/src/embed.rs | 24 ++-- viz-handlers/src/prometheus.rs | 40 +++--- viz-handlers/src/serve.rs | 74 +++++----- viz-macros/src/lib.rs | 11 +- viz-router/Cargo.toml | 1 + viz-router/src/resources.rs | 31 ++--- viz-router/src/route.rs | 45 +++--- viz-router/src/router.rs | 24 ++-- viz-router/src/tree.rs | 4 +- viz-tower/src/lib.rs | 21 +-- viz-tower/src/middleware.rs | 25 ++-- viz-tower/src/service.rs | 5 +- viz/Cargo.toml | 1 + viz/src/lib.rs | 41 +++--- 47 files changed, 729 insertions(+), 633 deletions(-) create mode 100644 viz-core/src/handler/cloneable.rs diff --git a/Cargo.toml b/Cargo.toml index c41335ed..cd5983c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,7 +67,7 @@ sync_wrapper = "0.1.2" thiserror = "1.0" # router -path-tree = "0.7" +path-tree = "0.7.3" # http headers = "0.4" diff --git a/examples/hello-world/src/main.rs b/examples/hello-world/src/main.rs index 520de837..8114da12 100644 --- a/examples/hello-world/src/main.rs +++ b/examples/hello-world/src/main.rs @@ -1,7 +1,7 @@ #![deny(warnings)] #![allow(clippy::unused_async)] -use std::{net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, str::FromStr, sync::Arc}; use tokio::net::TcpListener; use viz::{serve, Request, Result, Router, Tree}; @@ -11,16 +11,20 @@ async fn index(_: Request) -> Result<&'static str> { #[tokio::main] async fn main() -> Result<()> { - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + let addr = SocketAddr::from_str("[::1]:3000").unwrap(); let listener = TcpListener::bind(addr).await?; println!("listening on http://{addr}"); - let app = Router::new().get("/", index); + let mut app = Router::new().get("/", |_| async { Ok("Hello, World!") }); + + for n in 0..1000 { + app = app.get(&format!("/{}", n), index); + } + let tree = Arc::new(Tree::from(app)); loop { let (stream, addr) = listener.accept().await?; - let tree = tree.clone(); - tokio::task::spawn(serve(stream, tree, Some(addr))); + tokio::task::spawn(serve(stream, tree.clone(), Some(addr))); } } diff --git a/viz-core/src/handler.rs b/viz-core/src/handler.rs index b817a150..5cdfd9d1 100644 --- a/viz-core/src/handler.rs +++ b/viz-core/src/handler.rs @@ -1,7 +1,9 @@ //! Traits and types for handling an HTTP. -use crate::Future; -use futures_util::future::BoxFuture; +use crate::future::{BoxFuture, Future}; + +mod cloneable; +pub use cloneable::{BoxCloneable, Cloneable}; mod after; pub use after::After; @@ -71,8 +73,8 @@ pub trait Handler { impl Handler for F where I: Send + 'static, - F: Fn(I) -> Fut + ?Sized + Clone + Send + Sync + 'static, - Fut: Future + Send, + F: Fn(I) -> Fut + ?Sized + Clone + Send + 'static, + Fut: Future + Send + 'static, { type Output = Fut::Output; @@ -170,7 +172,7 @@ pub trait HandlerExt: Handler { } /// Catches rejected error while calling the handler. - fn catch_error(self, f: F) -> CatchError + fn catch_error(self, f: F) -> CatchError where Self: Sized, { @@ -188,9 +190,9 @@ pub trait HandlerExt: Handler { /// Converts this Handler into a [`BoxHandler`]. fn boxed(self) -> BoxHandler where - Self: Sized, + Self: Sized + Send + Clone + 'static, { - Box::new(self) + BoxHandler::new(self) } /// Returns a new [`Handler`] that wrapping the `Self` and a type implementing [`Transform`]. diff --git a/viz-core/src/handler/after.rs b/viz-core/src/handler/after.rs index 2c23a4ad..4f52f476 100644 --- a/viz-core/src/handler/after.rs +++ b/viz-core/src/handler/after.rs @@ -1,6 +1,7 @@ -use futures_util::{future::BoxFuture, FutureExt}; - -use crate::{Handler, Result}; +use crate::{ + future::{BoxFuture, FutureExt}, + Handler, Result, +}; /// Maps the output `Result` after the handler called. #[derive(Debug, Clone)] @@ -20,12 +21,13 @@ impl After { impl Handler for After where H: Handler>, - F: Handler + Send, + F: Handler + Send + Clone + 'static, + O: 'static, { type Output = F::Output; fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { - let f = self.f; + let f = self.f.clone(); let fut = self.h.call(i).then(move |o| f.call(o)); Box::pin(fut) } diff --git a/viz-core/src/handler/and_then.rs b/viz-core/src/handler/and_then.rs index eef7ca4c..40cfffca 100644 --- a/viz-core/src/handler/and_then.rs +++ b/viz-core/src/handler/and_then.rs @@ -1,6 +1,7 @@ -use futures_util::{future::BoxFuture, TryFutureExt}; - -use crate::{Handler, Result}; +use crate::{ + future::{BoxFuture, TryFutureExt}, + Handler, Result, +}; /// Calls `op` if the output is `Ok`, otherwise returns the `Err` value of the output. #[derive(Debug, Clone)] @@ -20,12 +21,13 @@ impl AndThen { impl Handler for AndThen where H: Handler>, - F: Handler + Send, + F: Handler + Send + Clone + 'static, + O: 'static, { type Output = F::Output; fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { - let f = self.f; + let f = self.f.clone(); let fut = self.h.call(i).and_then(move |o| f.call(o)); Box::pin(fut) } diff --git a/viz-core/src/handler/around.rs b/viz-core/src/handler/around.rs index 8b801b2a..98a0102e 100644 --- a/viz-core/src/handler/around.rs +++ b/viz-core/src/handler/around.rs @@ -1,6 +1,4 @@ -use futures_util::future::BoxFuture; - -use crate::{Handler, Result}; +use crate::{future::BoxFuture, Handler, Result}; /// Represents a middleware parameter, which is a tuple that includes Requset and `BoxHandler`. pub type Next = (I, H); @@ -22,12 +20,14 @@ impl Around { impl Handler for Around where - H: Handler> + Copy, + H: Handler> + Clone + 'static, F: Handler, Output = H::Output>, + O: 'static, { type Output = F::Output; fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { - Box::pin(self.f.call((i, self.h))) + let h = self.h.clone(); + Box::pin(self.f.call((i, h))) } } diff --git a/viz-core/src/handler/before.rs b/viz-core/src/handler/before.rs index e1720304..2424a629 100644 --- a/viz-core/src/handler/before.rs +++ b/viz-core/src/handler/before.rs @@ -1,6 +1,7 @@ -use futures_util::{future::BoxFuture, TryFutureExt}; - -use crate::{Handler, Result}; +use crate::{ + future::{BoxFuture, TryFutureExt}, + Handler, Result, +}; /// Maps the input before the handler calls. #[derive(Debug, Clone)] @@ -19,13 +20,15 @@ impl Before { impl Handler for Before where - F: Handler>, - H: Handler> + Send, + I: Send + 'static, + F: Handler> + 'static, + H: Handler> + Send + Clone + 'static, + O: 'static, { type Output = H::Output; fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { - let h = self.h; + let h = self.h.clone(); let fut = self.f.call(i).and_then(move |i| h.call(i)); Box::pin(fut) } diff --git a/viz-core/src/handler/boxed.rs b/viz-core/src/handler/boxed.rs index b81b34d0..d4fe0dde 100644 --- a/viz-core/src/handler/boxed.rs +++ b/viz-core/src/handler/boxed.rs @@ -1,19 +1,32 @@ -use crate::{async_trait, Handler, Request, Response, Result}; +use crate::{future::BoxFuture, handler::BoxCloneable, Handler, Request, Response, Result}; -/// Alias the boxed Handler. -pub type BoxHandler> = Box>; +pub struct BoxHandler>(BoxCloneable); -impl Clone for BoxHandler { +impl BoxHandler { + pub fn new(h: H) -> Self + where + H: Handler + Send + Clone + 'static, + { + Self(Box::new(h)) + } +} + +impl Clone for BoxHandler { fn clone(&self) -> Self { - dyn_clone::clone_box(&**self) + Self(self.0.clone_box()) } } -#[async_trait] -impl Handler for BoxHandler { - type Output = Result; +impl Handler for BoxHandler { + type Output = O; + + fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { + self.0.call(i) + } +} - async fn call(&self, req: Request) -> Self::Output { - self.as_ref().call(req).await +impl From> for BoxHandler { + fn from(value: BoxCloneable) -> Self { + Self(value) } } diff --git a/viz-core/src/handler/catch_error.rs b/viz-core/src/handler/catch_error.rs index 271bbcf1..3a20909c 100644 --- a/viz-core/src/handler/catch_error.rs +++ b/viz-core/src/handler/catch_error.rs @@ -1,18 +1,19 @@ use std::marker::PhantomData; -use futures_util::future::BoxFuture; - -use crate::{Handler, IntoResponse, Response, Result}; +use crate::{ + future::{BoxFuture, TryFutureExt}, + Error, Handler, IntoResponse, Response, Result, +}; /// Catches rejected error while calling the handler. #[derive(Debug)] -pub struct CatchError { +pub struct CatchError { h: H, f: F, _marker: PhantomData R>, } -impl Clone for CatchError +impl Clone for CatchError where H: Clone, F: Clone, @@ -26,7 +27,7 @@ where } } -impl CatchError { +impl CatchError { /// Creates a [`CatchError`] handler. #[inline] pub fn new(h: H, f: F) -> Self { @@ -38,21 +39,24 @@ impl CatchError { } } -impl Handler for CatchError +impl Handler for CatchError where - I: Send + 'static, - H: Handler> + Clone, - O: IntoResponse + Send, + H: Handler>, + O: IntoResponse + 'static, E: std::error::Error + Send + 'static, - F: Handler + Clone, - R: IntoResponse + 'static, + F: Handler + Send + Clone + 'static, + R: IntoResponse, { type Output = Result; fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { - match self.h.call(i).await { - Ok(r) => Ok(r.into_response()), - Err(e) => Ok(self.f.call(e.downcast::()?).await.into_response()), - } + let f = self.f.clone(); + let fut = self + .h + .call(i) + .map_ok(IntoResponse::into_response) + .map_err(Error::downcast::) + .or_else(move |r| async move { Ok(f.call(r?).await.into_response()) }); + Box::pin(fut) } } diff --git a/viz-core/src/handler/catch_unwind.rs b/viz-core/src/handler/catch_unwind.rs index 81fe847d..125482f8 100644 --- a/viz-core/src/handler/catch_unwind.rs +++ b/viz-core/src/handler/catch_unwind.rs @@ -1,8 +1,7 @@ -use std::{any::Any, panic::AssertUnwindSafe}; - -use futures_util::FutureExt; - -use crate::{async_trait, Handler, IntoResponse, Response, Result}; +use crate::{ + future::{BoxFuture, FutureExt, TryFutureExt}, + Handler, IntoResponse, Response, Result, +}; /// Catches unwinding panics while calling the handler. #[derive(Debug, Clone)] @@ -19,21 +18,21 @@ impl CatchUnwind { } } -#[async_trait] impl Handler for CatchUnwind where - I: Send + 'static, - H: Handler> + Clone, - O: IntoResponse + Send, - F: Handler, Output = R> + Clone, - R: IntoResponse, + H: Handler> + 'static, + O: IntoResponse + 'static, + F: Handler, Output = R> + Send + Clone + 'static, + R: IntoResponse + 'static, { type Output = Result; - async fn call(&self, i: I) -> Self::Output { - match AssertUnwindSafe(self.h.call(i)).catch_unwind().await { - Ok(r) => r.map(IntoResponse::into_response), - Err(e) => Ok(self.f.call(e).await.into_response()), - } + fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { + let f = self.f.clone(); + let fut = ::core::panic::AssertUnwindSafe(self.h.call(i)) + .catch_unwind() + .map_ok(IntoResponse::into_response) + .or_else(move |e| f.call(e).map(IntoResponse::into_response).map(Result::Ok)); + Box::pin(fut) } } diff --git a/viz-core/src/handler/cloneable.rs b/viz-core/src/handler/cloneable.rs new file mode 100644 index 00000000..7a64744a --- /dev/null +++ b/viz-core/src/handler/cloneable.rs @@ -0,0 +1,16 @@ +use super::Handler; + +pub type BoxCloneable = Box + Send>; + +pub trait Cloneable: Handler { + fn clone_box(&self) -> BoxCloneable; +} + +impl Cloneable for T +where + T: Handler + Send + Clone + 'static, +{ + fn clone_box(&self) -> BoxCloneable { + Box::new(self.clone()) + } +} diff --git a/viz-core/src/handler/either.rs b/viz-core/src/handler/either.rs index f2e08229..6de28c17 100644 --- a/viz-core/src/handler/either.rs +++ b/viz-core/src/handler/either.rs @@ -1,6 +1,4 @@ -use futures_util::future::BoxFuture; - -use crate::Handler; +use crate::{future::BoxFuture, Handler}; /// Combines two different handlers having the same associated types into a single type. #[derive(Debug, Clone)] @@ -13,7 +11,8 @@ pub enum Either { impl Handler for Either where - I: Send + 'static, + I: 'static, + O: 'static, L: Handler, R: Handler, { diff --git a/viz-core/src/handler/fn_ext.rs b/viz-core/src/handler/fn_ext.rs index a0677874..24b63361 100644 --- a/viz-core/src/handler/fn_ext.rs +++ b/viz-core/src/handler/fn_ext.rs @@ -1,11 +1,10 @@ -use crate::{async_trait, Request}; +use crate::{future::BoxFuture, Request}; /// A handler with extractors. -#[async_trait] -pub trait FnExt: Clone + Send + Sync + 'static { +pub trait FnExt { /// The returned type after the call operator is used. type Output; /// Performs the call operation. - async fn call(&self, req: Request) -> Self::Output; + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output>; } diff --git a/viz-core/src/handler/fn_ext_hanlder.rs b/viz-core/src/handler/fn_ext_hanlder.rs index 665f9bb9..e7cbe074 100644 --- a/viz-core/src/handler/fn_ext_hanlder.rs +++ b/viz-core/src/handler/fn_ext_hanlder.rs @@ -1,6 +1,9 @@ use std::marker::PhantomData; -use crate::{async_trait, FnExt, FromRequest, Handler, IntoResponse, Request, Result}; +use crate::{ + future::{BoxFuture, TryFutureExt}, + FnExt, FromRequest, Handler, IntoResponse, Request, Result, +}; /// A wrapper of the extractors handler. #[derive(Debug)] @@ -22,17 +25,17 @@ impl FnExtHandler { } } -#[async_trait] impl Handler for FnExtHandler where E: FromRequest + 'static, - E::Error: IntoResponse + Send, + E::Error: IntoResponse, H: FnExt>, O: 'static, { type Output = H::Output; - async fn call(&self, req: Request) -> Self::Output { - self.0.call(req).await.map_err(IntoResponse::into_error) + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + let fut = self.0.call(req).map_err(IntoResponse::into_error); + Box::pin(fut) } } diff --git a/viz-core/src/handler/into_handler.rs b/viz-core/src/handler/into_handler.rs index 5ca17b8d..f8e8c9af 100644 --- a/viz-core/src/handler/into_handler.rs +++ b/viz-core/src/handler/into_handler.rs @@ -1,6 +1,4 @@ -use crate::{FromRequest, IntoResponse, Request, Result}; - -use super::{FnExt, FnExtHandler, Handler}; +use crate::{handler::FnExtHandler, FnExt, FromRequest, Handler, IntoResponse, Request, Result}; /// The trait implemented by types that can be converted to a [`Handler`]. pub trait IntoHandler { @@ -15,8 +13,8 @@ pub trait IntoHandler { impl IntoHandler for H where E: FromRequest + 'static, - E::Error: IntoResponse + Send, - H: FnExt>, + E::Error: IntoResponse, + H: FnExt> + Send + Copy + 'static, O: 'static, { type Handler = FnExtHandler; diff --git a/viz-core/src/handler/map.rs b/viz-core/src/handler/map.rs index 89f89dc3..5ff6ee6c 100644 --- a/viz-core/src/handler/map.rs +++ b/viz-core/src/handler/map.rs @@ -1,6 +1,7 @@ -use futures_util::{future::BoxFuture, TryFutureExt}; - -use crate::{Handler, Result}; +use crate::{ + future::{BoxFuture, TryFutureExt}, + Handler, Result, +}; /// Maps the `Ok` value of the output if after the handler called. #[derive(Debug, Clone)] @@ -20,12 +21,13 @@ impl Map { impl Handler for Map where H: Handler>, - F: FnOnce(O) -> T + Send, + F: FnOnce(O) -> T + Send + Clone + 'static, + O: 'static, { type Output = Result; fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { - let fut = self.h.call(i).map_ok(self.f); + let fut = self.h.call(i).map_ok(self.f.clone()); Box::pin(fut) } } diff --git a/viz-core/src/handler/map_err.rs b/viz-core/src/handler/map_err.rs index f1424faf..e276393a 100644 --- a/viz-core/src/handler/map_err.rs +++ b/viz-core/src/handler/map_err.rs @@ -1,6 +1,7 @@ -use futures_util::{future::BoxFuture, TryFutureExt}; - -use crate::{Error, Handler, Result}; +use crate::{ + future::{BoxFuture, TryFutureExt}, + Error, Handler, Result, +}; /// Maps the `Err` value of the output if after the handler called. #[derive(Debug, Clone)] @@ -20,12 +21,14 @@ impl MapErr { impl Handler for MapErr where H: Handler>, - F: FnOnce(E) -> Error + Send, + F: FnOnce(E) -> Error + Send + Clone + 'static, + O: 'static, + E: 'static, { type Output = Result; fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { - let fut = self.h.call(i).map_err(self.f); + let fut = self.h.call(i).map_err(self.f.clone()); Box::pin(fut) } } diff --git a/viz-core/src/handler/map_into_response.rs b/viz-core/src/handler/map_into_response.rs index 6757700b..dba2d2aa 100644 --- a/viz-core/src/handler/map_into_response.rs +++ b/viz-core/src/handler/map_into_response.rs @@ -1,6 +1,7 @@ -use futures_util::{future::BoxFuture, TryFutureExt}; - -use crate::{Handler, IntoResponse, Response, Result}; +use crate::{ + future::{BoxFuture, TryFutureExt}, + Handler, IntoResponse, Response, Result, +}; /// Maps the handler's output type to the [`Response`]. #[derive(Debug, Clone)] @@ -16,8 +17,8 @@ impl MapInToResponse { impl Handler for MapInToResponse where - H: Handler> + Clone, - O: IntoResponse + Send, + H: Handler>, + O: IntoResponse + 'static, { type Output = Result; diff --git a/viz-core/src/handler/or_else.rs b/viz-core/src/handler/or_else.rs index 1ae79af3..c2d16155 100644 --- a/viz-core/src/handler/or_else.rs +++ b/viz-core/src/handler/or_else.rs @@ -1,6 +1,7 @@ -use futures_util::{future::BoxFuture, TryFutureExt}; - -use crate::{Error, Handler, Result}; +use crate::{ + future::{BoxFuture, TryFutureExt}, + Error, Handler, Result, +}; /// Calls `op` if the output is `Err`, otherwise returns the `Ok` value of the output. #[derive(Debug, Clone)] @@ -20,12 +21,13 @@ impl OrElse { impl Handler for OrElse where H: Handler>, - F: Handler + Send + Copy, + F: Handler + Send + Clone + 'static, + O: 'static, { type Output = F::Output; fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { - let f = self.f; + let f = self.f.clone(); let fut = self.h.call(i).or_else(move |e| f.call(e)); Box::pin(fut) } diff --git a/viz-core/src/handler/service.rs b/viz-core/src/handler/service.rs index 56c7fa43..ae675cc0 100644 --- a/viz-core/src/handler/service.rs +++ b/viz-core/src/handler/service.rs @@ -1,7 +1,8 @@ use hyper::service::Service; use crate::{ - async_trait, Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result, + future::{BoxFuture, TryFutureExt}, + Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result, }; /// Converts a hyper [`Service`] to a viz [`Handler`]. @@ -15,7 +16,6 @@ impl ServiceHandler { } } -#[async_trait] impl Handler> for ServiceHandler where I: HttpBody + Send + 'static, @@ -28,11 +28,12 @@ where { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { - self.0 + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + let fut = self + .0 .call(req) - .await - .map(|resp| resp.map(Body::wrap)) - .map_err(Error::boxed) + .map_ok(|resp| resp.map(Body::wrap)) + .map_err(Error::boxed); + Box::pin(fut) } } diff --git a/viz-core/src/handler/try_handler.rs b/viz-core/src/handler/try_handler.rs index c0354334..e5ce9db9 100644 --- a/viz-core/src/handler/try_handler.rs +++ b/viz-core/src/handler/try_handler.rs @@ -1,6 +1,4 @@ -use futures_util::future::BoxFuture; - -use super::{Handler, MapErr}; +use crate::{future::BoxFuture, Handler}; pub trait TryHandler: Handler { type Ok; @@ -23,12 +21,4 @@ where } } -pub trait TryHandlerExt: TryHandler { - fn map_err(self, f: F) -> MapErr - where - F: FnOnce(Self::Error) -> E, - Self: Sized, - { - MapErr::new(self, f) - } -} +pub trait TryHandlerExt: TryHandler {} diff --git a/viz-core/src/lib.rs b/viz-core/src/lib.rs index 201e84a4..2b214934 100644 --- a/viz-core/src/lib.rs +++ b/viz-core/src/lib.rs @@ -47,6 +47,7 @@ pub type Result = core::result::Result; pub use async_trait::async_trait; pub use bytes::{Bytes, BytesMut}; +pub use futures_util::future; #[doc(inline)] pub use headers; pub use http::{header, Method, StatusCode}; @@ -57,7 +58,11 @@ pub use thiserror::Error as ThisError; #[doc(hidden)] mod tuples { - use super::{async_trait, Error, FnExt, FromRequest, Future, IntoResponse, Request, Result}; + use super::{ + async_trait, + future::{BoxFuture, TryFutureExt}, + Error, FnExt, FromRequest, Future, IntoResponse, Request, Result, + }; tuple_impls!(A B C D E F G H I J K L); } diff --git a/viz-core/src/macros.rs b/viz-core/src/macros.rs index 729eef7d..63a307d8 100644 --- a/viz-core/src/macros.rs +++ b/viz-core/src/macros.rs @@ -22,20 +22,22 @@ macro_rules! tuple_impls { } } - #[async_trait] impl<$($T,)* Fun, Fut, Out> FnExt<($($T,)*)> for Fun where $($T: FromRequest + Send,)* $($T::Error: IntoResponse + Send,)* - Fun: Fn($($T,)*) -> Fut + Clone + Send + Sync + 'static, + Fun: Fn($($T,)*) -> Fut + Send + Copy + 'static, Fut: Future> + Send, { type Output = Fut::Output; - #[allow(unused, unused_mut)] - async fn call(&self, mut req: Request) -> Self::Output { - (self)($($T::extract(&mut req).await.map_err(IntoResponse::into_error)?,)*) - .await + #[allow(unused, unused_mut, non_snake_case)] + fn call(&self, mut req: Request) -> BoxFuture<'static, Self::Output> { + let this = *self; + let fut = async move { + <($($T,)*)>::extract(&mut req).and_then(move |($($T,)*)| this($($T,)*)).await + }; + Box::pin(fut) } } }; diff --git a/viz-core/src/middleware/compression.rs b/viz-core/src/middleware/compression.rs index 11b15921..3e4c82a3 100644 --- a/viz-core/src/middleware/compression.rs +++ b/viz-core/src/middleware/compression.rs @@ -6,7 +6,7 @@ use async_compression::tokio::bufread; use tokio_util::io::{ReaderStream, StreamReader}; use crate::{ - async_trait, + future::BoxFuture, header::{HeaderValue, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH}, Body, Handler, IntoResponse, Request, Response, Result, Transform, }; @@ -32,27 +32,30 @@ pub struct CompressionMiddleware { h: H, } -#[async_trait] impl Handler for CompressionMiddleware where + H: Handler> + Send + Clone + 'static, O: IntoResponse, - H: Handler> + Clone, { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { - let accept_encoding = req - .headers() - .get(ACCEPT_ENCODING) - .map(HeaderValue::to_str) - .and_then(Result::ok) - .and_then(parse_accept_encoding); + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + let h = self.h.clone(); - let raw = self.h.call(req).await?; + Box::pin(async move { + let accept_encoding = req + .headers() + .get(ACCEPT_ENCODING) + .map(HeaderValue::to_str) + .and_then(Result::ok) + .and_then(parse_accept_encoding); - Ok(match accept_encoding { - Some(algo) => Compress::new(raw, algo).into_response(), - None => raw.into_response(), + let raw = h.call(req).await?; + + Ok(match accept_encoding { + Some(algo) => Compress::new(raw, algo).into_response(), + None => raw.into_response(), + }) }) } } diff --git a/viz-core/src/middleware/cookie.rs b/viz-core/src/middleware/cookie.rs index e8ac8b64..71ee7cad 100644 --- a/viz-core/src/middleware/cookie.rs +++ b/viz-core/src/middleware/cookie.rs @@ -3,7 +3,7 @@ use std::fmt; use crate::{ - async_trait, + future::BoxFuture, header::{HeaderValue, COOKIE, SET_COOKIE}, types::{Cookie, CookieJar, CookieKey, Cookies}, Handler, IntoResponse, Request, Response, Result, Transform, @@ -80,15 +80,14 @@ impl fmt::Debug for CookieMiddleware { } } -#[async_trait] impl Handler for CookieMiddleware where - H: Handler> + Clone, - O: IntoResponse, + H: Handler> + Send + Clone + 'static, + O: IntoResponse + 'static, { type Output = Result; - async fn call(&self, mut req: Request) -> Self::Output { + fn call(&self, mut req: Request) -> BoxFuture<'static, Self::Output> { let jar = req .headers() .get_all(COOKIE) @@ -103,23 +102,26 @@ where req.extensions_mut().insert::(cookies.clone()); - self.h - .call(req) - .await - .map(IntoResponse::into_response) - .map(|mut res| { - if let Ok(c) = cookies.jar().lock() { - c.delta() - .map(Cookie::encoded) - .map(|cookie| HeaderValue::from_str(&cookie.to_string())) - .filter_map(Result::ok) - .fold(res.headers_mut(), |headers, cookie| { - headers.append(SET_COOKIE, cookie); - headers - }); - } - res - }) + let h = self.h.clone(); + + Box::pin(async move { + h.call(req) + .await + .map(IntoResponse::into_response) + .map(|mut res| { + if let Ok(c) = cookies.jar().lock() { + c.delta() + .map(Cookie::encoded) + .map(|cookie| HeaderValue::from_str(&cookie.to_string())) + .filter_map(Result::ok) + .fold(res.headers_mut(), |headers, cookie| { + headers.append(SET_COOKIE, cookie); + headers + }); + } + res + }) + }) } } diff --git a/viz-core/src/middleware/cors.rs b/viz-core/src/middleware/cors.rs index d0327899..6b609aea 100644 --- a/viz-core/src/middleware/cors.rs +++ b/viz-core/src/middleware/cors.rs @@ -3,7 +3,7 @@ use std::{collections::HashSet, fmt, sync::Arc}; use crate::{ - async_trait, + future::BoxFuture, header::{ HeaderMap, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_REQUEST_HEADERS, @@ -203,99 +203,103 @@ pub struct CorsMiddleware { aceh: AccessControlExposeHeaders, } -#[async_trait] impl Handler for CorsMiddleware where - H: Handler> + Clone, - O: IntoResponse, + H: Handler> + Send + Clone + 'static, + O: IntoResponse + 'static, { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { - let Some(origin) = req.header(ORIGIN).filter(is_not_empty) else { - return self.h.call(req).await.map(IntoResponse::into_response); - }; - - if !self.config.allow_origins.contains(&origin) - || !self - .config - .origin_verify - .as_ref() - .map_or(true, |f| (f)(&origin)) - { - return Err(StatusCode::FORBIDDEN.into_error()); - } + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + let Self { + config, + acam, + acah, + aceh, + h, + } = self.clone(); + + Box::pin(async move { + let Some(origin) = req.header(ORIGIN).filter(is_not_empty) else { + return h.call(req).await.map(IntoResponse::into_response); + }; - let mut headers = HeaderMap::new(); - let mut resp = if req.method() == Method::OPTIONS { - // Preflight request - if req - .header(ACCESS_CONTROL_REQUEST_METHOD) - .map_or(false, |method| { - self.config.allow_methods.is_empty() - || self.config.allow_methods.contains(&method) - }) + if !config.allow_origins.contains(&origin) + || !config.origin_verify.as_ref().map_or(true, |f| (f)(&origin)) { - headers.typed_insert(self.acam.clone()); - } else { - return Err((StatusCode::FORBIDDEN, "Invalid Preflight Request").into_error()); + return Err(StatusCode::FORBIDDEN.into_error()); } - let (allow_headers, request_headers) = req - .header(ACCESS_CONTROL_REQUEST_HEADERS) - .map_or((true, None), |hs: HeaderValue| { - ( - hs.to_str() - .map(|hs| { - hs.split(',') - .map(str::as_bytes) - .map(HeaderName::from_bytes) - .filter_map(Result::ok) - .any(|header| self.config.allow_headers.contains(&header)) - }) - .unwrap_or(false), - Some(hs), - ) - }); - - if !allow_headers { - return Err((StatusCode::FORBIDDEN, "Invalid Preflight Request").into_error()); - } + let mut headers = HeaderMap::new(); + let mut resp = if req.method() == Method::OPTIONS { + // Preflight request + if req + .header(ACCESS_CONTROL_REQUEST_METHOD) + .map_or(false, |method| { + config.allow_methods.is_empty() || config.allow_methods.contains(&method) + }) + { + headers.typed_insert(acam); + } else { + return Err((StatusCode::FORBIDDEN, "Invalid Preflight Request").into_error()); + } - if self.config.allow_headers.is_empty() { - headers.insert( - ACCESS_CONTROL_ALLOW_HEADERS, - request_headers.unwrap_or(HeaderValue::from_static("*")), - ); - } else { - headers.typed_insert(self.acah.clone()); - } + let (allow_headers, request_headers) = req + .header(ACCESS_CONTROL_REQUEST_HEADERS) + .map_or((true, None), |hs: HeaderValue| { + ( + hs.to_str() + .map(|hs| { + hs.split(',') + .map(str::as_bytes) + .map(HeaderName::from_bytes) + .filter_map(Result::ok) + .any(|header| config.allow_headers.contains(&header)) + }) + .unwrap_or(false), + Some(hs), + ) + }); - // 204 - no content - StatusCode::NO_CONTENT.into_response() - } else { - // Simple Request - if !self.config.expose_headers.is_empty() { - headers.typed_insert(self.aceh.clone()); - } + if !allow_headers { + return Err((StatusCode::FORBIDDEN, "Invalid Preflight Request").into_error()); + } + + if config.allow_headers.is_empty() { + headers.insert( + ACCESS_CONTROL_ALLOW_HEADERS, + request_headers.unwrap_or(HeaderValue::from_static("*")), + ); + } else { + headers.typed_insert(acah.clone()); + } - self.h.call(req).await.map(IntoResponse::into_response)? - }; + // 204 - no content + StatusCode::NO_CONTENT.into_response() + } else { + // Simple Request + if !config.expose_headers.is_empty() { + headers.typed_insert(aceh.clone()); + } - // https://github.com/rs/cors/issues/10 - headers.insert(VARY, ORIGIN.into()); - headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin); + h.call(req).await.map(IntoResponse::into_response)? + }; - if self.config.credentials { - headers.insert( - ACCESS_CONTROL_ALLOW_CREDENTIALS, - HeaderValue::from_static("true"), - ); - } + // https://github.com/rs/cors/issues/10 + headers.insert(VARY, ORIGIN.into()); + headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin); + + if config.credentials { + headers.insert( + ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } - resp.headers_mut().extend(headers); + resp.headers_mut().extend(headers); - Ok(resp) + Ok(resp) + }) } } diff --git a/viz-core/src/middleware/csrf.rs b/viz-core/src/middleware/csrf.rs index 321e8241..a7a7e68c 100644 --- a/viz-core/src/middleware/csrf.rs +++ b/viz-core/src/middleware/csrf.rs @@ -6,6 +6,7 @@ use base64::Engine as _; use crate::{ async_trait, + future::BoxFuture, header::{HeaderName, HeaderValue, VARY}, middleware::helper::{CookieOptions, Cookieable}, Error, FromRequest, Handler, IntoResponse, Method, Request, RequestExt, Response, Result, @@ -186,10 +187,9 @@ where } } -#[async_trait] impl Handler for CsrfMiddleware where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse, S: Fn() -> Result> + Send + Sync + 'static, G: Fn(&[u8], Vec) -> Vec + Send + Sync + 'static, @@ -197,38 +197,46 @@ where { type Output = Result; - async fn call(&self, mut req: Request) -> Self::Output { - let mut secret = self.config.get(&req)?; - let config = self.config.as_ref(); - - if !config.ignored_methods.contains(req.method()) { - let mut forbidden = true; - if let Some(secret) = secret.take() { - if let Some(raw_token) = req.header(&config.header) { - forbidden = !(config.verify)(&secret, raw_token); + fn call(&self, mut req: Request) -> BoxFuture<'static, Self::Output> { + let Self { config, h } = self.clone(); + + Box::pin(async move { + let mut secret = config.get(&req)?; + + let (token, secret) = { + let config = config.as_ref(); + + if !config.ignored_methods.contains(req.method()) { + let mut forbidden = true; + if let Some(secret) = secret.take() { + if let Some(raw_token) = req.header(&config.header) { + forbidden = !(config.verify)(&secret, raw_token); + } + } + if forbidden { + return Err((StatusCode::FORBIDDEN, "Invalid csrf token").into_error()); + } } - } - if forbidden { - return Err((StatusCode::FORBIDDEN, "Invalid csrf token").into_error()); - } - } - - let otp = (config.secret)()?; - let secret = (config.secret)()?; - let token = base64::engine::general_purpose::URL_SAFE_NO_PAD - .encode((config.generate)(&secret, otp)); - req.extensions_mut().insert(CsrfToken(token.to_string())); - self.config.set(&req, token, secret)?; - - self.h - .call(req) - .await - .map(IntoResponse::into_response) - .map(|mut res| { - res.headers_mut() - .insert(VARY, HeaderValue::from_static("Cookie")); - res - }) + let otp = (config.secret)()?; + let secret = (config.secret)()?; + let token = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode((config.generate)(&secret, otp)); + + (token, secret) + }; + + req.extensions_mut().insert(CsrfToken(token.to_string())); + config.set(&req, token, secret)?; + + h.call(req) + .await + .map(IntoResponse::into_response) + .map(|mut res| { + res.headers_mut() + .insert(VARY, HeaderValue::from_static("Cookie")); + res + }) + }) } } diff --git a/viz-core/src/middleware/limits.rs b/viz-core/src/middleware/limits.rs index 9b4a7441..8c7a7315 100644 --- a/viz-core/src/middleware/limits.rs +++ b/viz-core/src/middleware/limits.rs @@ -3,7 +3,10 @@ #[cfg(feature = "multipart")] use std::sync::Arc; -use crate::{async_trait, types, Handler, IntoResponse, Request, Response, Result, Transform}; +use crate::{ + future::{BoxFuture, TryFutureExt}, + types, Handler, IntoResponse, Request, Response, Result, Transform, +}; /// A configuration for [`LimitsMiddleware`]. #[derive(Debug, Clone)] @@ -67,19 +70,18 @@ pub struct LimitsMiddleware { config: Config, } -#[async_trait] impl Handler for LimitsMiddleware where - H: Handler> + Clone, - O: IntoResponse, + H: Handler>, + O: IntoResponse + 'static, { type Output = Result; - async fn call(&self, mut req: Request) -> Self::Output { + fn call(&self, mut req: Request) -> BoxFuture<'static, Self::Output> { req.extensions_mut().insert(self.config.limits.clone()); #[cfg(feature = "multipart")] req.extensions_mut().insert(self.config.multipart.clone()); - self.h.call(req).await.map(IntoResponse::into_response) + Box::pin(self.h.call(req).map_ok(IntoResponse::into_response)) } } diff --git a/viz-core/src/middleware/otel/metrics.rs b/viz-core/src/middleware/otel/metrics.rs index 16b3c51b..5a1e1124 100644 --- a/viz-core/src/middleware/otel/metrics.rs +++ b/viz-core/src/middleware/otel/metrics.rs @@ -15,7 +15,7 @@ use opentelemetry_semantic_conventions::trace::{ }; use crate::{ - async_trait, Handler, IntoResponse, Request, RequestExt, Response, ResponseExt, Result, + future::BoxFuture, Handler, IntoResponse, Request, RequestExt, Response, ResponseExt, Result, Transform, }; @@ -96,45 +96,52 @@ pub struct MetricsMiddleware { response_size: Histogram, } -#[async_trait] impl Handler for MetricsMiddleware where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse, { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { - let timer = SystemTime::now(); - let mut attributes = build_attributes(&req, req.route_info().pattern.as_str()); + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + let Self { + active_requests, + duration, + request_size, + response_size, + h, + } = self.clone(); + + Box::pin(async move { + let timer = SystemTime::now(); + let mut attributes = build_attributes(&req, req.route_info().pattern.as_str()); - self.active_requests.add(1, &attributes); + active_requests.add(1, &attributes); - self.request_size - .record(req.content_length().unwrap_or(0), &attributes); + request_size.record(req.content_length().unwrap_or(0), &attributes); - let resp = self - .h - .call(req) - .await - .map(IntoResponse::into_response) - .map(|resp| { - self.active_requests.add(-1, &attributes); + let resp = h + .call(req) + .await + .map(IntoResponse::into_response) + .map(|resp| { + active_requests.add(-1, &attributes); - attributes.push(HTTP_RESPONSE_STATUS_CODE.i64(i64::from(resp.status().as_u16()))); + attributes + .push(HTTP_RESPONSE_STATUS_CODE.i64(i64::from(resp.status().as_u16()))); - self.response_size - .record(resp.content_length().unwrap_or(0), &attributes); + response_size.record(resp.content_length().unwrap_or(0), &attributes); - resp - }); + resp + }); - self.duration.record( - timer.elapsed().map(|t| t.as_secs_f64()).unwrap_or_default(), - &attributes, - ); + duration.record( + timer.elapsed().map(|t| t.as_secs_f64()).unwrap_or_default(), + &attributes, + ); - resp + resp + }) } } diff --git a/viz-core/src/middleware/otel/tracing.rs b/viz-core/src/middleware/otel/tracing.rs index a6c0d977..7ab9c47a 100644 --- a/viz-core/src/middleware/otel/tracing.rs +++ b/viz-core/src/middleware/otel/tracing.rs @@ -19,7 +19,7 @@ use opentelemetry_semantic_conventions::trace::{ }; use crate::{ - async_trait, + future::BoxFuture, header::{HeaderMap, HeaderName}, headers::UserAgent, Handler, IntoResponse, Request, RequestExt, Response, ResponseExt, Result, Transform, @@ -58,76 +58,77 @@ pub struct TracingMiddleware { tracer: Arc, } -#[async_trait] impl Handler for TracingMiddleware where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse, T: Tracer + Send + Sync + Clone + 'static, T::Span: Send + Sync + 'static, { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { - let parent_context = global::get_text_map_propagator(|propagator| { - propagator.extract(&RequestHeaderCarrier::new(req.headers())) - }); - - let http_route = &req.route_info().pattern; - let attributes = build_attributes(&req, http_route.as_str()); - - let mut span = self - .tracer - .span_builder(format!("{} {}", req.method(), http_route)) - .with_kind(SpanKind::Server) - .with_attributes(attributes) - .start_with_context(&*self.tracer, &parent_context); - - span.add_event("request.started".to_string(), vec![]); - - let resp = self - .h - .call(req) - .with_context(Context::current_with_span(span)) - .await; - - let cx = Context::current(); - let span = cx.span(); - - match resp { - Ok(resp) => { - let resp = resp.into_response(); - span.add_event("request.completed".to_string(), vec![]); - span.set_attribute( - HTTP_RESPONSE_STATUS_CODE.i64(i64::from(resp.status().as_u16())), - ); - if let Some(content_length) = resp.content_length() { + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + let Self { tracer, h } = self.clone(); + + Box::pin(async move { + let parent_context = global::get_text_map_propagator(|propagator| { + propagator.extract(&RequestHeaderCarrier::new(req.headers())) + }); + + let http_route = &req.route_info().pattern; + let attributes = build_attributes(&req, http_route.as_str()); + + let mut span = tracer + .span_builder(format!("{} {}", req.method(), http_route)) + .with_kind(SpanKind::Server) + .with_attributes(attributes) + .start_with_context(&*tracer, &parent_context); + + span.add_event("request.started".to_string(), vec![]); + + let resp = h + .call(req) + .with_context(Context::current_with_span(span)) + .await; + + let cx = Context::current(); + let span = cx.span(); + + match resp { + Ok(resp) => { + let resp = resp.into_response(); + span.add_event("request.completed".to_string(), vec![]); span.set_attribute( - HTTP_RESPONSE_BODY_SIZE - .i64(i64::try_from(content_length).unwrap_or(i64::MAX)), + HTTP_RESPONSE_STATUS_CODE.i64(i64::from(resp.status().as_u16())), ); + if let Some(content_length) = resp.content_length() { + span.set_attribute( + HTTP_RESPONSE_BODY_SIZE + .i64(i64::try_from(content_length).unwrap_or(i64::MAX)), + ); + } + if resp.status().is_server_error() { + span.set_status(Status::error( + resp.status() + .canonical_reason() + .map(ToString::to_string) + .unwrap_or_default(), + )); + }; + span.end(); + Ok(resp) + } + Err(err) => { + span.add_event( + "request.error".to_string(), + vec![EXCEPTION_MESSAGE.string(err.to_string())], + ); + span.set_status(Status::error(err.to_string())); + span.end(); + Err(err) } - if resp.status().is_server_error() { - span.set_status(Status::error( - resp.status() - .canonical_reason() - .map(ToString::to_string) - .unwrap_or_default(), - )); - }; - span.end(); - Ok(resp) - } - Err(err) => { - span.add_event( - "request.error".to_string(), - vec![EXCEPTION_MESSAGE.string(err.to_string())], - ); - span.set_status(Status::error(err.to_string())); - span.end(); - Err(err) } - } + }) } } diff --git a/viz-core/src/middleware/session/config.rs b/viz-core/src/middleware/session/config.rs index c1fdec8c..1974b15c 100644 --- a/viz-core/src/middleware/session/config.rs +++ b/viz-core/src/middleware/session/config.rs @@ -5,7 +5,7 @@ use std::{ }; use crate::{ - async_trait, + future::BoxFuture, middleware::helper::{CookieOptions, Cookieable}, types::{Cookie, Session}, Error, Handler, IntoResponse, Request, RequestExt, Response, Result, StatusCode, Transform, @@ -89,10 +89,9 @@ where } } -#[async_trait] impl Handler for SessionMiddleware where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse, S: Storage + 'static, G: Fn() -> String + Send + Sync + 'static, @@ -100,61 +99,61 @@ where { type Output = Result; - async fn call(&self, mut req: Request) -> Self::Output { - let cookies = req.cookies().map_err(Error::from)?; - let cookie = self.config.get_cookie(&cookies); - - let mut session_id = cookie.as_ref().map(Cookie::value).map(ToString::to_string); - let data = match &session_id { - Some(sid) if (self.config.store().verify)(sid) => self.config.store().get(sid).await?, - _ => None, - }; - if data.is_none() && session_id.is_some() { - session_id.take(); - } - let session = Session::new(data.unwrap_or_default()); - req.extensions_mut().insert(session.clone()); + fn call(&self, mut req: Request) -> BoxFuture<'static, Self::Output> { + let Self { config, h } = self.clone(); - let resp = self.h.call(req).await.map(IntoResponse::into_response); + Box::pin(async move { + let cookies = req.cookies().map_err(Error::from)?; + let cookie = config.get_cookie(&cookies); - let status = session.status().load(Ordering::Acquire); + let mut session_id = cookie.as_ref().map(Cookie::value).map(ToString::to_string); + let data = match &session_id { + Some(sid) if (config.store().verify)(sid) => config.store().get(sid).await?, + _ => None, + }; + if data.is_none() && session_id.is_some() { + session_id.take(); + } + let session = Session::new(data.unwrap_or_default()); + req.extensions_mut().insert(session.clone()); - if status == UNCHANGED { - return resp; - } + let resp = h.call(req).await.map(IntoResponse::into_response); - if status == PURGED { - if let Some(sid) = &session_id { - self.config.store().remove(sid).await.map_err(Error::from)?; - self.config.remove_cookie(&cookies); + let status = session.status().load(Ordering::Acquire); + + if status == UNCHANGED { + return resp; } - return resp; - } + if status == PURGED { + if let Some(sid) = &session_id { + config.store().remove(sid).await.map_err(Error::from)?; + config.remove_cookie(&cookies); + } - if status == RENEWED { - if let Some(sid) = &session_id.take() { - self.config.store().remove(sid).await.map_err(Error::from)?; + return resp; } - } - let sid = session_id.unwrap_or_else(|| { - let sid = (self.config.store().generate)(); - self.config.set_cookie(&cookies, &sid); - sid - }); - - self.config - .store() - .set( - &sid, - session.data()?, - &self.config.ttl().unwrap_or_else(max_age), - ) - .await - .map_err(Error::from)?; - - resp + if status == RENEWED { + if let Some(sid) = &session_id.take() { + config.store().remove(sid).await.map_err(Error::from)?; + } + } + + let sid = session_id.unwrap_or_else(|| { + let sid = (config.store().generate)(); + config.set_cookie(&cookies, &sid); + sid + }); + + config + .store() + .set(&sid, session.data()?, &config.ttl().unwrap_or_else(max_age)) + .await + .map_err(Error::from)?; + + resp + }) } } diff --git a/viz-core/src/types/state.rs b/viz-core/src/types/state.rs index c72e4681..231df9b7 100644 --- a/viz-core/src/types/state.rs +++ b/viz-core/src/types/state.rs @@ -6,8 +6,11 @@ use std::{ }; use crate::{ - async_trait, handler::Transform, Error, FromRequest, Handler, IntoResponse, Request, - RequestExt, Response, Result, StatusCode, ThisError, + async_trait, + future::{BoxFuture, TryFutureExt}, + handler::Transform, + Error, FromRequest, Handler, IntoResponse, Request, RequestExt, Response, Result, StatusCode, + ThisError, }; /// Extracts state from the extensions of a request. @@ -72,19 +75,18 @@ where } } -#[async_trait] impl Handler for State<(T, H)> where T: Clone + Send + Sync + 'static, - H: Handler> + Clone, - O: IntoResponse, + H: Handler>, + O: IntoResponse + 'static, { type Output = Result; - async fn call(&self, mut req: Request) -> Self::Output { + fn call(&self, mut req: Request) -> BoxFuture<'static, Self::Output> { let Self((t, h)) = self; req.extensions_mut().insert(t.clone()); - h.call(req).await.map(IntoResponse::into_response) + Box::pin(h.call(req).map_ok(IntoResponse::into_response)) } } diff --git a/viz-core/tests/handler.rs b/viz-core/tests/handler.rs index 2c8d6545..9d286258 100644 --- a/viz-core/tests/handler.rs +++ b/viz-core/tests/handler.rs @@ -3,6 +3,7 @@ #![allow(clippy::similar_names)] #![allow(clippy::wildcard_imports)] +use futures_util::future::BoxFuture; use http_body_util::Full; use viz_core::handler::CatchError; use viz_core::*; @@ -159,12 +160,11 @@ async fn handler() -> Result<()> { name: String, } - #[async_trait] impl Handler for MyBefore { type Output = Result; - async fn call(&self, i: I) -> Self::Output { - Ok(i) + fn call(&self, i: I) -> BoxFuture<'static, Self::Output> { + Box::pin(async move { Ok(i) }) } } @@ -173,12 +173,11 @@ async fn handler() -> Result<()> { name: String, } - #[async_trait] impl Handler> for MyAfter { type Output = Result; - async fn call(&self, o: Self::Output) -> Self::Output { - o + fn call(&self, o: Self::Output) -> BoxFuture<'static, Self::Output> { + Box::pin(async move { o }) } } @@ -187,24 +186,24 @@ async fn handler() -> Result<()> { name: String, } - #[async_trait] impl Handler> for MyAround where I: Send + 'static, H: Handler>, + O: 'static, { type Output = H::Output; - async fn call(&self, (i, h): Next) -> Self::Output { - h.call(i).await + fn call(&self, (i, h): Next) -> BoxFuture<'static, Self::Output> { + Box::pin(h.call(i)) } } - async fn map(res: Response) -> Response { + fn map(res: Response) -> Response { res } - async fn map_err(err: Error) -> Error { + fn map_err(err: Error) -> Error { err } @@ -232,7 +231,9 @@ async fn handler() -> Result<()> { name: "round 1".to_string(), }) .map(map) - .catch_error(|_: CustomError2| async move { "Custom Error 2" }) + .catch_error::<_, CustomError2, &'static str>(|_: CustomError2| async move { + "Custom Error 2" + }) .catch_unwind( |_: Box| async move { panic!("Custom Error 2") }, ); @@ -255,12 +256,14 @@ async fn handler() -> Result<()> { .map_err(map_err) .or_else(or_else); let rhb = b.map_into_response(); - let rhc = c - .map_into_response() - .catch_error(|_: CustomError2| async move { "Custom Error 2" }) - .catch_error2(|e: std::io::Error| async move { - (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()) - }); + let rhc = + c.map_into_response() + .catch_error(|_: CustomError2| async move { "Custom Error 2" }) + .catch_error2::<_, std::io::Error, (StatusCode, String)>( + |e: std::io::Error| async move { + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()) + }, + ); let rhd = d .map_into_response() .map(map) @@ -289,25 +292,25 @@ async fn handler() -> Result<()> { assert!(rhb.call(Request::default()).await.is_err()); - let brha: BoxHandler = rha.boxed(); - let brhb: BoxHandler = Box::new(rhb) + let brha: BoxHandler<_, _> = rha.boxed(); + let brhb: BoxHandler<_, _> = Box::new(rhb) .around(MyAround { name: "MyRound 3".to_string(), }) .boxed(); - let brhc: BoxHandler = Box::new(rhc); - let brhd: BoxHandler = Box::new(rhd); - let brhe: BoxHandler = rhe.boxed(); - let brhf: BoxHandler = Box::new(rhf); - let brhg: BoxHandler = Box::new(rhg); - let brhh: BoxHandler = Box::new(rhh); - let brhi: BoxHandler = Box::new(rhi); - let brhj: BoxHandler = Box::new(rhj); - let brhk: BoxHandler = rhk.boxed(); - let brhl: BoxHandler = Box::new(rhl); - let brhm: BoxHandler = rhm.boxed(); - - let v: Vec = vec![ + let brhc: BoxHandler<_, _> = rhc.boxed(); + let brhd: BoxHandler<_, _> = rhd.boxed(); + let brhe: BoxHandler<_, _> = rhe.boxed(); + let brhf: BoxHandler<_, _> = rhf.boxed(); + let brhg: BoxHandler<_, _> = rhg.boxed(); + let brhh: BoxHandler<_, _> = rhh.boxed(); + let brhi: BoxHandler<_, _> = rhi.boxed(); + let brhj: BoxHandler<_, _> = rhj.boxed(); + let brhk: BoxHandler<_, _> = rhk.boxed(); + let brhl: BoxHandler<_, _> = rhl.boxed(); + let brhm: BoxHandler<_, _> = rhm.boxed(); + + let v: Vec> = vec![ brha, brhb, brhc, brhd, brhe, brhf, brhg, brhh, brhi, brhj, brhk, brhl, brhm, ]; diff --git a/viz-handlers/src/embed.rs b/viz-handlers/src/embed.rs index 86191ddd..f053712c 100644 --- a/viz-handlers/src/embed.rs +++ b/viz-handlers/src/embed.rs @@ -5,8 +5,8 @@ use std::{borrow::Cow, marker::PhantomData}; use http_body_util::Full; use rust_embed::{EmbeddedFile, RustEmbed}; use viz_core::{ - async_trait, - header::{HeaderMap, CONTENT_TYPE, ETAG, IF_NONE_MATCH}, + future::BoxFuture, + header::{CONTENT_TYPE, ETAG, IF_NONE_MATCH}, Handler, IntoResponse, Method, Request, RequestExt, Response, Result, StatusCode, }; @@ -28,15 +28,14 @@ impl File { } } -#[async_trait] impl Handler for File where E: RustEmbed + Send + Sync + 'static, { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { - serve::(&self.0, req.method(), req.headers()) + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + Box::pin(serve::(self.0.to_string(), req)) } } @@ -56,32 +55,35 @@ impl Default for Dir { } } -#[async_trait] impl Handler for Dir where E: RustEmbed + Send + Sync + 'static, { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { let path = match req.route_info().params.first().map(|(_, v)| v) { Some(p) => p, None => "index.html", - }; + } + .to_string(); - serve::(path, req.method(), req.headers()) + Box::pin(serve::(path, req)) } } -fn serve(path: &str, method: &Method, headers: &HeaderMap) -> Result +async fn serve(path: String, req: Request) -> Result where E: RustEmbed + Send + Sync + 'static, { + let method = req.method(); + let headers = req.headers(); + if method != Method::GET { Err(StatusCode::METHOD_NOT_ALLOWED.into_error())?; } - match E::get(path) { + match E::get(&path) { Some(EmbeddedFile { data, metadata }) => { let hash = hex::encode(metadata.sha256_hash()); diff --git a/viz-handlers/src/prometheus.rs b/viz-handlers/src/prometheus.rs index 1326281c..8a1bc728 100644 --- a/viz-handlers/src/prometheus.rs +++ b/viz-handlers/src/prometheus.rs @@ -7,7 +7,7 @@ use opentelemetry::{global::handle_error, metrics::MetricsError}; use prometheus::{Encoder, TextEncoder}; use viz_core::{ - async_trait, + future::BoxFuture, header::{HeaderValue, CONTENT_TYPE}, Handler, IntoResponse, Request, Response, Result, StatusCode, }; @@ -31,29 +31,33 @@ impl Prometheus { } } -#[async_trait] impl Handler for Prometheus { type Output = Result; - async fn call(&self, _: Request) -> Self::Output { - let metric_families = self.registry.gather(); - let encoder = TextEncoder::new(); - let mut body = Vec::new(); + fn call(&self, _: Request) -> BoxFuture<'static, Self::Output> { + let Self { registry } = self.clone(); - if let Err(err) = encoder.encode(&metric_families, &mut body) { - let text = err.to_string(); - handle_error(MetricsError::Other(text.clone())); - Err((StatusCode::INTERNAL_SERVER_ERROR, text).into_error())?; - } + Box::pin(async move { + let metric_families = registry.gather(); + let encoder = TextEncoder::new(); + let mut body = Vec::new(); - let mut res = Response::new(Full::from(body).into()); + if let Err(err) = encoder.encode(&metric_families, &mut body) { + let text = err.to_string(); + handle_error(MetricsError::Other(text.clone())); + Err((StatusCode::INTERNAL_SERVER_ERROR, text).into_error())?; + } - res.headers_mut().append( - CONTENT_TYPE, - HeaderValue::from_str(encoder.format_type()) - .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_error())?, - ); + let mut res = Response::new(Full::from(body).into()); - Ok(res) + res.headers_mut().append( + CONTENT_TYPE, + HeaderValue::from_str(encoder.format_type()).map_err(|err| { + (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_error() + })?, + ); + + Ok(res) + }) } } diff --git a/viz-handlers/src/serve.rs b/viz-handlers/src/serve.rs index 09280c90..aaca0c63 100644 --- a/viz-handlers/src/serve.rs +++ b/viz-handlers/src/serve.rs @@ -11,7 +11,7 @@ use tokio::io::AsyncReadExt; use tokio_util::io::ReaderStream; use viz_core::{ - async_trait, + future::BoxFuture, headers::{ AcceptRanges, ContentLength, ContentRange, ContentType, ETag, HeaderMap, HeaderMapExt, IfMatch, IfModifiedSince, IfNoneMatch, IfUnmodifiedSince, LastModified, Range, @@ -47,12 +47,12 @@ impl File { } } -#[async_trait] impl Handler for File { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { - serve(&self.path, req.headers()) + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + let path = self.path.clone(); + Box::pin(async move { serve(&path, req.headers()) }) } } @@ -98,46 +98,52 @@ impl Dir { } } -#[async_trait] impl Handler for Dir { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { - if req.method() != Method::GET { - Err(Error::MethodNotAllowed)?; - } + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + let Self { + mut path, + listing, + unlisted, + } = self.clone(); - let mut prev = false; - let mut path = self.path.clone(); + Box::pin(async move { + if req.method() != Method::GET { + Err(Error::MethodNotAllowed)?; + } - if let Some(param) = req.route_info().params.first().map(|(_, v)| v) { - let p = percent_encoding::percent_decode_str(param) - .decode_utf8() - .map_err(|_| Error::InvalidPath)?; - sanitize_path(&mut path, &p)?; - prev = true; - } + let mut prev = false; - if !path.exists() { - Err(StatusCode::NOT_FOUND.into_error())?; - } + if let Some(param) = req.route_info().params.first().map(|(_, v)| v) { + let p = percent_encoding::percent_decode_str(param) + .decode_utf8() + .map_err(|_| Error::InvalidPath)?; + sanitize_path(&mut path, &p)?; + prev = true; + } - if path.is_file() { - return serve(&path, req.headers()); - } + if !path.exists() { + Err(StatusCode::NOT_FOUND.into_error())?; + } - let index = path.join("index.html"); - if index.exists() { - return serve(&index, req.headers()); - } + if path.is_file() { + return serve(&path, req.headers()); + } - if self.listing { - return Directory::new(req.path(), prev, &path, &self.unlisted) - .ok_or_else(|| StatusCode::INTERNAL_SERVER_ERROR.into_error()) - .map(IntoResponse::into_response); - } + let index = path.join("index.html"); + if index.exists() { + return serve(&index, req.headers()); + } + + if listing { + return Directory::new(req.path(), prev, &path, &unlisted) + .ok_or_else(|| StatusCode::INTERNAL_SERVER_ERROR.into_error()) + .map(IntoResponse::into_response); + } - Ok(StatusCode::NOT_FOUND.into_response()) + Ok(StatusCode::NOT_FOUND.into_response()) + }) } } diff --git a/viz-macros/src/lib.rs b/viz-macros/src/lib.rs index 026101a9..567601ed 100644 --- a/viz-macros/src/lib.rs +++ b/viz-macros/src/lib.rs @@ -121,16 +121,17 @@ fn generate_handler(input: TokenStream) -> Result { #[derive(Clone)] #vis struct #name; - #[viz_core::async_trait] impl viz_core::Handler for #name { type Output = viz_core::Result; #[allow(unused, unused_mut)] - async fn call(&self, mut req: viz_core::Request) -> Self::Output { - #ast - let res = #name(#(#extractors),*)#asyncness; - #out.map(viz_core::IntoResponse::into_response) + fn call(&self, mut req: viz_core::Request) -> viz_core::future::BoxFuture<'static, Self::Output> { + Box::pin(async move { + #ast + let res = #name(#(#extractors),*)#asyncness; + #out.map(viz_core::IntoResponse::into_response) + }) } } }; diff --git a/viz-router/Cargo.toml b/viz-router/Cargo.toml index e4e1a73a..0c596369 100644 --- a/viz-router/Cargo.toml +++ b/viz-router/Cargo.toml @@ -34,3 +34,4 @@ tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } [lints] workspace = true + diff --git a/viz-router/src/resources.rs b/viz-router/src/resources.rs index 42c4b89a..62c278ce 100644 --- a/viz-router/src/resources.rs +++ b/viz-router/src/resources.rs @@ -71,7 +71,7 @@ impl Resources { pub(crate) fn on(mut self, kind: Kind, method: Method, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { match self @@ -94,7 +94,7 @@ impl Resources { #[must_use] pub fn index(self, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.on(Kind::Empty, Method::GET, handler) @@ -104,7 +104,7 @@ impl Resources { #[must_use] pub fn new(self, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.on(Kind::New, Method::GET, handler) @@ -114,7 +114,7 @@ impl Resources { #[must_use] pub fn create(self, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.on(Kind::Empty, Method::POST, handler) @@ -124,7 +124,7 @@ impl Resources { #[must_use] pub fn show(self, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.on(Kind::Id, Method::GET, handler) @@ -134,7 +134,7 @@ impl Resources { #[must_use] pub fn edit(self, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.on(Kind::Edit, Method::GET, handler) @@ -144,7 +144,7 @@ impl Resources { #[must_use] pub fn update(self, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.on(Kind::Id, Method::PUT, handler) @@ -154,7 +154,7 @@ impl Resources { #[must_use] pub fn update_with_patch(self, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.on(Kind::Id, Method::PATCH, handler) @@ -164,7 +164,7 @@ impl Resources { #[must_use] pub fn destroy(self, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.on(Kind::Id, Method::DELETE, handler) @@ -200,7 +200,7 @@ impl Resources { pub fn with(self, t: T) -> Self where T: Transform, - T::Output: Handler>, + T::Output: Handler> + Send + Clone + 'static, { self.map_handler(|handler| t.transform(handler).boxed()) } @@ -209,7 +209,7 @@ impl Resources { #[must_use] pub fn with_handler(self, f: H) -> Self where - H: Handler, Output = Result> + Clone, + H: Handler, Output = Result> + Send + Clone + 'static, { self.map_handler(|handler| handler.around(f.clone()).boxed()) } @@ -259,7 +259,7 @@ mod tests { use crate::{get, Resources}; use http_body_util::BodyExt; use viz_core::{ - async_trait, Handler, HandlerExt, IntoResponse, Method, Next, Request, Response, + future::BoxFuture, Handler, HandlerExt, IntoResponse, Method, Next, Request, Response, ResponseExt, Result, Transform, }; @@ -285,15 +285,14 @@ mod tests { #[derive(Clone)] struct LoggerHandler(H); - #[async_trait] impl Handler for LoggerHandler where - H: Handler + Clone, + H: Handler + Send + Clone + 'static, { type Output = H::Output; - async fn call(&self, req: Request) -> Self::Output { - self.0.call(req).await + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + Box::pin(self.0.call(req)) } } diff --git a/viz-router/src/route.rs b/viz-router/src/route.rs index c86334d7..f75310c8 100644 --- a/viz-router/src/route.rs +++ b/viz-router/src/route.rs @@ -13,7 +13,7 @@ macro_rules! export_internal_verb { #[must_use] pub fn $name(self, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.on(Method::$verb, handler) @@ -27,7 +27,7 @@ macro_rules! export_verb { #[must_use] pub fn $name(handler: H) -> Route where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { Route::new().$name(handler) @@ -70,7 +70,7 @@ impl Route { #[must_use] pub fn on(self, method: Method, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.push(method, handler.map_into_response().boxed()) @@ -80,7 +80,7 @@ impl Route { #[must_use] pub fn any(self, handler: H) -> Self where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { [ @@ -127,7 +127,7 @@ impl Route { pub fn with(self, t: T) -> Self where T: Transform, - T::Output: Handler>, + T::Output: Handler> + Send + Clone + 'static, { self.map_handler(|handler| t.transform(handler).boxed()) } @@ -136,7 +136,7 @@ impl Route { #[must_use] pub fn with_handler(self, f: H) -> Self where - H: Handler, Output = Result> + Clone, + H: Handler, Output = Result> + Send + Clone + 'static, { self.map_handler(|handler| handler.around(f.clone()).boxed()) } @@ -166,7 +166,7 @@ impl FromIterator<(Method, BoxHandler)> for Route { /// Creates a route with a handler and HTTP verb pair. pub fn on(method: Method, handler: H) -> Route where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { Route::new().on(method, handler) @@ -188,7 +188,7 @@ repeat!( /// Creates a route with a handler and any HTTP verbs. pub fn any(handler: H) -> Route where - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { Route::new().any(handler) @@ -218,7 +218,7 @@ mod tests { use serde::Deserialize; use std::sync::Arc; use viz_core::{ - async_trait, + future::{BoxFuture, TryFutureExt}, handler::Transform, types::{Query, State}, Handler, HandlerExt, IntoHandler, IntoResponse, Method, Next, Request, RequestExt, @@ -250,15 +250,14 @@ mod tests { #[derive(Clone)] struct LoggerHandler(H); - #[async_trait] impl Handler for LoggerHandler where - H: Handler + Clone, + H: Handler + Send + Clone + 'static, { type Output = H::Output; - async fn call(&self, req: Request) -> Self::Output { - self.0.call(req).await + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + Box::pin(self.0.call(req)) } } @@ -298,16 +297,16 @@ mod tests { name: String, } - #[async_trait] impl Handler> for Around2 where I: Send + 'static, - H: Handler> + Clone, + H: Handler> + Clone + 'static, + O: 'static, { type Output = H::Output; - async fn call(&self, (i, h): Next) -> Self::Output { - h.call(i).await + fn call(&self, (i, h): Next) -> BoxFuture<'static, Self::Output> { + Box::pin(h.call(i)) } } @@ -316,16 +315,15 @@ mod tests { name: String, } - #[async_trait] impl Handler> for Around3 where H: Handler> + Clone, - O: IntoResponse, + O: IntoResponse + 'static, { type Output = Result; - async fn call(&self, (i, h): Next) -> Self::Output { - h.call(i).await.map(IntoResponse::into_response) + fn call(&self, (i, h): Next) -> BoxFuture<'static, Self::Output> { + Box::pin(h.call(i).map_ok(IntoResponse::into_response)) } } @@ -334,15 +332,14 @@ mod tests { name: String, } - #[async_trait] impl Handler> for Around4 where H: Handler> + Clone, { type Output = Result; - async fn call(&self, (i, h): Next) -> Self::Output { - h.call(i).await + fn call(&self, (i, h): Next) -> BoxFuture<'static, Self::Output> { + Box::pin(h.call(i)) } } diff --git a/viz-router/src/router.rs b/viz-router/src/router.rs index 4aa48cba..24c1e585 100644 --- a/viz-router/src/router.rs +++ b/viz-router/src/router.rs @@ -11,7 +11,7 @@ macro_rules! export_verb { pub fn $name(self, path: S, handler: H) -> Self where S: AsRef, - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.route(path, Route::new().$name(handler)) @@ -130,7 +130,7 @@ impl Router { pub fn any(self, path: S, handler: H) -> Self where S: AsRef, - H: Handler> + Clone, + H: Handler> + Send + Clone + 'static, O: IntoResponse + Send + 'static, { self.route(path, Route::new().any(handler)) @@ -140,7 +140,7 @@ impl Router { #[must_use] pub fn map_handler(self, f: F) -> Self where - F: Fn(BoxHandler) -> BoxHandler, + F: Fn(BoxHandler>) -> BoxHandler>, { Self { routes: self.routes.map(|routes| { @@ -164,8 +164,8 @@ impl Router { #[must_use] pub fn with(self, t: T) -> Self where - T: Transform, - T::Output: Handler>, + T: Transform>>, + T::Output: Handler> + Send + Clone + 'static, { self.map_handler(|handler| t.transform(handler).boxed()) } @@ -174,7 +174,10 @@ impl Router { #[must_use] pub fn with_handler(self, f: H) -> Self where - H: Handler, Output = Result> + Clone, + H: Handler>>, Output = Result> + + Send + + Clone + + 'static, { self.map_handler(|handler| handler.around(f.clone()).boxed()) } @@ -186,7 +189,7 @@ mod tests { use http_body_util::{BodyExt, Full}; use std::sync::Arc; use viz_core::{ - async_trait, + future::BoxFuture, types::{Params, RouteInfo}, Body, Error, Handler, HandlerExt, IntoResponse, Method, Next, Request, RequestExt, Response, ResponseExt, Result, StatusCode, Transform, @@ -214,15 +217,14 @@ mod tests { #[derive(Clone)] struct LoggerHandler(H); - #[async_trait] impl Handler for LoggerHandler where - H: Handler + Clone, + H: Handler + Clone + 'static, { type Output = H::Output; - async fn call(&self, req: Request) -> Self::Output { - self.0.call(req).await + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + Box::pin(self.0.call(req)) } } diff --git a/viz-router/src/tree.rs b/viz-router/src/tree.rs index c49235fe..bb0b2fe9 100644 --- a/viz-router/src/tree.rs +++ b/viz-router/src/tree.rs @@ -7,7 +7,7 @@ use viz_core::{BoxHandler, Method}; use crate::{Route, Router}; /// Store all final routes. -#[derive(Default)] +#[derive(Clone, Default)] pub struct Tree(Vec<(Method, PathTree)>); impl Tree { @@ -20,7 +20,7 @@ impl Tree { ) -> Option<(&'a BoxHandler, Path<'a, 'b>)> { self.0 .iter() - .find_map(|(m, t)| if m == method { t.find(path) } else { None }) + .find_map(|(m, t)| (m == method).then(|| t.find(path)).flatten()) } /// Consumes the Tree, returning the wrapped value. diff --git a/viz-tower/src/lib.rs b/viz-tower/src/lib.rs index 535690e8..4459391f 100644 --- a/viz-tower/src/lib.rs +++ b/viz-tower/src/lib.rs @@ -2,7 +2,8 @@ use tower::{Service, ServiceExt}; use viz_core::{ - async_trait, Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result, + future::{BoxFuture, TryFutureExt}, + Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result, }; mod service; @@ -25,25 +26,25 @@ impl ServiceHandler { } } -#[async_trait] impl Handler for ServiceHandler where O: HttpBody + Send + 'static, O::Data: Into, O::Error: Into, - S: Service> + Send + Sync + Clone + 'static, + S: Service> + Send + Clone + 'static, S::Future: Send, S::Error: Into, { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { - self.0 - .clone() - .oneshot(req) - .await - .map(|resp| resp.map(Body::wrap)) - .map_err(Error::boxed) + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + Box::pin( + self.0 + .clone() + .oneshot(req) + .map_ok(|resp| resp.map(Body::wrap)) + .map_err(Error::boxed), + ) } } diff --git a/viz-tower/src/middleware.rs b/viz-tower/src/middleware.rs index 45d9f5b0..f673f6ec 100644 --- a/viz-tower/src/middleware.rs +++ b/viz-tower/src/middleware.rs @@ -1,6 +1,7 @@ use tower::{Layer, Service, ServiceExt}; use viz_core::{ - async_trait, Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result, + future::{BoxFuture, TryFutureExt}, + Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result, }; use crate::HandlerService; @@ -19,27 +20,27 @@ impl Middleware { } } -#[async_trait] impl Handler for Middleware where L: Layer> + Send + Sync + Clone + 'static, - H: Handler> + Send + Sync + Clone + 'static, + H: Handler> + Send + Clone + 'static, O: HttpBody + Send + 'static, O::Data: Into, O::Error: Into, - L::Service: Service> + Send + Sync + Clone + 'static, + L::Service: Service> + Send + Clone + 'static, >::Future: Send, >::Error: Into, { type Output = Result; - async fn call(&self, req: Request) -> Self::Output { - self.l - .clone() - .layer(HandlerService::new(self.h.clone())) - .oneshot(req) - .await - .map(|resp| resp.map(Body::wrap)) - .map_err(Error::boxed) + fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { + Box::pin( + self.l + .clone() + .layer(HandlerService::new(self.h.clone())) + .oneshot(req) + .map_ok(|resp| resp.map(Body::wrap)) + .map_err(Error::boxed), + ) } } diff --git a/viz-tower/src/service.rs b/viz-tower/src/service.rs index 6247618b..bc03ea4b 100644 --- a/viz-tower/src/service.rs +++ b/viz-tower/src/service.rs @@ -1,8 +1,7 @@ -use std::pin::Pin; use std::task::{Context, Poll}; use tower::Service; -use viz_core::{Error, Future, Handler, Request, Response, Result}; +use viz_core::{future::BoxFuture, Error, Handler, Request, Response, Result}; /// An adapter that makes a [`Handler`] into a [`Service`]. #[derive(Debug)] @@ -30,7 +29,7 @@ where { type Response = Response; type Error = Error; - type Future = Pin> + Send + 'static>>; + type Future = BoxFuture<'static, Result>; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { diff --git a/viz/Cargo.toml b/viz/Cargo.toml index b341bea8..fab7a76f 100644 --- a/viz/Cargo.toml +++ b/viz/Cargo.toml @@ -78,6 +78,7 @@ hyper-util.workspace = true rustls-pemfile = { workspace = true, optional = true } +sync_wrapper.workspace = true futures-util = { workspace = true, optional = true } tokio-native-tls = { workspace = true, optional = true } tokio-rustls = { workspace = true, optional = true } diff --git a/viz/src/lib.rs b/viz/src/lib.rs index 354cc9f7..e9c408a2 100644 --- a/viz/src/lib.rs +++ b/viz/src/lib.rs @@ -72,21 +72,23 @@ //! //! ``` //! # use std::sync::{Arc, atomic::{AtomicUsize, Ordering}}; -//! # use viz::{async_trait, Handler, IntoResponse, Request, RequestExt, Response, Result}; +//! # use viz::{Handler, IntoResponse, Request, RequestExt, Response, Result, future::BoxFuture}; //! #[derive(Clone)] //! struct MyHandler { //! code: Arc, //! } //! -//! #[async_trait] //! impl Handler for MyHandler { -//! type Output = Result; -//! -//! async fn call(&self, req: Request) -> Self::Output { -//! let path = req.path(); -//! let method = req.method().clone(); -//! let code = self.code.fetch_add(1, Ordering::SeqCst); -//! Ok(format!("code = {}, method = {}, path = {}", code, method, path).into_response()) +//! type Output = Result; +//! +//! fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { +//! let code = self.code.clone(); +//! Box::pin(async move { +//! let path = req.path(); +//! let method = req.method().clone(); +//! let code = code.fetch_add(1, Ordering::SeqCst); +//! Ok(format!("code = {}, method = {}, path = {}", code, method, path).into_response()) +//! }) //! } //! } //! ``` @@ -220,8 +222,9 @@ //! ``` //! # use std::time::Duration; //! # use viz::{ -//! # async_trait, get, types::Params, Transform, HandlerExt, IntoResponse, IntoHandler, -//! # Request, Response, ResponseExt, Result, Router, StatusCode, Next, Handler +//! # get, types::Params, Transform, HandlerExt, IntoResponse, IntoHandler, +//! # Request, Response, ResponseExt, Result, Router, StatusCode, Next, Handler, +//! # future::BoxFuture, //! # }; //! async fn index(_: Request) -> Result { //! Ok(StatusCode::OK.into_response()) @@ -238,7 +241,7 @@ //! // middleware fn //! async fn around((req, handler): Next) -> Result //! where -//! H: Handler>, +//! H: Handler> + Send + Clone, //! { //! // before ... //! let result = handler.call(req).await; @@ -250,15 +253,14 @@ //! #[derive(Clone)] //! struct MyMiddleware {} //! -//! #[async_trait] //! impl Handler> for MyMiddleware //! where -//! H: Handler, +//! H: Handler + Send + Clone + 'static, //! { //! type Output = H::Output; //! -//! async fn call(&self, (i, h): Next) -> Self::Output { -//! h.call(i).await +//! fn call(&self, (i, h): Next) -> BoxFuture<'static, Self::Output> { +//! Box::pin(h.call(i)) //! } //! } //! @@ -285,15 +287,14 @@ //! #[derive(Clone)] //! struct TimeoutMiddleware(H, Duration); //! -//! #[async_trait] //! impl Handler for TimeoutMiddleware //! where -//! H: Handler + Clone, +//! H: Handler + Send + Clone + 'static, //! { //! type Output = H::Output; //! -//! async fn call(&self, req: Request) -> Self::Output { -//! self.0.call(req).await +//! fn call(&self, req: Request) -> BoxFuture<'static, Self::Output> { +//! Box::pin(self.0.call(req)) //! } //! } //!