Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ValidateRequestHeaderLayer): add assert() function #360

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- **request_id:** Derive `Default` for `MakeRequestUuid` ([#335])
- **fs:** Derive `Default` for `ServeFileSystemResponseBody` ([#336])
- **validate-request:** Add `ValidateRequestHeaderLayer::assert()` function to reject requests when a header does not have an expected value ([#360])

## Changed

Expand All @@ -27,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#335]: https://github.com/tower-rs/tower-http/pull/335
[#336]: https://github.com/tower-rs/tower-http/pull/336
[#354]: https://github.com/tower-rs/tower-http/pull/354
[#360]: https://github.com/tower-rs/tower-http/pull/360

# 0.4.0 (February 24, 2023)

Expand Down
166 changes: 166 additions & 0 deletions tower-http/src/validate_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//!
//! # Example
//!
//! Validation of the `Accept` header can be made by using [`ValidateRequestHeaderLayer::accept()`]:
//!
//! ```
//! use tower_http::validate_request::ValidateRequestHeaderLayer;
//! use hyper::{Request, Response, Body, Error};
Expand Down Expand Up @@ -50,6 +52,70 @@
//! # }
//! ```
//!
//! Validation of a custom header can be made by using [`ValidateRequestHeaderLayer::assert()`]:
//!
//! ```
//! use tower_http::validate_request::ValidateRequestHeaderLayer;
//! use hyper::{Request, Response, Body, Error};
//! use http::StatusCode;
//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
//!
//! async fn handle(request: Request<Body>) -> Result<Response<Body>, Error> {
//! Ok(Response::new(Body::empty()))
//! }
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let mut service = ServiceBuilder::new()
//! // Require a `X-Custom-Header` header to have the value `random-value-1234567890` or reject with a `403 Forbidden` response
//! .layer(ValidateRequestHeaderLayer::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN))
Copy link
Contributor Author

@Oliboy50 Oliboy50 Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also thought of using a Builder pattern here

so instead of the assert function, the API would look like:

.layer(
    ValidateRequestHeaderLayer::has("x-custom-header")
        .with_value("random-value-1234567890")
        .or_reject_with(StatusCode::FORBIDDEN)
)

where or_reject_with would be the "build" method...

and we could say that with_value is optional, so this new layer could also be used by those who just want to make sure that a request has a specific header and they don't care about its value

.layer(
    ValidateRequestHeaderLayer::has("x-custom-header")
        .or_reject_with(StatusCode::FORBIDDEN)
)

WDYT?

//! .service_fn(handle);
//!
//! // Requests with the correct value are allowed through
//! let request = Request::builder()
//! .header("x-custom-header", "random-value-1234567890")
//! .body(Body::empty())
//! .unwrap();
//!
//! let response = service
//! .ready()
//! .await?
//! .call(request)
//! .await?;
//!
//! assert_eq!(StatusCode::OK, response.status());
//!
//! // Requests with an invalid value get a `403 Forbidden` response
//! let request = Request::builder()
//! .header("x-custom-header", "wrong-value")
//! .body(Body::empty())
//! .unwrap();
//!
//! let response = service
//! .ready()
//! .await?
//! .call(request)
//! .await?;
//!
//! assert_eq!(StatusCode::FORBIDDEN, response.status());
//! #
//! # // Requests without the expected header also get a `403 Forbidden` response
//! # let request = Request::builder()
//! # .body(Body::empty())
//! # .unwrap();
//! #
//! # let response = service
//! # .ready()
//! # .await?
//! # .call(request)
//! # .await?;
//! #
//! # assert_eq!(StatusCode::FORBIDDEN, response.status());
//! #
Comment on lines +101 to +114
Copy link
Contributor Author

@Oliboy50 Oliboy50 Apr 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ℹ️ this has been hidden from the documentation because it feels redundant... but it has been kept because it remains a valuable test

//! # Ok(())
//! # }
//! ```
//!
//! Custom validation can be made by implementing [`ValidateRequest`]:
//!
//! ```
Expand Down Expand Up @@ -112,6 +178,8 @@
//! # Ok(())
//! # }
//! ```
//!
//! [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
Copy link
Contributor Author

@Oliboy50 Oliboy50 Apr 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not too sure about this one...
I copied/pasted it from the existing accept() function documentation when I added this line of documentation:

Validation of the Accept header can be made by using [ValidateRequestHeaderLayer::accept()]:

but I think that it is mostly useless because I didn't find any link to the MDN in the generated documentation page 🤷

just tell me if I should remove it... BTW I can also delete the one in the accept() function documentation if you want (same issue)


use http::{header, Request, Response, StatusCode};
use http_body::Body;
Expand Down Expand Up @@ -165,6 +233,34 @@ impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
}
}

impl<ResBody> ValidateRequestHeaderLayer<AssertHeaderOrReject<ResBody>> {
/// Validate requests have a required header with a specific value.
///
/// # Example
///
/// ```
/// use http::StatusCode;
/// use hyper::Body;
/// use tower_http::validate_request::{AssertHeaderOrReject, ValidateRequestHeaderLayer};
///
/// let layer = ValidateRequestHeaderLayer::<AssertHeaderOrReject<Body>>::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN);
/// ```
pub fn assert(
expected_header_name: &str,
expected_header_value: &str,
response_status_code: StatusCode,
) -> Self
where
ResBody: Body + Default,
{
Self::custom(AssertHeaderOrReject::new(
expected_header_name,
expected_header_value,
response_status_code,
))
}
}

impl<T> ValidateRequestHeaderLayer<T> {
/// Validate requests using a custom method.
pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> {
Expand Down Expand Up @@ -409,6 +505,76 @@ where
}
}

/// Type that rejects requests if a header is not present or does not have an expected value.
pub struct AssertHeaderOrReject<ResBody> {
expected_header_name: String,
expected_header_value: String,
response_status_code: StatusCode,
_ty: PhantomData<fn() -> ResBody>,
}

impl<ResBody> AssertHeaderOrReject<ResBody> {
/// Create a new `AssertHeaderOrReject` struct.
fn new(
expected_header_name: &str,
expected_header_value: &str,
response_status_code: StatusCode,
) -> Self
where
ResBody: Body + Default,
{
Self {
expected_header_name: expected_header_name.to_string(),
expected_header_value: expected_header_value.to_string(),
response_status_code,
_ty: PhantomData,
}
}
}

impl<ResBody> Clone for AssertHeaderOrReject<ResBody> {
fn clone(&self) -> Self {
Self {
expected_header_name: self.expected_header_name.clone(),
expected_header_value: self.expected_header_value.clone(),
response_status_code: self.response_status_code.clone(),
_ty: PhantomData,
}
}
}

impl<ResBody> fmt::Debug for AssertHeaderOrReject<ResBody> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AssertHeaderOrReject")
.field("expected_header_name", &self.expected_header_name)
.field("expected_header_value", &self.expected_header_value)
.field("response_status_code", &self.response_status_code)
.finish()
}
}

impl<B, ResBody> ValidateRequest<B> for AssertHeaderOrReject<ResBody>
where
ResBody: Body + Default,
{
type ResponseBody = ResBody;

fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
let request_header_value = req
.headers()
.get(&self.expected_header_name)
.and_then(|v| v.to_str().ok());

if request_header_value != Some(&self.expected_header_value) {
let mut res = Response::new(ResBody::default());
*res.status_mut() = self.response_status_code;
return Err(res);
}

Ok(())
}
}

#[cfg(test)]
mod tests {
#[allow(unused_imports)]
Expand Down