diff --git a/volo-http/src/extract.rs b/volo-http/src/extract.rs index c9688a61..ebc4cba5 100644 --- a/volo-http/src/extract.rs +++ b/volo-http/src/extract.rs @@ -1,10 +1,28 @@ use std::convert::Infallible; +use bytes::Bytes; use futures_util::Future; -use hyper::http::{Method, Uri}; +use http_body_util::BodyExt; +use hyper::{ + body::Incoming, + http::{header, HeaderMap, Method, StatusCode, Uri}, +}; +use serde::de::DeserializeOwned; use volo::net::Address; -use crate::{HttpContext, Params, State}; +use crate::{ + param::Params, + response::{IntoResponse, Response}, + HttpContext, +}; + +mod private { + #[derive(Debug, Clone, Copy)] + pub enum ViaContext {} + + #[derive(Debug, Clone, Copy)] + pub enum ViaRequest {} +} pub trait FromContext: Sized { fn from_context( @@ -13,6 +31,19 @@ pub trait FromContext: Sized { ) -> impl Future> + Send; } +pub trait FromRequest: Sized { + fn from( + cx: &HttpContext, + body: Incoming, + state: &S, + ) -> impl Future> + Send; +} + +#[derive(Debug, Default, Clone, Copy)] +pub struct State(pub S); + +pub struct Json(pub T); + impl FromContext for Option where T: FromContext, @@ -67,3 +98,93 @@ where Ok(State(state.clone())) } } + +impl FromRequest for T +where + T: FromContext + Sync, + S: Sync, +{ + async fn from(cx: &HttpContext, _body: Incoming, state: &S) -> Result { + match T::from_context(cx, state).await { + Ok(value) => Ok(value), + Err(rejection) => Err(rejection.into_response()), + } + } +} + +impl FromRequest for Incoming +where + S: Sync, +{ + async fn from(_cx: &HttpContext, body: Incoming, _state: &S) -> Result { + Ok(body) + } +} + +impl FromRequest for Json { + fn from( + cx: &HttpContext, + body: Incoming, + _state: &S, + ) -> impl Future> + Send { + async move { + if !json_content_type(&cx.headers) { + return Err(Response::builder() + .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) + .body(Bytes::new().into()) + .unwrap() + .into()); + } + + match body.collect().await { + Ok(body) => { + let body = body.to_bytes(); + match serde_json::from_slice::(body.as_ref()) { + Ok(t) => Ok(Self(t)), + Err(e) => { + tracing::warn!("json serialization error {e}"); + Err(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Bytes::new().into()) + .unwrap() + .into()) + } + } + } + Err(e) => { + tracing::warn!("collect body error: {e}"); + Err(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Bytes::new().into()) + .unwrap() + .into()) + } + } + } + } +} + +fn json_content_type(headers: &HeaderMap) -> bool { + let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { + content_type + } else { + return false; + }; + + let content_type = if let Ok(content_type) = content_type.to_str() { + content_type + } else { + return false; + }; + + let mime = if let Ok(mime) = content_type.parse::() { + mime + } else { + return false; + }; + + let is_json_content_type = mime.type_() == "application" + && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json")); + + is_json_content_type +} diff --git a/volo-http/src/handler.rs b/volo-http/src/handler.rs index b6c21481..3e57cfb3 100644 --- a/volo-http/src/handler.rs +++ b/volo-http/src/handler.rs @@ -4,9 +4,8 @@ use hyper::body::Incoming; use motore::Service; use crate::{ - extract::FromContext, + extract::{FromContext, FromRequest}, macros::{all_the_tuples, all_the_tuples_no_last_special_case}, - request::FromRequest, response::{IntoResponse, Response}, DynService, HttpContext, }; diff --git a/volo-http/src/lib.rs b/volo-http/src/lib.rs index c794c92b..cbbb7ace 100644 --- a/volo-http/src/lib.rs +++ b/volo-http/src/lib.rs @@ -19,17 +19,15 @@ pub use hyper::{ pub use volo::net::Address; pub use crate::{ + extract::{Json, State}, param::Params, - request::{Json, Request}, + request::Request, response::Response, server::Server, }; pub type DynService = motore::BoxCloneService; -#[derive(Debug, Default, Clone, Copy)] -pub struct State(pub S); - pub struct HttpContext { pub peer: Address, pub method: Method, diff --git a/volo-http/src/request.rs b/volo-http/src/request.rs index 111d5e34..48794c13 100644 --- a/volo-http/src/request.rs +++ b/volo-http/src/request.rs @@ -1,19 +1,6 @@ use std::ops::{Deref, DerefMut}; -use bytes::Bytes; -use futures_util::Future; -use http_body_util::BodyExt; -use hyper::{ - body::Incoming, - http::{header, request::Builder, HeaderMap, StatusCode}, -}; -use serde::de::DeserializeOwned; - -use crate::{ - extract::FromContext, - response::{IntoResponse, Response}, - HttpContext, -}; +use hyper::{body::Incoming, http::request::Builder}; pub struct Request(pub(crate) hyper::http::Request); @@ -42,111 +29,3 @@ impl From> for Request { Self(value) } } - -mod private { - #[derive(Debug, Clone, Copy)] - pub enum ViaContext {} - - #[derive(Debug, Clone, Copy)] - pub enum ViaRequest {} -} - -pub trait FromRequest: Sized { - fn from( - cx: &HttpContext, - body: Incoming, - state: &S, - ) -> impl Future> + Send; -} - -impl FromRequest for T -where - T: FromContext + Sync, - S: Sync, -{ - async fn from(cx: &HttpContext, _body: Incoming, state: &S) -> Result { - match T::from_context(cx, state).await { - Ok(value) => Ok(value), - Err(rejection) => Err(rejection.into_response()), - } - } -} - -impl FromRequest for Incoming -where - S: Sync, -{ - async fn from(_cx: &HttpContext, body: Incoming, _state: &S) -> Result { - Ok(body) - } -} - -pub struct Json(pub T); - -impl FromRequest for Json { - fn from( - cx: &HttpContext, - body: Incoming, - _state: &S, - ) -> impl Future> + Send { - async move { - if !json_content_type(&cx.headers) { - return Err(Response::builder() - .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) - .body(Bytes::new().into()) - .unwrap() - .into()); - } - - match body.collect().await { - Ok(body) => { - let body = body.to_bytes(); - match serde_json::from_slice::(body.as_ref()) { - Ok(t) => Ok(Self(t)), - Err(e) => { - tracing::warn!("json serialization error {e}"); - Err(Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Bytes::new().into()) - .unwrap() - .into()) - } - } - } - Err(e) => { - tracing::warn!("collect body error: {e}"); - Err(Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Bytes::new().into()) - .unwrap() - .into()) - } - } - } - } -} - -fn json_content_type(headers: &HeaderMap) -> bool { - let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { - content_type - } else { - return false; - }; - - let content_type = if let Ok(content_type) = content_type.to_str() { - content_type - } else { - return false; - }; - - let mime = if let Ok(mime) = content_type.parse::() { - mime - } else { - return false; - }; - - let is_json_content_type = mime.type_() == "application" - && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json")); - - is_json_content_type -}