diff --git a/viz-core/src/lib.rs b/viz-core/src/lib.rs index b6d7d289..0b2708c7 100644 --- a/viz-core/src/lib.rs +++ b/viz-core/src/lib.rs @@ -34,6 +34,8 @@ pub use into_response::IntoResponse; mod request; pub use request::RequestExt; +#[cfg(feature = "limits")] +pub use request::RequestLimitsExt; mod response; pub use response::ResponseExt; diff --git a/viz-core/src/request.rs b/viz-core/src/request.rs index aac228e6..1f5a0105 100644 --- a/viz-core/src/request.rs +++ b/viz-core/src/request.rs @@ -91,16 +91,6 @@ pub trait RequestExt: private::Sealed + Sized { /// [mdn]: fn bytes(&mut self) -> impl Future> + Send; - /// Return with a [Bytes][mdn] by a limit representation of the request body. - /// - /// [mdn]: - #[cfg(feature = "limits")] - fn bytes_with( - &mut self, - limit: Option, - max: u64, - ) -> impl Future> + Send; - /// Return with a [Text][mdn] representation of the request body. /// /// [mdn]: @@ -156,10 +146,6 @@ pub trait RequestExt: private::Sealed + Sized { where S: AsRef; - /// Get limits settings. - #[cfg(feature = "limits")] - fn limits(&self) -> &Limits; - /// Get current session. #[cfg(feature = "session")] fn session(&self) -> &Session; @@ -288,34 +274,8 @@ impl RequestExt for Request { .map(Collected::to_bytes) } - #[cfg(feature = "limits")] - async fn bytes_with(&mut self, limit: Option, max: u64) -> Result { - Limited::new( - self.incoming()?, - usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX), - ) - .collect() - .await - .map_err(|err| { - if err.is::() { - return PayloadError::TooLarge; - } - if let Ok(err) = err.downcast::() { - return PayloadError::Hyper(*err); - } - PayloadError::Read - }) - .map(Collected::to_bytes) - } - async fn text(&mut self) -> Result { - #[cfg(feature = "limits")] - let bytes = self - .bytes_with(self.limits().get("text"), Limits::NORMAL) - .await?; - #[cfg(not(feature = "limits"))] let bytes = self.bytes().await?; - String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8) } @@ -324,18 +284,8 @@ impl RequestExt for Request { where T: serde::de::DeserializeOwned, { - #[cfg(feature = "limits")] - let limit = self.limits().get(
::NAME); - #[cfg(not(feature = "limits"))] - let limit = None; - - ::check_header(self.content_type(), self.content_length(), limit)?; - - #[cfg(feature = "limits")] - let bytes = self.bytes_with(limit, ::LIMIT).await?; - #[cfg(not(feature = "limits"))] + ::check_type(self.content_type())?; let bytes = self.bytes().await?; - serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode) } @@ -344,33 +294,14 @@ impl RequestExt for Request { where T: serde::de::DeserializeOwned, { - #[cfg(feature = "limits")] - let limit = self.limits().get(::NAME); - #[cfg(not(feature = "limits"))] - let limit = None; - - ::check_header(self.content_type(), self.content_length(), limit)?; - - #[cfg(feature = "limits")] - let bytes = self.bytes_with(limit, ::LIMIT).await?; - #[cfg(not(feature = "limits"))] + ::check_type(self.content_type())?; let bytes = self.bytes().await?; - serde_json::from_slice(&bytes).map_err(PayloadError::Json) } #[cfg(feature = "multipart")] async fn multipart(&mut self) -> Result { - #[cfg(feature = "limits")] - let limit = self.limits().get(::NAME); - #[cfg(not(feature = "limits"))] - let limit = None; - - let m = ::check_header( - self.content_type(), - self.content_length(), - limit, - )?; + let m = ::check_type(self.content_type())?; let boundary = m .get_param(mime::BOUNDARY) @@ -420,13 +351,6 @@ impl RequestExt for Request { self.extensions().get::()?.get(name.as_ref()) } - #[cfg(feature = "limits")] - fn limits(&self) -> &Limits { - self.extensions() - .get::() - .expect("Limits middleware is required") - } - #[cfg(feature = "session")] fn session(&self) -> &Session { self.extensions().get().expect("should get a session") @@ -463,6 +387,136 @@ impl RequestExt for Request { } } +/// The [`Request`] Extension with a limited body. +#[cfg(feature = "limits")] +pub trait RequestLimitsExt: private::Sealed + Sized { + /// Get limits settings. + fn limits(&self) -> &Limits; + + /// Return with a [Bytes][mdn] by a limit representation of the request body. + /// + /// [mdn]: + fn bytes_with( + &mut self, + limit: Option, + max: u64, + ) -> impl Future> + Send; + + /// Return with a limited [Text][mdn] representation of the request body. + /// + /// [mdn]: + fn limited_text(&mut self) -> impl Future> + Send; + + /// Return with a limited `application/x-www-form-urlencoded` [FormData][mdn] by the specified type + /// representation of the request body. + /// + /// [mdn]: + #[cfg(feature = "form")] + fn limited_form(&mut self) -> impl Future> + Send + where + T: serde::de::DeserializeOwned; + + /// Return with a limited [JSON][mdn] by the specified type representation of the request body. + /// + /// [mdn]: + #[cfg(feature = "json")] + fn limited_json(&mut self) -> impl Future> + Send + where + T: serde::de::DeserializeOwned; + + /// Return with a limited `multipart/form-data` [FormData][mdn] by the specified type + /// representation of the request body. + /// + /// [mdn]: + #[cfg(feature = "multipart")] + fn limited_multipart(&mut self) + -> impl Future> + Send; +} + +#[cfg(feature = "limits")] +impl RequestLimitsExt for Request { + fn limits(&self) -> &Limits { + self.extensions() + .get::() + .expect("Limits middleware is required") + } + + async fn bytes_with(&mut self, limit: Option, max: u64) -> Result { + Limited::new( + self.incoming()?, + usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX), + ) + .collect() + .await + .map_err(|err| { + if err.is::() { + return PayloadError::TooLarge; + } + if let Ok(err) = err.downcast::() { + return PayloadError::Hyper(*err); + } + PayloadError::Read + }) + .map(Collected::to_bytes) + } + + async fn limited_text(&mut self) -> Result { + let bytes = self + .bytes_with(self.limits().get("text"), Limits::NORMAL) + .await?; + + String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8) + } + + #[cfg(feature = "form")] + async fn limited_form(&mut self) -> Result + where + T: serde::de::DeserializeOwned, + { + let limit = self.limits().get(::NAME); + ::check_header(self.content_type(), self.content_length(), limit)?; + let bytes = self.bytes_with(limit, ::LIMIT).await?; + serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode) + } + + #[cfg(feature = "json")] + async fn limited_json(&mut self) -> Result + where + T: serde::de::DeserializeOwned, + { + let limit = self.limits().get(::NAME); + ::check_header(self.content_type(), self.content_length(), limit)?; + let bytes = self.bytes_with(limit, ::LIMIT).await?; + serde_json::from_slice(&bytes).map_err(PayloadError::Json) + } + + #[cfg(feature = "multipart")] + async fn limited_multipart(&mut self) -> Result { + let limit = self.limits().get(::NAME); + + let m = ::check_header( + self.content_type(), + self.content_length(), + limit, + )?; + + let boundary = m + .get_param(mime::BOUNDARY) + .ok_or(PayloadError::MissingBoundary)? + .as_str(); + + Ok(Multipart::with_limits( + self.incoming()?, + boundary, + self.extensions() + .get::>() + .map(AsRef::as_ref) + .cloned() + .unwrap_or_default(), + )) + } +} + mod private { pub trait Sealed {} impl Sealed for super::Request {} diff --git a/viz-core/src/types/limits.rs b/viz-core/src/types/limits.rs index ae189fc4..7da0d6ef 100644 --- a/viz-core/src/types/limits.rs +++ b/viz-core/src/types/limits.rs @@ -1,6 +1,6 @@ use std::{convert::Infallible, sync::Arc}; -use crate::{FromRequest, Request, RequestExt}; +use crate::{FromRequest, Request, RequestLimitsExt}; #[cfg(feature = "form")] use super::Form; diff --git a/viz-core/src/types/payload.rs b/viz-core/src/types/payload.rs index f67f30f1..f99f84bd 100644 --- a/viz-core/src/types/payload.rs +++ b/viz-core/src/types/payload.rs @@ -108,6 +108,32 @@ pub trait Payload { limit.unwrap_or(Self::LIMIT) } + /// Detects `Content-Type` + /// + /// # Errors + /// + /// Will return [`PayloadError::UnsupportedMediaType`] if the detected media type is not supported. + #[inline] + fn check_type(m: Option) -> Result { + m.filter(Self::detect) + .ok_or_else(|| PayloadError::UnsupportedMediaType(Self::mime())) + } + + /// Checks `Content-Length` + /// + /// # Errors + /// + /// Will return [`PayloadError::TooLarge`] if the detected content length is too large. + #[inline] + fn check_length(len: Option, limit: Option) -> Result<(), PayloadError> { + match len { + None => Err(PayloadError::LengthRequired), + Some(len) => (len <= Self::limit(limit)) + .then_some(()) + .ok_or_else(|| PayloadError::TooLarge), + } + } + /// Checks `Content-Type` & `Content-Length` /// /// # Errors @@ -121,14 +147,8 @@ pub trait Payload { len: Option, limit: Option, ) -> Result { - let m = m - .filter(Self::detect) - .ok_or_else(|| PayloadError::UnsupportedMediaType(Self::mime()))?; - - match len { - None => Err(PayloadError::LengthRequired), - Some(len) if len > Self::limit(limit) => Err(PayloadError::TooLarge), - Some(_) => Ok(m), - } + let m = Self::check_type(m)?; + Self::check_length(len, limit)?; + Ok(m) } } diff --git a/viz-test/tests/payload.rs b/viz-test/tests/payload.rs index a9350e51..f655742e 100644 --- a/viz-test/tests/payload.rs +++ b/viz-test/tests/payload.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; -use viz::{types, Error, Request, RequestExt, Response, ResponseExt, Result}; + +use viz::{types, Error, Request, RequestLimitsExt, Response, ResponseExt, Result}; #[tokio::test] async fn payload() -> Result<()> { @@ -9,15 +10,15 @@ async fn payload() -> Result<()> { let router = Router::new() .post("/form", |mut req: Request| async move { - let data = req.form::>().await?; + let data = req.limited_form::>().await?; Ok(Response::json(data)) }) .post("/json", |mut req: Request| async move { - let data = req.json::>().await?; + let data = req.limited_json::>().await?; Ok(Response::json(data)) }) .post("/multipart", |mut req: Request| async move { - let _ = req.multipart().await?; + let _ = req.limited_multipart().await?; Ok(()) }) .with( diff --git a/viz-test/tests/request.rs b/viz-test/tests/request.rs index 0d6d24bc..c90e1fa0 100644 --- a/viz-test/tests/request.rs +++ b/viz-test/tests/request.rs @@ -11,6 +11,7 @@ use viz::{ IntoResponse, Request, RequestExt, + RequestLimitsExt, Response, ResponseExt, Result, @@ -89,7 +90,7 @@ async fn request_body() -> Result<()> { Ok(Response::json(data)) }) .post("/multipart", |mut req: Request| async move { - let mut multipart = req.multipart().await?; + let mut multipart = req.limited_multipart().await?; let mut data = HashMap::new(); while let Some(mut field) = multipart.try_next().await? { diff --git a/viz/src/server/listener.rs b/viz/src/server/listener.rs index 4b72e1eb..8bb2cd49 100644 --- a/viz/src/server/listener.rs +++ b/viz/src/server/listener.rs @@ -13,5 +13,10 @@ pub trait Listener { fn accept(&self) -> impl Future> + Send; /// Returns the local address that this listener is bound to. + /// + /// # Errors + /// + /// An error will return if got the socket address of the local half of this connection is + /// failed. fn local_addr(&self) -> Result; }