diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs index c61c1bed..f43b72d9 100644 --- a/tower-http/src/validate_request.rs +++ b/tower-http/src/validate_request.rs @@ -2,6 +2,8 @@ //! //! # Example //! +//! Validation of the `Accept` header can be made by using [`ValidateRequestHeaderLayer::accept()`]: +//! //! ``` //! use tower_http::validate_request::ValidateRequestHeaderLayer; //! use hyper::{Request, Response, Body, Error}; @@ -50,6 +52,70 @@ //! # } //! ``` //! +//! Validation of a custom header can be made by using [`ValidateRequestHeaderLayer::assert()`]: +//! +//! ``` +//! use tower_http::validate_request::ValidateRequestHeaderLayer; +//! use hyper::{Request, Response, Body, Error}; +//! use http::StatusCode; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! +//! async fn handle(request: Request) -> Result, Error> { +//! Ok(Response::new(Body::empty())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let mut service = ServiceBuilder::new() +//! // Require a `X-Custom-Header` header to have the value `random-value-1234567890` or reject with a `403 Forbidden` response +//! .layer(ValidateRequestHeaderLayer::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN)) +//! .service_fn(handle); +//! +//! // Requests with the correct value are allowed through +//! let request = Request::builder() +//! .header("x-custom-header", "random-value-1234567890") +//! .body(Body::empty()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::OK, response.status()); +//! +//! // Requests with an invalid value get a `403 Forbidden` response +//! let request = Request::builder() +//! .header("x-custom-header", "wrong-value") +//! .body(Body::empty()) +//! .unwrap(); +//! +//! let response = service +//! .ready() +//! .await? +//! .call(request) +//! .await?; +//! +//! assert_eq!(StatusCode::FORBIDDEN, response.status()); +//! # +//! # // Requests without the expected header also get a `403 Forbidden` response +//! # let request = Request::builder() +//! # .body(Body::empty()) +//! # .unwrap(); +//! # +//! # let response = service +//! # .ready() +//! # .await? +//! # .call(request) +//! # .await?; +//! # +//! # assert_eq!(StatusCode::FORBIDDEN, response.status()); +//! # +//! # Ok(()) +//! # } +//! ``` +//! //! Custom validation can be made by implementing [`ValidateRequest`]: //! //! ``` @@ -112,6 +178,8 @@ //! # Ok(()) //! # } //! ``` +//! +//! [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept use http::{header, Request, Response, StatusCode}; use http_body::Body; @@ -165,6 +233,34 @@ impl ValidateRequestHeaderLayer> { } } +impl ValidateRequestHeaderLayer> { + /// Validate requests have a required header with a specific value. + /// + /// # Example + /// + /// ``` + /// use http::StatusCode; + /// use hyper::Body; + /// use tower_http::validate_request::{AssertHeaderOrReject, ValidateRequestHeaderLayer}; + /// + /// let layer = ValidateRequestHeaderLayer::>::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN); + /// ``` + pub fn assert( + expected_header_name: &str, + expected_header_value: &str, + response_status_code: StatusCode, + ) -> Self + where + ResBody: Body + Default, + { + Self::custom(AssertHeaderOrReject::new( + expected_header_name, + expected_header_value, + response_status_code, + )) + } +} + impl ValidateRequestHeaderLayer { /// Validate requests using a custom method. pub fn custom(validate: T) -> ValidateRequestHeaderLayer { @@ -409,6 +505,76 @@ where } } +/// Type that rejects requests if a header is not present or does not have an expected value. +pub struct AssertHeaderOrReject { + expected_header_name: String, + expected_header_value: String, + response_status_code: StatusCode, + _ty: PhantomData ResBody>, +} + +impl AssertHeaderOrReject { + /// Create a new `AssertHeaderOrReject` struct. + fn new( + expected_header_name: &str, + expected_header_value: &str, + response_status_code: StatusCode, + ) -> Self + where + ResBody: Body + Default, + { + Self { + expected_header_name: expected_header_name.to_string(), + expected_header_value: expected_header_value.to_string(), + response_status_code, + _ty: PhantomData, + } + } +} + +impl Clone for AssertHeaderOrReject { + fn clone(&self) -> Self { + Self { + expected_header_name: self.expected_header_name.clone(), + expected_header_value: self.expected_header_value.clone(), + response_status_code: self.response_status_code.clone(), + _ty: PhantomData, + } + } +} + +impl fmt::Debug for AssertHeaderOrReject { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AssertHeaderOrReject") + .field("expected_header_name", &self.expected_header_name) + .field("expected_header_value", &self.expected_header_value) + .field("response_status_code", &self.response_status_code) + .finish() + } +} + +impl ValidateRequest for AssertHeaderOrReject +where + ResBody: Body + Default, +{ + type ResponseBody = ResBody; + + fn validate(&mut self, req: &mut Request) -> Result<(), Response> { + let request_header_value = req + .headers() + .get(&self.expected_header_name) + .and_then(|v| v.to_str().ok()); + + if request_header_value != Some(&self.expected_header_value) { + let mut res = Response::new(ResBody::default()); + *res.status_mut() = self.response_status_code; + return Err(res); + } + + Ok(()) + } +} + #[cfg(test)] mod tests { #[allow(unused_imports)]