Skip to content

Commit

Permalink
feat: add a trait for limited request body
Browse files Browse the repository at this point in the history
  • Loading branch information
fundon committed Jan 4, 2024
1 parent 98ecceb commit 44117c9
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 82 deletions.
2 changes: 1 addition & 1 deletion viz-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mod into_response;
pub use into_response::IntoResponse;

mod request;
pub use request::RequestExt;
pub use request::{RequestExt, RequestLimitsExt};

mod response;
pub use response::ResponseExt;
Expand Down
213 changes: 134 additions & 79 deletions viz-core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,6 @@ pub trait RequestExt: private::Sealed + Sized {
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
fn bytes(&mut self) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;

/// Return with a [Bytes][mdn] by a limit representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
#[cfg(feature = "limits")]
fn bytes_with(
&mut self,
limit: Option<u64>,
max: u64,
) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;

/// Return with a [Text][mdn] representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/text>
Expand Down Expand Up @@ -156,10 +146,6 @@ pub trait RequestExt: private::Sealed + Sized {
where
S: AsRef<str>;

/// Get limits settings.
#[cfg(feature = "limits")]
fn limits(&self) -> &Limits;

/// Get current session.
#[cfg(feature = "session")]
fn session(&self) -> &Session;
Expand Down Expand Up @@ -288,34 +274,8 @@ impl RequestExt for Request {
.map(Collected::to_bytes)
}

#[cfg(feature = "limits")]
async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
Limited::new(
self.incoming()?,
usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
)
.collect()
.await
.map_err(|err| {
if err.is::<LengthLimitError>() {
return PayloadError::TooLarge;
}
if let Ok(err) = err.downcast::<hyper::Error>() {
return PayloadError::Hyper(*err);
}
PayloadError::Read
})
.map(Collected::to_bytes)
}

async fn text(&mut self) -> Result<String, PayloadError> {
#[cfg(feature = "limits")]
let bytes = self
.bytes_with(self.limits().get("text"), Limits::NORMAL)
.await?;
#[cfg(not(feature = "limits"))]
let bytes = self.bytes().await?;

String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
}

Expand All @@ -324,18 +284,8 @@ impl RequestExt for Request {
where
T: serde::de::DeserializeOwned,
{
#[cfg(feature = "limits")]
let limit = self.limits().get(<Form as Payload>::NAME);
#[cfg(not(feature = "limits"))]
let limit = None;

<Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;

#[cfg(feature = "limits")]
let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
#[cfg(not(feature = "limits"))]
<Form as Payload>::check_header(self.content_type(), self.content_length(), None)?;
let bytes = self.bytes().await?;

serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
}

Expand All @@ -344,33 +294,15 @@ impl RequestExt for Request {
where
T: serde::de::DeserializeOwned,
{
#[cfg(feature = "limits")]
let limit = self.limits().get(<Json as Payload>::NAME);
#[cfg(not(feature = "limits"))]
let limit = None;

<Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;

#[cfg(feature = "limits")]
let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
#[cfg(not(feature = "limits"))]
<Json as Payload>::check_header(self.content_type(), self.content_length(), None)?;
let bytes = self.bytes().await?;

serde_json::from_slice(&bytes).map_err(PayloadError::Json)
}

#[cfg(feature = "multipart")]
async fn multipart(&mut self) -> Result<Multipart, PayloadError> {
#[cfg(feature = "limits")]
let limit = self.limits().get(<Multipart as Payload>::NAME);
#[cfg(not(feature = "limits"))]
let limit = None;

let m = <Multipart as Payload>::check_header(
self.content_type(),
self.content_length(),
limit,
)?;
let m =
<Multipart as Payload>::check_header(self.content_type(), self.content_length(), None)?;

let boundary = m
.get_param(mime::BOUNDARY)
Expand Down Expand Up @@ -420,13 +352,6 @@ impl RequestExt for Request {
self.extensions().get::<Cookies>()?.get(name.as_ref())
}

#[cfg(feature = "limits")]
fn limits(&self) -> &Limits {
self.extensions()
.get::<Limits>()
.expect("Limits middleware is required")
}

#[cfg(feature = "session")]
fn session(&self) -> &Session {
self.extensions().get().expect("should get a session")
Expand Down Expand Up @@ -463,6 +388,136 @@ impl RequestExt for Request {
}
}

/// The [`Request`] Extension with a limited body.
#[cfg(feature = "limits")]
pub trait RequestLimitsExt: private::Sealed + Sized {
/// Get limits settings.
fn limits(&self) -> &Limits;

/// Return with a [Bytes][mdn] by a limit representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
fn bytes_with(
&mut self,
limit: Option<u64>,
max: u64,
) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;

/// Return with a limited [Text][mdn] representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/text>
fn limited_text(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;

/// Return with a limited `application/x-www-form-urlencoded` [FormData][mdn] by the specified type
/// representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
#[cfg(feature = "form")]
fn limited_form<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
where
T: serde::de::DeserializeOwned;

/// Return with a limited [JSON][mdn] by the specified type representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/json>
#[cfg(feature = "json")]
fn limited_json<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
where
T: serde::de::DeserializeOwned;

/// Return with a limited `multipart/form-data` [FormData][mdn] by the specified type
/// representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
#[cfg(feature = "multipart")]
fn limited_multipart(&mut self)
-> impl Future<Output = Result<Multipart, PayloadError>> + Send;
}

#[cfg(feature = "limits")]
impl RequestLimitsExt for Request {
fn limits(&self) -> &Limits {
self.extensions()
.get::<Limits>()
.expect("Limits middleware is required")
}

async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
Limited::new(
self.incoming()?,
usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
)
.collect()
.await
.map_err(|err| {
if err.is::<LengthLimitError>() {
return PayloadError::TooLarge;
}
if let Ok(err) = err.downcast::<hyper::Error>() {
return PayloadError::Hyper(*err);
}
PayloadError::Read
})
.map(Collected::to_bytes)
}

async fn limited_text(&mut self) -> Result<String, PayloadError> {
let bytes = self
.bytes_with(self.limits().get("text"), Limits::NORMAL)
.await?;

String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
}

#[cfg(feature = "form")]
async fn limited_form<T>(&mut self) -> Result<T, PayloadError>
where
T: serde::de::DeserializeOwned,
{
let limit = self.limits().get(<Form as Payload>::NAME);
<Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
}

#[cfg(feature = "json")]
async fn limited_json<T>(&mut self) -> Result<T, PayloadError>
where
T: serde::de::DeserializeOwned,
{
let limit = self.limits().get(<Json as Payload>::NAME);
<Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
serde_json::from_slice(&bytes).map_err(PayloadError::Json)
}

#[cfg(feature = "multipart")]
async fn limited_multipart(&mut self) -> Result<Multipart, PayloadError> {
let limit = self.limits().get(<Multipart as Payload>::NAME);

let m = <Multipart as Payload>::check_header(
self.content_type(),
self.content_length(),
limit,
)?;

let boundary = m
.get_param(mime::BOUNDARY)
.ok_or(PayloadError::MissingBoundary)?
.as_str();

Ok(Multipart::with_limits(
self.incoming()?,
boundary,
self.extensions()
.get::<std::sync::Arc<MultipartLimits>>()
.map(AsRef::as_ref)
.cloned()
.unwrap_or_default(),
))
}
}

mod private {
pub trait Sealed {}
impl Sealed for super::Request {}
Expand Down
2 changes: 1 addition & 1 deletion viz-core/src/types/limits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{convert::Infallible, sync::Arc};

use crate::{FromRequest, Request, RequestExt};
use crate::{FromRequest, Request, RequestLimitsExt};

#[cfg(feature = "form")]
use super::Form;
Expand Down
3 changes: 2 additions & 1 deletion viz-test/tests/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use viz::{
IntoResponse,
Request,
RequestExt,
RequestLimitsExt,
Response,
ResponseExt,
Result,
Expand Down Expand Up @@ -89,7 +90,7 @@ async fn request_body() -> Result<()> {
Ok(Response::json(data))
})
.post("/multipart", |mut req: Request| async move {
let mut multipart = req.multipart().await?;
let mut multipart = req.limited_multipart().await?;
let mut data = HashMap::new();

while let Some(mut field) = multipart.try_next().await? {
Expand Down

0 comments on commit 44117c9

Please sign in to comment.