diff --git a/Cargo.toml b/Cargo.toml index 7d348731..cf5df9d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ members = [ "examples/graceful-shutdown", "examples/databases/*", "examples/htmlx", + "examples/tower", ] [workspace.package] @@ -52,6 +53,7 @@ viz-router = { version = "0.7.1", path = "viz-router" } viz-handlers = { version = "0.7.1", path = "viz-handlers", default-features = false } viz-macros = { version = "0.2.0", path = "viz-macros" } viz-test = { version = "0.2.0", path = "viz-test" } +viz-tower = { version = "0.1.0", path = "viz-tower" } async-trait = "0.1" dyn-clone = "1.0" @@ -103,6 +105,10 @@ prometheus = "0.13" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } +# Tower +tower = "0.4" +tower-http = "0.5" + [workspace.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] diff --git a/examples/README.md b/examples/README.md index 8b8044ab..3d4e1785 100644 --- a/examples/README.md +++ b/examples/README.md @@ -33,6 +33,7 @@ Here you can find a lot of small crabs 🦀. * [minijinja](templates/minijinja) * [Tracing aka logging](tracing) * [htmlx](htmlx) +* [Tower Services](tower) ## Usage diff --git a/examples/tower/Cargo.toml b/examples/tower/Cargo.toml new file mode 100644 index 00000000..3005148c --- /dev/null +++ b/examples/tower/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "tower-example" +version = "0.1.0" +edition.workspace = true +publish = false + +[dependencies] +viz.workspace = true +viz-tower.workspace = true + +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tracing.workspace = true +tracing-subscriber = { workspace = true, features = ["env-filter"] } +tower.workspace = true +tower-http = { workspace = true, features = ["full"] } + +[lints] +workspace = true diff --git a/examples/tower/src/main.rs b/examples/tower/src/main.rs new file mode 100644 index 00000000..6371445a --- /dev/null +++ b/examples/tower/src/main.rs @@ -0,0 +1,75 @@ +//! Viz + Tower services + +use std::{net::SocketAddr, sync::Arc}; +use tokio::net::TcpListener; +use tower::{ + service_fn, + util::{MapErrLayer, MapRequestLayer, MapResponseLayer}, + ServiceBuilder, +}; +use tower_http::{ + limit::RequestBodyLimitLayer, + request_id::{MakeRequestUuid, SetRequestIdLayer}, + trace::TraceLayer, +}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use viz::{serve, Body, Error, IntoResponse, Request, Response, Result, Router, Tree}; +use viz_tower::{Layered, ServiceHandler}; + +async fn index(_: Request) -> Result { + Ok("Hello, World!".into_response()) +} + +async fn about(_: Request) -> Result<&'static str> { + Ok("About me!") +} + +async fn any(_: Request) -> Result { + Ok("std::any::Any".to_string()) +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "tower-example=debug,tower_http=debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let index_svc = ServiceBuilder::new() + .layer(MapRequestLayer::new(|req: Request<_>| req.map(Body::wrap))) + .service(service_fn(index)); + let index_handler = ServiceHandler::new(index_svc); + + let any_svc = ServiceBuilder::new() + .layer(MapResponseLayer::new(IntoResponse::into_response)) + .service_fn(any); + let any_handler = ServiceHandler::new(any_svc); + + let layer = ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(RequestBodyLimitLayer::new(1024)) + .layer(SetRequestIdLayer::x_request_id(MakeRequestUuid)) + .layer(MapErrLayer::new(Error::from)) + .layer(MapResponseLayer::new(IntoResponse::into_response)) + .layer(MapRequestLayer::new(|req: Request<_>| req.map(Body::wrap))); + + let app = Router::new() + .get("/", index_handler) + .get("/about", about) + .any("/*", any_handler) + .with(Layered::new(layer)); + let tree = Arc::new(Tree::from(app)); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + let listener = TcpListener::bind(addr).await?; + println!("listening on http://{addr}"); + + loop { + let (stream, addr) = listener.accept().await?; + let tree = tree.clone(); + tokio::task::spawn(serve(stream, tree, Some(addr))); + } +} diff --git a/examples/tracing/Cargo.toml b/examples/tracing/Cargo.toml index 2917a3e1..d813319c 100644 --- a/examples/tracing/Cargo.toml +++ b/examples/tracing/Cargo.toml @@ -7,6 +7,6 @@ publish = false [dependencies] viz.workspace = true -tokio = { workspace = true, features = [ "rt-multi-thread", "macros" ] } +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing.workspace = true tracing-subscriber = { workspace = true, features = ["env-filter"] } diff --git a/viz-core/src/body.rs b/viz-core/src/body.rs index f9d13c2d..b33cbec5 100644 --- a/viz-core/src/body.rs +++ b/viz-core/src/body.rs @@ -46,15 +46,23 @@ impl Body { } /// Wraps a body into box. + #[allow(clippy::missing_panics_doc)] pub fn wrap(body: B) -> Self where B: HttpBody + Send + 'static, B::Data: Into, B::Error: Into, { - body.map_frame(|frame| frame.map_data(Into::into)) - .map_err(Error::boxed) - .boxed_unsync() + // Copied from Axum, thanks. + let mut body = Some(body); + ::downcast_mut::>>(&mut body) + .and_then(Option::take) + .unwrap_or_else(|| { + body.unwrap() + .map_frame(|frame| frame.map_data(Into::into)) + .map_err(Error::boxed) + .boxed_unsync() + }) .into() } diff --git a/viz-core/src/handler/service.rs b/viz-core/src/handler/service.rs index 5c7a3121..56c7fa43 100644 --- a/viz-core/src/handler/service.rs +++ b/viz-core/src/handler/service.rs @@ -1,46 +1,38 @@ -use http_body_util::BodyExt; use hyper::service::Service; -use crate::{async_trait, Bytes, Error, Handler, HttpBody, Request, Response, Result}; +use crate::{ + async_trait, Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result, +}; /// Converts a hyper [`Service`] to a viz [`Handler`]. #[derive(Debug, Clone)] -pub struct ServiceHandler { - s: S, -} +pub struct ServiceHandler(S); impl ServiceHandler { /// Creates a new [`ServiceHandler`]. pub fn new(s: S) -> Self { - Self { s } + Self(s) } } #[async_trait] impl Handler> for ServiceHandler where - I: HttpBody + Send + Unpin + 'static, + I: HttpBody + Send + 'static, O: HttpBody + Send + 'static, O::Data: Into, - O::Error: Into, + O::Error: Into, S: Service, Response = Response> + Send + Sync + Clone + 'static, S::Future: Send, - S::Error: Into, + S::Error: Into, { type Output = Result; async fn call(&self, req: Request) -> Self::Output { - self.s + self.0 .call(req) .await - .map(|resp| { - resp.map(|body| { - body.map_frame(|f| f.map_data(Into::into)) - .map_err(Into::into) - .boxed_unsync() - .into() - }) - }) - .map_err(Into::into) + .map(|resp| resp.map(Body::wrap)) + .map_err(Error::boxed) } } diff --git a/viz-router/src/router.rs b/viz-router/src/router.rs index 7f25e5d0..4aa48cba 100644 --- a/viz-router/src/router.rs +++ b/viz-router/src/router.rs @@ -366,7 +366,7 @@ mod tests { req.extensions_mut().insert(Arc::from(RouteInfo { id: *route.id, pattern: route.pattern(), - params: Into::::into(route.params()), + params: route.params().into(), })); assert_eq!( h.call(req).await?.into_body().collect().await?.to_bytes(), @@ -396,7 +396,7 @@ mod tests { req.extensions_mut().insert(Arc::from(RouteInfo { id: *route.id, pattern: route.pattern(), - params: Into::::into(route.params()), + params: route.params().into(), })); assert_eq!( h.call(req).await?.into_body().collect().await?.to_bytes(), @@ -431,7 +431,7 @@ mod tests { req.extensions_mut().insert(Arc::from(RouteInfo { id: *route.id, pattern: route.pattern(), - params: Into::::into(route.params()), + params: route.params().into(), })); assert_eq!( h.call(req).await?.into_body().collect().await?.to_bytes(), @@ -446,7 +446,7 @@ mod tests { let route_info = Arc::from(RouteInfo { id: *route.id, pattern: route.pattern(), - params: Into::::into(route.params()), + params: route.params().into(), }); assert_eq!(route.pattern(), "/posts/:post_id/users/:user_id"); assert_eq!(route_info.pattern, "/posts/:post_id/users/:user_id"); @@ -464,7 +464,7 @@ mod tests { req.extensions_mut().insert(Arc::from(RouteInfo { id: *route.id, pattern: route.pattern(), - params: Into::::into(route.params()), + params: route.params().into(), })); assert_eq!( h.call(req).await?.into_body().collect().await?.to_bytes(), diff --git a/viz-tower/Cargo.toml b/viz-tower/Cargo.toml index 3eb0a749..2440efcb 100644 --- a/viz-tower/Cargo.toml +++ b/viz-tower/Cargo.toml @@ -3,6 +3,7 @@ name = "viz-tower" version = "0.1.0" documentation = "https://docs.rs/viz-tower" description = "An adapter for tower service" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true @@ -13,11 +14,11 @@ rust-version.workspace = true [dependencies] viz-core.workspace = true http-body-util.workspace = true -tower = { version = "0.4", features = ["util"] } +tower = { workspace = true, features = ["util"] } [dev-dependencies] tokio = { workspace = true, features = ["rt-multi-thread", "macros", "test-util"] } -tower-http = { version = "0.5", features = ["limit", "request-id", "timeout"] } +tower-http = { workspace = true, features = ["limit", "request-id", "timeout"] } [lints] workspace = true diff --git a/viz-tower/README.md b/viz-tower/README.md new file mode 100644 index 00000000..db0bff1a --- /dev/null +++ b/viz-tower/README.md @@ -0,0 +1,9 @@ +viz-tower +--------- + +An adapter that makes a tower [`Service`] into a [`Handler`]. + +See [tower example](../examples/tower/). + +[`Service`]: https://docs.rs/tower/latest/tower/trait.Service.html +[`Handler`]: https://docs.rs/viz/latest/viz/trait.Handler.html diff --git a/viz-tower/src/lib.rs b/viz-tower/src/lib.rs index 8ce1c1f6..fb5caef7 100644 --- a/viz-tower/src/lib.rs +++ b/viz-tower/src/lib.rs @@ -1,22 +1,29 @@ //! An adapter that makes a tower [`Service`] into a [`Handler`]. -use http_body_util::BodyExt; use tower::{Service, ServiceExt}; -use viz_core::{async_trait, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result}; +use viz_core::{ + async_trait, Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result, +}; + +mod service; +pub use service::HandlerService; + +mod middleware; +pub use middleware::{Layered, Middleware}; /// Converts a tower [`Service`] into a [`Handler`]. #[derive(Debug, Clone)] -pub struct TowerServiceHandler(S); +pub struct ServiceHandler(S); -impl TowerServiceHandler { - /// Creates a new [`TowerServiceHandler`]. +impl ServiceHandler { + /// Creates a new [`ServiceHandler`]. pub fn new(s: S) -> Self { Self(s) } } #[async_trait] -impl Handler for TowerServiceHandler +impl Handler for ServiceHandler where O: HttpBody + Send + 'static, O::Data: Into, @@ -32,14 +39,7 @@ where .clone() .oneshot(req) .await - .map(|resp| { - resp.map(|body| { - body.map_frame(|f| f.map_data(Into::into)) - .map_err(Error::boxed) - .boxed_unsync() - .into() - }) - }) + .map(|resp| resp.map(Body::wrap)) .map_err(Error::boxed) } } @@ -102,7 +102,7 @@ mod tests { .service(hello_svc); let r0 = Request::new(Body::Full("12".into())); - let h0 = TowerServiceHandler::new(svc); + let h0 = ServiceHandler::new(svc); assert!(h0.call(r0).await.is_err()); let r1 = Request::new(Body::Full("1".into())); diff --git a/viz-tower/src/middleware.rs b/viz-tower/src/middleware.rs new file mode 100644 index 00000000..ba2e5d22 --- /dev/null +++ b/viz-tower/src/middleware.rs @@ -0,0 +1,68 @@ +use tower::{Layer, Service, ServiceExt}; +use viz_core::{ + async_trait, Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result, + Transform, +}; + +use crate::HandlerService; + +/// Transforms a Tower layer into Viz Middleware. +#[derive(Debug)] +pub struct Layered(L); + +impl Layered { + /// Creates a new tower layer. + pub fn new(l: L) -> Self { + Self(l) + } +} + +impl Transform for Layered +where + L: Clone, +{ + type Output = Middleware; + + fn transform(&self, h: H) -> Self::Output { + Middleware::new(self.0.clone(), h) + } +} + +/// A [`Service`] created from a [`Handler`] by applying a Tower middleware. +#[derive(Debug, Clone)] +pub struct Middleware { + l: L, + h: H, +} + +impl Middleware { + /// Creates a new tower middleware. + pub fn new(l: L, h: H) -> Self { + Self { l, h } + } +} + +#[async_trait] +impl Handler for Middleware +where + L: Layer> + Send + Sync + Clone + 'static, + H: Handler> + Send + Sync + Clone + 'static, + O: HttpBody + Send + 'static, + O::Data: Into, + O::Error: Into, + L::Service: Service> + Send + Sync + Clone + 'static, + >::Future: Send, + >::Error: Into, +{ + type Output = Result; + + async fn call(&self, req: Request) -> Self::Output { + self.l + .clone() + .layer(HandlerService::new(self.h.clone())) + .oneshot(req) + .await + .map(|resp| resp.map(Body::wrap)) + .map_err(Error::boxed) + } +} diff --git a/viz-tower/src/service.rs b/viz-tower/src/service.rs new file mode 100644 index 00000000..96c01830 --- /dev/null +++ b/viz-tower/src/service.rs @@ -0,0 +1,44 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tower::Service; +use viz_core::{Error, Future, Handler, Request, Response, Result}; + +/// An adapter that makes a [`Handler`] into a [`Service`]. +#[derive(Debug)] +pub struct HandlerService(H); + +impl HandlerService { + /// Creates a new [`HandlerService`]. + pub fn new(h: H) -> Self { + Self(h) + } +} + +impl Clone for HandlerService +where + H: Clone, +{ + fn clone(&self) -> Self { + HandlerService(self.0.clone()) + } +} + +impl Service for HandlerService +where + H: Handler> + Clone + Send + 'static, +{ + type Response = Response; + type Error = Error; + type Future = Pin> + Send + 'static>>; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let handler = self.0.clone(); + Box::pin(async move { handler.call(req).await }) + } +} diff --git a/viz/README.md b/viz/README.md index 7c29ea59..f8ff5042 100644 --- a/viz/README.md +++ b/viz/README.md @@ -49,6 +49,8 @@ - Simple + Flexible `Handler` & `Middleware` +- Supports: Tower `Service` + ## Hello Viz ```rust