Skip to content

Commit

Permalink
feat: add a trait for limited request body (#130)
Browse files Browse the repository at this point in the history
* feat: add a trait for limited request body

* fix: clippy

* fix: export
  • Loading branch information
fundon committed Jan 6, 2024
1 parent 60711d3 commit b4b8de3
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 94 deletions.
2 changes: 2 additions & 0 deletions viz-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ pub use into_response::IntoResponse;

mod request;
pub use request::RequestExt;
#[cfg(feature = "limits")]
pub use request::RequestLimitsExt;

mod response;
pub use response::ResponseExt;
Expand Down
212 changes: 133 additions & 79 deletions viz-core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,6 @@ pub trait RequestExt: private::Sealed + Sized {
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
fn bytes(&mut self) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;

/// Return with a [Bytes][mdn] by a limit representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
#[cfg(feature = "limits")]
fn bytes_with(
&mut self,
limit: Option<u64>,
max: u64,
) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;

/// Return with a [Text][mdn] representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/text>
Expand Down Expand Up @@ -156,10 +146,6 @@ pub trait RequestExt: private::Sealed + Sized {
where
S: AsRef<str>;

/// Get limits settings.
#[cfg(feature = "limits")]
fn limits(&self) -> &Limits;

/// Get current session.
#[cfg(feature = "session")]
fn session(&self) -> &Session;
Expand Down Expand Up @@ -288,34 +274,8 @@ impl RequestExt for Request {
.map(Collected::to_bytes)
}

#[cfg(feature = "limits")]
async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
Limited::new(
self.incoming()?,
usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
)
.collect()
.await
.map_err(|err| {
if err.is::<LengthLimitError>() {
return PayloadError::TooLarge;
}
if let Ok(err) = err.downcast::<hyper::Error>() {
return PayloadError::Hyper(*err);
}
PayloadError::Read
})
.map(Collected::to_bytes)
}

async fn text(&mut self) -> Result<String, PayloadError> {
#[cfg(feature = "limits")]
let bytes = self
.bytes_with(self.limits().get("text"), Limits::NORMAL)
.await?;
#[cfg(not(feature = "limits"))]
let bytes = self.bytes().await?;

String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
}

Expand All @@ -324,18 +284,8 @@ impl RequestExt for Request {
where
T: serde::de::DeserializeOwned,
{
#[cfg(feature = "limits")]
let limit = self.limits().get(<Form as Payload>::NAME);
#[cfg(not(feature = "limits"))]
let limit = None;

<Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;

#[cfg(feature = "limits")]
let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
#[cfg(not(feature = "limits"))]
<Form as Payload>::check_type(self.content_type())?;
let bytes = self.bytes().await?;

serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
}

Expand All @@ -344,33 +294,14 @@ impl RequestExt for Request {
where
T: serde::de::DeserializeOwned,
{
#[cfg(feature = "limits")]
let limit = self.limits().get(<Json as Payload>::NAME);
#[cfg(not(feature = "limits"))]
let limit = None;

<Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;

#[cfg(feature = "limits")]
let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
#[cfg(not(feature = "limits"))]
<Json as Payload>::check_type(self.content_type())?;
let bytes = self.bytes().await?;

serde_json::from_slice(&bytes).map_err(PayloadError::Json)
}

#[cfg(feature = "multipart")]
async fn multipart(&mut self) -> Result<Multipart, PayloadError> {
#[cfg(feature = "limits")]
let limit = self.limits().get(<Multipart as Payload>::NAME);
#[cfg(not(feature = "limits"))]
let limit = None;

let m = <Multipart as Payload>::check_header(
self.content_type(),
self.content_length(),
limit,
)?;
let m = <Multipart as Payload>::check_type(self.content_type())?;

let boundary = m
.get_param(mime::BOUNDARY)
Expand Down Expand Up @@ -420,13 +351,6 @@ impl RequestExt for Request {
self.extensions().get::<Cookies>()?.get(name.as_ref())
}

#[cfg(feature = "limits")]
fn limits(&self) -> &Limits {
self.extensions()
.get::<Limits>()
.expect("Limits middleware is required")
}

#[cfg(feature = "session")]
fn session(&self) -> &Session {
self.extensions().get().expect("should get a session")
Expand Down Expand Up @@ -463,6 +387,136 @@ impl RequestExt for Request {
}
}

/// The [`Request`] Extension with a limited body.
#[cfg(feature = "limits")]
pub trait RequestLimitsExt: private::Sealed + Sized {
/// Get limits settings.
fn limits(&self) -> &Limits;

/// Return with a [Bytes][mdn] by a limit representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
fn bytes_with(
&mut self,
limit: Option<u64>,
max: u64,
) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;

/// Return with a limited [Text][mdn] representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/text>
fn limited_text(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;

/// Return with a limited `application/x-www-form-urlencoded` [FormData][mdn] by the specified type
/// representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
#[cfg(feature = "form")]
fn limited_form<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
where
T: serde::de::DeserializeOwned;

/// Return with a limited [JSON][mdn] by the specified type representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/json>
#[cfg(feature = "json")]
fn limited_json<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
where
T: serde::de::DeserializeOwned;

/// Return with a limited `multipart/form-data` [FormData][mdn] by the specified type
/// representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
#[cfg(feature = "multipart")]
fn limited_multipart(&mut self)
-> impl Future<Output = Result<Multipart, PayloadError>> + Send;
}

#[cfg(feature = "limits")]
impl RequestLimitsExt for Request {
fn limits(&self) -> &Limits {
self.extensions()
.get::<Limits>()
.expect("Limits middleware is required")
}

async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
Limited::new(
self.incoming()?,
usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
)
.collect()
.await
.map_err(|err| {
if err.is::<LengthLimitError>() {
return PayloadError::TooLarge;
}
if let Ok(err) = err.downcast::<hyper::Error>() {
return PayloadError::Hyper(*err);
}
PayloadError::Read
})
.map(Collected::to_bytes)
}

async fn limited_text(&mut self) -> Result<String, PayloadError> {
let bytes = self
.bytes_with(self.limits().get("text"), Limits::NORMAL)
.await?;

String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
}

#[cfg(feature = "form")]
async fn limited_form<T>(&mut self) -> Result<T, PayloadError>
where
T: serde::de::DeserializeOwned,
{
let limit = self.limits().get(<Form as Payload>::NAME);
<Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
}

#[cfg(feature = "json")]
async fn limited_json<T>(&mut self) -> Result<T, PayloadError>
where
T: serde::de::DeserializeOwned,
{
let limit = self.limits().get(<Json as Payload>::NAME);
<Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
serde_json::from_slice(&bytes).map_err(PayloadError::Json)
}

#[cfg(feature = "multipart")]
async fn limited_multipart(&mut self) -> Result<Multipart, PayloadError> {
let limit = self.limits().get(<Multipart as Payload>::NAME);

let m = <Multipart as Payload>::check_header(
self.content_type(),
self.content_length(),
limit,
)?;

let boundary = m
.get_param(mime::BOUNDARY)
.ok_or(PayloadError::MissingBoundary)?
.as_str();

Ok(Multipart::with_limits(
self.incoming()?,
boundary,
self.extensions()
.get::<std::sync::Arc<MultipartLimits>>()
.map(AsRef::as_ref)
.cloned()
.unwrap_or_default(),
))
}
}

mod private {
pub trait Sealed {}
impl Sealed for super::Request {}
Expand Down
2 changes: 1 addition & 1 deletion viz-core/src/types/limits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{convert::Infallible, sync::Arc};

use crate::{FromRequest, Request, RequestExt};
use crate::{FromRequest, Request, RequestLimitsExt};

#[cfg(feature = "form")]
use super::Form;
Expand Down
38 changes: 29 additions & 9 deletions viz-core/src/types/payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,32 @@ pub trait Payload {
limit.unwrap_or(Self::LIMIT)
}

/// Detects `Content-Type`
///
/// # Errors
///
/// Will return [`PayloadError::UnsupportedMediaType`] if the detected media type is not supported.
#[inline]
fn check_type(m: Option<mime::Mime>) -> Result<mime::Mime, PayloadError> {
m.filter(Self::detect)
.ok_or_else(|| PayloadError::UnsupportedMediaType(Self::mime()))
}

/// Checks `Content-Length`
///
/// # Errors
///
/// Will return [`PayloadError::TooLarge`] if the detected content length is too large.
#[inline]
fn check_length(len: Option<u64>, limit: Option<u64>) -> Result<(), PayloadError> {
match len {
None => Err(PayloadError::LengthRequired),
Some(len) => (len <= Self::limit(limit))
.then_some(())
.ok_or_else(|| PayloadError::TooLarge),
}
}

/// Checks `Content-Type` & `Content-Length`
///
/// # Errors
Expand All @@ -121,14 +147,8 @@ pub trait Payload {
len: Option<u64>,
limit: Option<u64>,
) -> Result<mime::Mime, PayloadError> {
let m = m
.filter(Self::detect)
.ok_or_else(|| PayloadError::UnsupportedMediaType(Self::mime()))?;

match len {
None => Err(PayloadError::LengthRequired),
Some(len) if len > Self::limit(limit) => Err(PayloadError::TooLarge),
Some(_) => Ok(m),
}
let m = Self::check_type(m)?;
Self::check_length(len, limit)?;
Ok(m)
}
}
9 changes: 5 additions & 4 deletions viz-test/tests/payload.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use viz::{types, Error, Request, RequestExt, Response, ResponseExt, Result};

use viz::{types, Error, Request, RequestLimitsExt, Response, ResponseExt, Result};

#[tokio::test]
async fn payload() -> Result<()> {
Expand All @@ -9,15 +10,15 @@ async fn payload() -> Result<()> {

let router = Router::new()
.post("/form", |mut req: Request| async move {
let data = req.form::<HashMap<String, String>>().await?;
let data = req.limited_form::<HashMap<String, String>>().await?;
Ok(Response::json(data))
})
.post("/json", |mut req: Request| async move {
let data = req.json::<HashMap<String, String>>().await?;
let data = req.limited_json::<HashMap<String, String>>().await?;
Ok(Response::json(data))
})
.post("/multipart", |mut req: Request| async move {
let _ = req.multipart().await?;
let _ = req.limited_multipart().await?;
Ok(())
})
.with(
Expand Down
3 changes: 2 additions & 1 deletion viz-test/tests/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use viz::{
IntoResponse,
Request,
RequestExt,
RequestLimitsExt,
Response,
ResponseExt,
Result,
Expand Down Expand Up @@ -89,7 +90,7 @@ async fn request_body() -> Result<()> {
Ok(Response::json(data))
})
.post("/multipart", |mut req: Request| async move {
let mut multipart = req.multipart().await?;
let mut multipart = req.limited_multipart().await?;
let mut data = HashMap::new();

while let Some(mut field) = multipart.try_next().await? {
Expand Down
5 changes: 5 additions & 0 deletions viz/src/server/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,10 @@ pub trait Listener {
fn accept(&self) -> impl Future<Output = Result<(Self::Io, Self::Addr)>> + Send;

/// Returns the local address that this listener is bound to.
///
/// # Errors
///
/// An error will return if got the socket address of the local half of this connection is
/// failed.
fn local_addr(&self) -> Result<Self::Addr>;
}

0 comments on commit b4b8de3

Please sign in to comment.