Skip to content

Commit

Permalink
feat(ValidateRequestHeaderLayer): add assert() function
Browse files Browse the repository at this point in the history
  • Loading branch information
Oliboy50 committed Apr 18, 2023
1 parent 92d1954 commit bbd0d90
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- **request_id:** Derive `Default` for `MakeRequestUuid` ([#335])
- **fs:** Derive `Default` for `ServeFileSystemResponseBody` ([#336])
- **validate-request:** Add `ValidateRequestHeaderLayer::assert()` function to reject requests when a header does not have an expected value ([#360])

## Changed

Expand All @@ -27,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#335]: https://github.com/tower-rs/tower-http/pull/335
[#336]: https://github.com/tower-rs/tower-http/pull/336
[#354]: https://github.com/tower-rs/tower-http/pull/354
[#360]: https://github.com/tower-rs/tower-http/pull/360

# 0.4.0 (February 24, 2023)

Expand Down
166 changes: 166 additions & 0 deletions tower-http/src/validate_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//!
//! # Example
//!
//! Validation of the `Accept` header can be made by using [`ValidateRequestHeaderLayer::accept()`]:
//!
//! ```
//! use tower_http::validate_request::ValidateRequestHeaderLayer;
//! use hyper::{Request, Response, Body, Error};
Expand Down Expand Up @@ -50,6 +52,70 @@
//! # }
//! ```
//!
//! Validation of a custom header can be made by using [`ValidateRequestHeaderLayer::assert()`]:
//!
//! ```
//! use tower_http::validate_request::ValidateRequestHeaderLayer;
//! use hyper::{Request, Response, Body, Error};
//! use http::StatusCode;
//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
//!
//! async fn handle(request: Request<Body>) -> Result<Response<Body>, Error> {
//! Ok(Response::new(Body::empty()))
//! }
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let mut service = ServiceBuilder::new()
//! // Require a `X-Custom-Header` header to have the value `random-value-1234567890` or reject with a `403 Forbidden` response
//! .layer(ValidateRequestHeaderLayer::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN))
//! .service_fn(handle);
//!
//! // Requests with the correct value are allowed through
//! let request = Request::builder()
//! .header("x-custom-header", "random-value-1234567890")
//! .body(Body::empty())
//! .unwrap();
//!
//! let response = service
//! .ready()
//! .await?
//! .call(request)
//! .await?;
//!
//! assert_eq!(StatusCode::OK, response.status());
//!
//! // Requests with an invalid value get a `403 Forbidden` response
//! let request = Request::builder()
//! .header("x-custom-header", "wrong-value")
//! .body(Body::empty())
//! .unwrap();
//!
//! let response = service
//! .ready()
//! .await?
//! .call(request)
//! .await?;
//!
//! assert_eq!(StatusCode::FORBIDDEN, response.status());
//! #
//! # // Requests without the expected header also get a `403 Forbidden` response
//! # let request = Request::builder()
//! # .body(Body::empty())
//! # .unwrap();
//! #
//! # let response = service
//! # .ready()
//! # .await?
//! # .call(request)
//! # .await?;
//! #
//! # assert_eq!(StatusCode::FORBIDDEN, response.status());
//! #
//! # Ok(())
//! # }
//! ```
//!
//! Custom validation can be made by implementing [`ValidateRequest`]:
//!
//! ```
Expand Down Expand Up @@ -112,6 +178,8 @@
//! # Ok(())
//! # }
//! ```
//!
//! [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
use http::{header, Request, Response, StatusCode};
use http_body::Body;
Expand Down Expand Up @@ -165,6 +233,34 @@ impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
}
}

impl<ResBody> ValidateRequestHeaderLayer<AssertHeaderOrReject<ResBody>> {
/// Validate requests have a required header with a specific value.
///
/// # Example
///
/// ```
/// use http::StatusCode;
/// use hyper::Body;
/// use tower_http::validate_request::{AssertHeaderOrReject, ValidateRequestHeaderLayer};
///
/// let layer = ValidateRequestHeaderLayer::<AssertHeaderOrReject<Body>>::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN);
/// ```
pub fn assert(
expected_header_name: &str,
expected_header_value: &str,
response_status_code: StatusCode,
) -> Self
where
ResBody: Body + Default,
{
Self::custom(AssertHeaderOrReject::new(
expected_header_name,
expected_header_value,
response_status_code,
))
}
}

impl<T> ValidateRequestHeaderLayer<T> {
/// Validate requests using a custom method.
pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> {
Expand Down Expand Up @@ -409,6 +505,76 @@ where
}
}

/// Type that rejects requests if a header is not present or does not have an expected value.
pub struct AssertHeaderOrReject<ResBody> {
expected_header_name: String,
expected_header_value: String,
response_status_code: StatusCode,
_ty: PhantomData<fn() -> ResBody>,
}

impl<ResBody> AssertHeaderOrReject<ResBody> {
/// Create a new `AssertHeaderOrReject` struct.
fn new(
expected_header_name: &str,
expected_header_value: &str,
response_status_code: StatusCode,
) -> Self
where
ResBody: Body + Default,
{
Self {
expected_header_name: expected_header_name.to_string(),
expected_header_value: expected_header_value.to_string(),
response_status_code,
_ty: PhantomData,
}
}
}

impl<ResBody> Clone for AssertHeaderOrReject<ResBody> {
fn clone(&self) -> Self {
Self {
expected_header_name: self.expected_header_name.clone(),
expected_header_value: self.expected_header_value.clone(),
response_status_code: self.response_status_code.clone(),
_ty: PhantomData,
}
}
}

impl<ResBody> fmt::Debug for AssertHeaderOrReject<ResBody> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AssertHeaderOrReject")
.field("expected_header_name", &self.expected_header_name)
.field("expected_header_value", &self.expected_header_value)
.field("response_status_code", &self.response_status_code)
.finish()
}
}

impl<B, ResBody> ValidateRequest<B> for AssertHeaderOrReject<ResBody>
where
ResBody: Body + Default,
{
type ResponseBody = ResBody;

fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
let request_header_value = req
.headers()
.get(&self.expected_header_name)
.and_then(|v| v.to_str().ok());

if request_header_value != Some(&self.expected_header_value) {
let mut res = Response::new(ResBody::default());
*res.status_mut() = self.response_status_code;
return Err(res);
}

Ok(())
}
}

#[cfg(test)]
mod tests {
#[allow(unused_imports)]
Expand Down

0 comments on commit bbd0d90

Please sign in to comment.