From bbd0d90e6f66a30137f8ec60d44fdff5e04529c2 Mon Sep 17 00:00:00 2001 From: Oliboy50 Date: Mon, 17 Apr 2023 20:17:58 +0200 Subject: [PATCH] feat(ValidateRequestHeaderLayer): add assert() function --- tower-http/CHANGELOG.md | 2 + tower-http/src/validate_request.rs | 166 +++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+) diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 69489759..828eaf3f 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **request_id:** Derive `Default` for `MakeRequestUuid` ([#335]) - **fs:** Derive `Default` for `ServeFileSystemResponseBody` ([#336]) +- **validate-request:** Add `ValidateRequestHeaderLayer::assert()` function to reject requests when a header does not have an expected value ([#360]) ## Changed @@ -27,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#335]: https://github.com/tower-rs/tower-http/pull/335 [#336]: https://github.com/tower-rs/tower-http/pull/336 [#354]: https://github.com/tower-rs/tower-http/pull/354 +[#360]: https://github.com/tower-rs/tower-http/pull/360 # 0.4.0 (February 24, 2023) diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index c61c1bed..f43b72d9 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -2,6 +2,8 @@ //! //! # Example //! +//! Validation of the `Accept` header can be made by using [`ValidateRequestHeaderLayer::accept()`]: +//! //! ``` //! use tower_http::validate_request::ValidateRequestHeaderLayer; //! use hyper::{Request, Response, Body, Error}; @@ -50,6 +52,70 @@ //! # } //! ``` //! +//! Validation of a custom header can be made by using [`ValidateRequestHeaderLayer::assert()`]: +//! +//! ``` +//! use tower_http::validate_request::ValidateRequestHeaderLayer; +//! use hyper::{Request, Response, Body, Error}; +//! use http::StatusCode; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! +//! async fn handle(request: Request) -> Result, Error> { +//! Ok(Response::new(Body::empty())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let mut service = ServiceBuilder::new() +//! // Require a `X-Custom-Header` header to have the value `random-value-1234567890` or reject with a `403 Forbidden` response +//! .layer(ValidateRequestHeaderLayer::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN)) +//! .service_fn(handle); +//! +//! // Requests with the correct value are allowed through +//! let request = Request::builder() +//! .header("x-custom-header", "random-value-1234567890") +//! .body(Body::empty()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::OK, response.status()); +//! +//! // Requests with an invalid value get a `403 Forbidden` response +//! let request = Request::builder() +//! .header("x-custom-header", "wrong-value") +//! .body(Body::empty()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::FORBIDDEN, response.status()); +//! # +//! # // Requests without the expected header also get a `403 Forbidden` response +//! # let request = Request::builder() +//! # .body(Body::empty()) +//! # .unwrap(); +//! # +//! # let response = service +//! # .ready() +//! # .await? +//! # .call(request) +//! # .await?; +//! # +//! # assert_eq!(StatusCode::FORBIDDEN, response.status()); +//! # +//! # Ok(()) +//! # } +//! ``` +//! //! Custom validation can be made by implementing [`ValidateRequest`]: //! //! ``` @@ -112,6 +178,8 @@ //! # Ok(()) //! # } //! ``` +//! +//! [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept use http::{header, Request, Response, StatusCode}; use http_body::Body; @@ -165,6 +233,34 @@ impl ValidateRequestHeaderLayer> { } } +impl ValidateRequestHeaderLayer> { + /// Validate requests have a required header with a specific value. + /// + /// # Example + /// + /// ``` + /// use http::StatusCode; + /// use hyper::Body; + /// use tower_http::validate_request::{AssertHeaderOrReject, ValidateRequestHeaderLayer}; + /// + /// let layer = ValidateRequestHeaderLayer::>::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN); + /// ``` + pub fn assert( + expected_header_name: &str, + expected_header_value: &str, + response_status_code: StatusCode, + ) -> Self + where + ResBody: Body + Default, + { + Self::custom(AssertHeaderOrReject::new( + expected_header_name, + expected_header_value, + response_status_code, + )) + } +} + impl ValidateRequestHeaderLayer { /// Validate requests using a custom method. pub fn custom(validate: T) -> ValidateRequestHeaderLayer { @@ -409,6 +505,76 @@ where } } +/// Type that rejects requests if a header is not present or does not have an expected value. +pub struct AssertHeaderOrReject { + expected_header_name: String, + expected_header_value: String, + response_status_code: StatusCode, + _ty: PhantomData ResBody>, +} + +impl AssertHeaderOrReject { + /// Create a new `AssertHeaderOrReject` struct. + fn new( + expected_header_name: &str, + expected_header_value: &str, + response_status_code: StatusCode, + ) -> Self + where + ResBody: Body + Default, + { + Self { + expected_header_name: expected_header_name.to_string(), + expected_header_value: expected_header_value.to_string(), + response_status_code, + _ty: PhantomData, + } + } +} + +impl Clone for AssertHeaderOrReject { + fn clone(&self) -> Self { + Self { + expected_header_name: self.expected_header_name.clone(), + expected_header_value: self.expected_header_value.clone(), + response_status_code: self.response_status_code.clone(), + _ty: PhantomData, + } + } +} + +impl fmt::Debug for AssertHeaderOrReject { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AssertHeaderOrReject") + .field("expected_header_name", &self.expected_header_name) + .field("expected_header_value", &self.expected_header_value) + .field("response_status_code", &self.response_status_code) + .finish() + } +} + +impl ValidateRequest for AssertHeaderOrReject +where + ResBody: Body + Default, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, req: &mut Request) -> Result<(), Response> { + let request_header_value = req + .headers() + .get(&self.expected_header_name) + .and_then(|v| v.to_str().ok()); + + if request_header_value != Some(&self.expected_header_value) { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = self.response_status_code; + return Err(res); + } + + Ok(()) + } +} + #[cfg(test)] mod tests { #[allow(unused_imports)]