Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add a trait for limited request body #130

Merged
merged 3 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions viz-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ pub use into_response::IntoResponse;

mod request;
pub use request::RequestExt;
#[cfg(feature = "limits")]
pub use request::RequestLimitsExt;

mod response;
pub use response::ResponseExt;
Expand Down
212 changes: 133 additions & 79 deletions viz-core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,6 @@
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
fn bytes(&mut self) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;

/// Return with a [Bytes][mdn] by a limit representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
#[cfg(feature = "limits")]
fn bytes_with(
&mut self,
limit: Option<u64>,
max: u64,
) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;

/// Return with a [Text][mdn] representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/text>
Expand Down Expand Up @@ -156,10 +146,6 @@
where
S: AsRef<str>;

/// Get limits settings.
#[cfg(feature = "limits")]
fn limits(&self) -> &Limits;

/// Get current session.
#[cfg(feature = "session")]
fn session(&self) -> &Session;
Expand Down Expand Up @@ -288,34 +274,8 @@
.map(Collected::to_bytes)
}

#[cfg(feature = "limits")]
async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
Limited::new(
self.incoming()?,
usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
)
.collect()
.await
.map_err(|err| {
if err.is::<LengthLimitError>() {
return PayloadError::TooLarge;
}
if let Ok(err) = err.downcast::<hyper::Error>() {
return PayloadError::Hyper(*err);
}
PayloadError::Read
})
.map(Collected::to_bytes)
}

async fn text(&mut self) -> Result<String, PayloadError> {
#[cfg(feature = "limits")]
let bytes = self
.bytes_with(self.limits().get("text"), Limits::NORMAL)
.await?;
#[cfg(not(feature = "limits"))]
let bytes = self.bytes().await?;

String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
}

Expand All @@ -324,18 +284,8 @@
where
T: serde::de::DeserializeOwned,
{
#[cfg(feature = "limits")]
let limit = self.limits().get(<Form as Payload>::NAME);
#[cfg(not(feature = "limits"))]
let limit = None;

<Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;

#[cfg(feature = "limits")]
let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
#[cfg(not(feature = "limits"))]
<Form as Payload>::check_type(self.content_type())?;
let bytes = self.bytes().await?;

serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
}

Expand All @@ -344,33 +294,14 @@
where
T: serde::de::DeserializeOwned,
{
#[cfg(feature = "limits")]
let limit = self.limits().get(<Json as Payload>::NAME);
#[cfg(not(feature = "limits"))]
let limit = None;

<Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;

#[cfg(feature = "limits")]
let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
#[cfg(not(feature = "limits"))]
<Json as Payload>::check_type(self.content_type())?;
let bytes = self.bytes().await?;

serde_json::from_slice(&bytes).map_err(PayloadError::Json)
}

#[cfg(feature = "multipart")]
async fn multipart(&mut self) -> Result<Multipart, PayloadError> {
#[cfg(feature = "limits")]
let limit = self.limits().get(<Multipart as Payload>::NAME);
#[cfg(not(feature = "limits"))]
let limit = None;

let m = <Multipart as Payload>::check_header(
self.content_type(),
self.content_length(),
limit,
)?;
let m = <Multipart as Payload>::check_type(self.content_type())?;

Check warning on line 304 in viz-core/src/request.rs

View check run for this annotation

Codecov / codecov/patch

viz-core/src/request.rs#L304

Added line #L304 was not covered by tests

let boundary = m
.get_param(mime::BOUNDARY)
Expand Down Expand Up @@ -420,13 +351,6 @@
self.extensions().get::<Cookies>()?.get(name.as_ref())
}

#[cfg(feature = "limits")]
fn limits(&self) -> &Limits {
self.extensions()
.get::<Limits>()
.expect("Limits middleware is required")
}

#[cfg(feature = "session")]
fn session(&self) -> &Session {
self.extensions().get().expect("should get a session")
Expand Down Expand Up @@ -463,6 +387,136 @@
}
}

/// The [`Request`] Extension with a limited body.
#[cfg(feature = "limits")]
pub trait RequestLimitsExt: private::Sealed + Sized {
/// Get limits settings.
fn limits(&self) -> &Limits;

/// Return with a [Bytes][mdn] by a limit representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/arrayBuffer>
fn bytes_with(
&mut self,
limit: Option<u64>,
max: u64,
) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;

/// Return with a limited [Text][mdn] representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/text>
fn limited_text(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;

/// Return with a limited `application/x-www-form-urlencoded` [FormData][mdn] by the specified type
/// representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
#[cfg(feature = "form")]
fn limited_form<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
where
T: serde::de::DeserializeOwned;

/// Return with a limited [JSON][mdn] by the specified type representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/Response/json>
#[cfg(feature = "json")]
fn limited_json<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
where
T: serde::de::DeserializeOwned;

/// Return with a limited `multipart/form-data` [FormData][mdn] by the specified type
/// representation of the request body.
///
/// [mdn]: <https://developer.mozilla.org/en-US/docs/Web/API/FormData>
#[cfg(feature = "multipart")]
fn limited_multipart(&mut self)
-> impl Future<Output = Result<Multipart, PayloadError>> + Send;
}

#[cfg(feature = "limits")]
impl RequestLimitsExt for Request {
fn limits(&self) -> &Limits {
self.extensions()
.get::<Limits>()
.expect("Limits middleware is required")
}

async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
Limited::new(
self.incoming()?,
usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
)
.collect()
.await

Check warning on line 450 in viz-core/src/request.rs

View check run for this annotation

Codecov / codecov/patch

viz-core/src/request.rs#L450

Added line #L450 was not covered by tests
.map_err(|err| {
if err.is::<LengthLimitError>() {
return PayloadError::TooLarge;
}
if let Ok(err) = err.downcast::<hyper::Error>() {
return PayloadError::Hyper(*err);
}
PayloadError::Read

Check warning on line 458 in viz-core/src/request.rs

View check run for this annotation

Codecov / codecov/patch

viz-core/src/request.rs#L454-L458

Added lines #L454 - L458 were not covered by tests
})
.map(Collected::to_bytes)
}

async fn limited_text(&mut self) -> Result<String, PayloadError> {
let bytes = self
.bytes_with(self.limits().get("text"), Limits::NORMAL)
.await?;

Check warning on line 466 in viz-core/src/request.rs

View check run for this annotation

Codecov / codecov/patch

viz-core/src/request.rs#L463-L466

Added lines #L463 - L466 were not covered by tests

String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
}

Check warning on line 469 in viz-core/src/request.rs

View check run for this annotation

Codecov / codecov/patch

viz-core/src/request.rs#L468-L469

Added lines #L468 - L469 were not covered by tests

#[cfg(feature = "form")]
async fn limited_form<T>(&mut self) -> Result<T, PayloadError>
where
T: serde::de::DeserializeOwned,
{
let limit = self.limits().get(<Form as Payload>::NAME);
<Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)

Check warning on line 479 in viz-core/src/request.rs

View check run for this annotation

Codecov / codecov/patch

viz-core/src/request.rs#L478-L479

Added lines #L478 - L479 were not covered by tests
}

#[cfg(feature = "json")]
async fn limited_json<T>(&mut self) -> Result<T, PayloadError>
where
T: serde::de::DeserializeOwned,
{
let limit = self.limits().get(<Json as Payload>::NAME);
<Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
serde_json::from_slice(&bytes).map_err(PayloadError::Json)

Check warning on line 490 in viz-core/src/request.rs

View check run for this annotation

Codecov / codecov/patch

viz-core/src/request.rs#L489-L490

Added lines #L489 - L490 were not covered by tests
}

#[cfg(feature = "multipart")]
async fn limited_multipart(&mut self) -> Result<Multipart, PayloadError> {
let limit = self.limits().get(<Multipart as Payload>::NAME);

let m = <Multipart as Payload>::check_header(
self.content_type(),
self.content_length(),
limit,
)?;

let boundary = m
.get_param(mime::BOUNDARY)
.ok_or(PayloadError::MissingBoundary)?
.as_str();

Ok(Multipart::with_limits(
self.incoming()?,
boundary,
self.extensions()
.get::<std::sync::Arc<MultipartLimits>>()
.map(AsRef::as_ref)
.cloned()
.unwrap_or_default(),
))
}
}

mod private {
pub trait Sealed {}
impl Sealed for super::Request {}
Expand Down
2 changes: 1 addition & 1 deletion viz-core/src/types/limits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{convert::Infallible, sync::Arc};

use crate::{FromRequest, Request, RequestExt};
use crate::{FromRequest, Request, RequestLimitsExt};

#[cfg(feature = "form")]
use super::Form;
Expand Down
38 changes: 29 additions & 9 deletions viz-core/src/types/payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,32 @@
limit.unwrap_or(Self::LIMIT)
}

/// Detects `Content-Type`
///
/// # Errors
///
/// Will return [`PayloadError::UnsupportedMediaType`] if the detected media type is not supported.
#[inline]
fn check_type(m: Option<mime::Mime>) -> Result<mime::Mime, PayloadError> {
m.filter(Self::detect)
.ok_or_else(|| PayloadError::UnsupportedMediaType(Self::mime()))
}

/// Checks `Content-Length`
///
/// # Errors
///
/// Will return [`PayloadError::TooLarge`] if the detected content length is too large.
#[inline]
fn check_length(len: Option<u64>, limit: Option<u64>) -> Result<(), PayloadError> {
match len {
None => Err(PayloadError::LengthRequired),

Check warning on line 130 in viz-core/src/types/payload.rs

View check run for this annotation

Codecov / codecov/patch

viz-core/src/types/payload.rs#L130

Added line #L130 was not covered by tests
Some(len) => (len <= Self::limit(limit))
.then_some(())
.ok_or_else(|| PayloadError::TooLarge),
}
}

/// Checks `Content-Type` & `Content-Length`
///
/// # Errors
Expand All @@ -121,14 +147,8 @@
len: Option<u64>,
limit: Option<u64>,
) -> Result<mime::Mime, PayloadError> {
let m = m
.filter(Self::detect)
.ok_or_else(|| PayloadError::UnsupportedMediaType(Self::mime()))?;

match len {
None => Err(PayloadError::LengthRequired),
Some(len) if len > Self::limit(limit) => Err(PayloadError::TooLarge),
Some(_) => Ok(m),
}
let m = Self::check_type(m)?;
Self::check_length(len, limit)?;
Ok(m)
}
}
9 changes: 5 additions & 4 deletions viz-test/tests/payload.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use viz::{types, Error, Request, RequestExt, Response, ResponseExt, Result};

use viz::{types, Error, Request, RequestLimitsExt, Response, ResponseExt, Result};

#[tokio::test]
async fn payload() -> Result<()> {
Expand All @@ -9,15 +10,15 @@ async fn payload() -> Result<()> {

let router = Router::new()
.post("/form", |mut req: Request| async move {
let data = req.form::<HashMap<String, String>>().await?;
let data = req.limited_form::<HashMap<String, String>>().await?;
Ok(Response::json(data))
})
.post("/json", |mut req: Request| async move {
let data = req.json::<HashMap<String, String>>().await?;
let data = req.limited_json::<HashMap<String, String>>().await?;
Ok(Response::json(data))
})
.post("/multipart", |mut req: Request| async move {
let _ = req.multipart().await?;
let _ = req.limited_multipart().await?;
Ok(())
})
.with(
Expand Down
3 changes: 2 additions & 1 deletion viz-test/tests/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use viz::{
IntoResponse,
Request,
RequestExt,
RequestLimitsExt,
Response,
ResponseExt,
Result,
Expand Down Expand Up @@ -89,7 +90,7 @@ async fn request_body() -> Result<()> {
Ok(Response::json(data))
})
.post("/multipart", |mut req: Request| async move {
let mut multipart = req.multipart().await?;
let mut multipart = req.limited_multipart().await?;
let mut data = HashMap::new();

while let Some(mut field) = multipart.try_next().await? {
Expand Down
5 changes: 5 additions & 0 deletions viz/src/server/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,10 @@ pub trait Listener {
fn accept(&self) -> impl Future<Output = Result<(Self::Io, Self::Addr)>> + Send;

/// Returns the local address that this listener is bound to.
///
/// # Errors
///
/// An error will return if got the socket address of the local half of this connection is
/// failed.
fn local_addr(&self) -> Result<Self::Addr>;
}