From 44117c9d705c7f97a6c6de3f52bb40dbd3957bdb Mon Sep 17 00:00:00 2001 From: Fangdun Tsai Date: Thu, 4 Jan 2024 10:47:34 +0800 Subject: [PATCH] feat: add a trait for limited request body --- viz-core/src/lib.rs | 2 +- viz-core/src/request.rs | 213 ++++++++++++++++++++++------------- viz-core/src/types/limits.rs | 2 +- viz-test/tests/request.rs | 3 +- 4 files changed, 138 insertions(+), 82 deletions(-) diff --git a/viz-core/src/lib.rs b/viz-core/src/lib.rs index b6d7d289..eda9459c 100644 --- a/viz-core/src/lib.rs +++ b/viz-core/src/lib.rs @@ -33,7 +33,7 @@ mod into_response; pub use into_response::IntoResponse; mod request; -pub use request::RequestExt; +pub use request::{RequestExt, RequestLimitsExt}; mod response; pub use response::ResponseExt; diff --git a/viz-core/src/request.rs b/viz-core/src/request.rs index aac228e6..b8de371a 100644 --- a/viz-core/src/request.rs +++ b/viz-core/src/request.rs @@ -91,16 +91,6 @@ pub trait RequestExt: private::Sealed + Sized { /// [mdn]: fn bytes(&mut self) -> impl Future> + Send; - /// Return with a [Bytes][mdn] by a limit representation of the request body. - /// - /// [mdn]: - #[cfg(feature = "limits")] - fn bytes_with( - &mut self, - limit: Option, - max: u64, - ) -> impl Future> + Send; - /// Return with a [Text][mdn] representation of the request body. /// /// [mdn]: @@ -156,10 +146,6 @@ pub trait RequestExt: private::Sealed + Sized { where S: AsRef; - /// Get limits settings. - #[cfg(feature = "limits")] - fn limits(&self) -> &Limits; - /// Get current session. #[cfg(feature = "session")] fn session(&self) -> &Session; @@ -288,34 +274,8 @@ impl RequestExt for Request { .map(Collected::to_bytes) } - #[cfg(feature = "limits")] - async fn bytes_with(&mut self, limit: Option, max: u64) -> Result { - Limited::new( - self.incoming()?, - usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX), - ) - .collect() - .await - .map_err(|err| { - if err.is::() { - return PayloadError::TooLarge; - } - if let Ok(err) = err.downcast::() { - return PayloadError::Hyper(*err); - } - PayloadError::Read - }) - .map(Collected::to_bytes) - } - async fn text(&mut self) -> Result { - #[cfg(feature = "limits")] - let bytes = self - .bytes_with(self.limits().get("text"), Limits::NORMAL) - .await?; - #[cfg(not(feature = "limits"))] let bytes = self.bytes().await?; - String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8) } @@ -324,18 +284,8 @@ impl RequestExt for Request { where T: serde::de::DeserializeOwned, { - #[cfg(feature = "limits")] - let limit = self.limits().get(
::NAME); - #[cfg(not(feature = "limits"))] - let limit = None; - - ::check_header(self.content_type(), self.content_length(), limit)?; - - #[cfg(feature = "limits")] - let bytes = self.bytes_with(limit, ::LIMIT).await?; - #[cfg(not(feature = "limits"))] + ::check_header(self.content_type(), self.content_length(), None)?; let bytes = self.bytes().await?; - serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode) } @@ -344,33 +294,15 @@ impl RequestExt for Request { where T: serde::de::DeserializeOwned, { - #[cfg(feature = "limits")] - let limit = self.limits().get(::NAME); - #[cfg(not(feature = "limits"))] - let limit = None; - - ::check_header(self.content_type(), self.content_length(), limit)?; - - #[cfg(feature = "limits")] - let bytes = self.bytes_with(limit, ::LIMIT).await?; - #[cfg(not(feature = "limits"))] + ::check_header(self.content_type(), self.content_length(), None)?; let bytes = self.bytes().await?; - serde_json::from_slice(&bytes).map_err(PayloadError::Json) } #[cfg(feature = "multipart")] async fn multipart(&mut self) -> Result { - #[cfg(feature = "limits")] - let limit = self.limits().get(::NAME); - #[cfg(not(feature = "limits"))] - let limit = None; - - let m = ::check_header( - self.content_type(), - self.content_length(), - limit, - )?; + let m = + ::check_header(self.content_type(), self.content_length(), None)?; let boundary = m .get_param(mime::BOUNDARY) @@ -420,13 +352,6 @@ impl RequestExt for Request { self.extensions().get::()?.get(name.as_ref()) } - #[cfg(feature = "limits")] - fn limits(&self) -> &Limits { - self.extensions() - .get::() - .expect("Limits middleware is required") - } - #[cfg(feature = "session")] fn session(&self) -> &Session { self.extensions().get().expect("should get a session") @@ -463,6 +388,136 @@ impl RequestExt for Request { } } +/// The [`Request`] Extension with a limited body. +#[cfg(feature = "limits")] +pub trait RequestLimitsExt: private::Sealed + Sized { + /// Get limits settings. + fn limits(&self) -> &Limits; + + /// Return with a [Bytes][mdn] by a limit representation of the request body. + /// + /// [mdn]: + fn bytes_with( + &mut self, + limit: Option, + max: u64, + ) -> impl Future> + Send; + + /// Return with a limited [Text][mdn] representation of the request body. + /// + /// [mdn]: + fn limited_text(&mut self) -> impl Future> + Send; + + /// Return with a limited `application/x-www-form-urlencoded` [FormData][mdn] by the specified type + /// representation of the request body. + /// + /// [mdn]: + #[cfg(feature = "form")] + fn limited_form(&mut self) -> impl Future> + Send + where + T: serde::de::DeserializeOwned; + + /// Return with a limited [JSON][mdn] by the specified type representation of the request body. + /// + /// [mdn]: + #[cfg(feature = "json")] + fn limited_json(&mut self) -> impl Future> + Send + where + T: serde::de::DeserializeOwned; + + /// Return with a limited `multipart/form-data` [FormData][mdn] by the specified type + /// representation of the request body. + /// + /// [mdn]: + #[cfg(feature = "multipart")] + fn limited_multipart(&mut self) + -> impl Future> + Send; +} + +#[cfg(feature = "limits")] +impl RequestLimitsExt for Request { + fn limits(&self) -> &Limits { + self.extensions() + .get::() + .expect("Limits middleware is required") + } + + async fn bytes_with(&mut self, limit: Option, max: u64) -> Result { + Limited::new( + self.incoming()?, + usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX), + ) + .collect() + .await + .map_err(|err| { + if err.is::() { + return PayloadError::TooLarge; + } + if let Ok(err) = err.downcast::() { + return PayloadError::Hyper(*err); + } + PayloadError::Read + }) + .map(Collected::to_bytes) + } + + async fn limited_text(&mut self) -> Result { + let bytes = self + .bytes_with(self.limits().get("text"), Limits::NORMAL) + .await?; + + String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8) + } + + #[cfg(feature = "form")] + async fn limited_form(&mut self) -> Result + where + T: serde::de::DeserializeOwned, + { + let limit = self.limits().get(::NAME); + ::check_header(self.content_type(), self.content_length(), limit)?; + let bytes = self.bytes_with(limit, ::LIMIT).await?; + serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode) + } + + #[cfg(feature = "json")] + async fn limited_json(&mut self) -> Result + where + T: serde::de::DeserializeOwned, + { + let limit = self.limits().get(::NAME); + ::check_header(self.content_type(), self.content_length(), limit)?; + let bytes = self.bytes_with(limit, ::LIMIT).await?; + serde_json::from_slice(&bytes).map_err(PayloadError::Json) + } + + #[cfg(feature = "multipart")] + async fn limited_multipart(&mut self) -> Result { + let limit = self.limits().get(::NAME); + + let m = ::check_header( + self.content_type(), + self.content_length(), + limit, + )?; + + let boundary = m + .get_param(mime::BOUNDARY) + .ok_or(PayloadError::MissingBoundary)? + .as_str(); + + Ok(Multipart::with_limits( + self.incoming()?, + boundary, + self.extensions() + .get::>() + .map(AsRef::as_ref) + .cloned() + .unwrap_or_default(), + )) + } +} + mod private { pub trait Sealed {} impl Sealed for super::Request {} diff --git a/viz-core/src/types/limits.rs b/viz-core/src/types/limits.rs index ae189fc4..7da0d6ef 100644 --- a/viz-core/src/types/limits.rs +++ b/viz-core/src/types/limits.rs @@ -1,6 +1,6 @@ use std::{convert::Infallible, sync::Arc}; -use crate::{FromRequest, Request, RequestExt}; +use crate::{FromRequest, Request, RequestLimitsExt}; #[cfg(feature = "form")] use super::Form; diff --git a/viz-test/tests/request.rs b/viz-test/tests/request.rs index 0d6d24bc..c90e1fa0 100644 --- a/viz-test/tests/request.rs +++ b/viz-test/tests/request.rs @@ -11,6 +11,7 @@ use viz::{ IntoResponse, Request, RequestExt, + RequestLimitsExt, Response, ResponseExt, Result, @@ -89,7 +90,7 @@ async fn request_body() -> Result<()> { Ok(Response::json(data)) }) .post("/multipart", |mut req: Request| async move { - let mut multipart = req.multipart().await?; + let mut multipart = req.limited_multipart().await?; let mut data = HashMap::new(); while let Some(mut field) = multipart.try_next().await? {