diff --git a/Cargo.toml b/Cargo.toml index d6bd4ae9..7f3aa1be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "viz-handlers", "viz-macros", "viz-router", + "viz-tower", "viz-test", "examples/hello-world", diff --git a/viz-core/src/body.rs b/viz-core/src/body.rs index 3780ac39..e7ae63f7 100644 --- a/viz-core/src/body.rs +++ b/viz-core/src/body.rs @@ -45,6 +45,19 @@ impl Body { Self::Empty } + /// Wraps a body into box. + pub fn wrap(body: B) -> Self + where + B: HttpBody + Send + 'static, + B::Data: Into, + B::Error: Into, + { + body.map_frame(|frame| frame.map_data(Into::into)) + .map_err(Error::boxed) + .boxed_unsync() + .into() + } + /// A body created from a [`Stream`]. pub fn from_stream(stream: S) -> Self where @@ -52,15 +65,14 @@ impl Body { S::Ok: Into, S::Error: Into, { - Self::Boxed(SyncWrapper::new( - StreamBody::new( - stream - .map_ok(Into::into) - .map_ok(Frame::data) - .map_err(Error::boxed), - ) - .boxed_unsync(), - )) + StreamBody::new( + stream + .map_ok(Into::into) + .map_ok(Frame::data) + .map_err(Error::boxed), + ) + .boxed_unsync() + .into() } /// A stream created from a [`http_body::Body`]. diff --git a/viz-core/src/handler/service.rs b/viz-core/src/handler/service.rs index 258570b0..5c7a3121 100644 --- a/viz-core/src/handler/service.rs +++ b/viz-core/src/handler/service.rs @@ -38,8 +38,8 @@ where body.map_frame(|f| f.map_data(Into::into)) .map_err(Into::into) .boxed_unsync() + .into() }) - .map(Into::into) }) .map_err(Into::into) } diff --git a/viz-core/src/request.rs b/viz-core/src/request.rs index d58892d4..11222bac 100644 --- a/viz-core/src/request.rs +++ b/viz-core/src/request.rs @@ -254,11 +254,11 @@ impl RequestExt for Request { match state { BodyState::Empty => Err(PayloadError::Empty)?, BodyState::Used => Err(PayloadError::Used)?, - BodyState::Normal => unreachable!(), + BodyState::Normal => {} } } - let (state, result) = match std::mem::replace(self.body_mut(), Body::empty()) { + let (state, result) = match std::mem::replace(self.body_mut(), Body::Empty) { Body::Empty => (BodyState::Empty, Err(PayloadError::Empty)), body => (BodyState::Used, Ok(body)), }; diff --git a/viz-tower/Cargo.toml b/viz-tower/Cargo.toml new file mode 100644 index 00000000..e7d1165a --- /dev/null +++ b/viz-tower/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "viz-tower" +version = "0.1.0" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +documentation.workspace = true +repository.workspace = true +license.workspace = true +rust-version.workspace = true + +[dependencies] +viz-core.workspace = true +http-body-util.workspace = true +tower = { version = "0.4", features = ["util"] } + +[dev-dependencies] +tokio = { workspace = true, features = ["rt-multi-thread", "macros", "test-util"] } +tower-http = { version = "0.5", features = ["limit", "request-id", "timeout"] } + +[lints] +workspace = true diff --git a/viz-tower/src/lib.rs b/viz-tower/src/lib.rs new file mode 100644 index 00000000..8ce1c1f6 --- /dev/null +++ b/viz-tower/src/lib.rs @@ -0,0 +1,112 @@ +//! An adapter that makes a tower [`Service`] into a [`Handler`]. + +use http_body_util::BodyExt; +use tower::{Service, ServiceExt}; +use viz_core::{async_trait, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result}; + +/// Converts a tower [`Service`] into a [`Handler`]. +#[derive(Debug, Clone)] +pub struct TowerServiceHandler(S); + +impl TowerServiceHandler { + /// Creates a new [`TowerServiceHandler`]. + pub fn new(s: S) -> Self { + Self(s) + } +} + +#[async_trait] +impl Handler for TowerServiceHandler +where + O: HttpBody + Send + 'static, + O::Data: Into, + O::Error: Into, + S: Service> + Send + Sync + Clone + 'static, + S::Future: Send, + S::Error: Into, +{ + type Output = Result; + + async fn call(&self, req: Request) -> Self::Output { + self.0 + .clone() + .oneshot(req) + .await + .map(|resp| { + resp.map(|body| { + body.map_frame(|f| f.map_data(Into::into)) + .map_err(Error::boxed) + .boxed_unsync() + .into() + }) + }) + .map_err(Error::boxed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{ + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, + }; + use tower::util::{MapErrLayer, MapRequestLayer, MapResponseLayer}; + use tower::{service_fn, ServiceBuilder}; + use tower_http::{ + limit::RequestBodyLimitLayer, + request_id::{MakeRequestId, RequestId, SetRequestIdLayer}, + timeout::TimeoutLayer, + }; + use viz_core::{ + Body, BoxHandler, Handler, HandlerExt, IntoResponse, Request, RequestExt, Response, + }; + + #[derive(Clone, Default, Debug)] + struct MyMakeRequestId { + counter: Arc, + } + + impl MakeRequestId for MyMakeRequestId { + fn make_request_id(&mut self, _: &Request) -> Option { + let request_id = self + .counter + .fetch_add(1, Ordering::SeqCst) + .to_string() + .parse() + .unwrap(); + + Some(RequestId::new(request_id)) + } + } + + async fn hello(mut req: Request) -> Result { + let bytes = req.bytes().await?; + Ok(bytes.into_response()) + } + + #[tokio::test] + async fn tower_service_into_handler() { + let hello_svc = service_fn(hello); + + let svc = ServiceBuilder::new() + .layer(RequestBodyLimitLayer::new(1)) + .layer(MapErrLayer::new(Error::from)) + .layer(SetRequestIdLayer::x_request_id(MyMakeRequestId::default())) + .layer(MapResponseLayer::new(IntoResponse::into_response)) + .layer(MapRequestLayer::new(|req: Request<_>| req.map(Body::wrap))) + .layer(TimeoutLayer::new(Duration::from_secs(10))) + .service(hello_svc); + + let r0 = Request::new(Body::Full("12".into())); + let h0 = TowerServiceHandler::new(svc); + assert!(h0.call(r0).await.is_err()); + + let r1 = Request::new(Body::Full("1".into())); + let b0: BoxHandler = h0.boxed(); + assert!(b0.call(r1).await.is_ok()); + } +}