diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md
index 69489759..828eaf3f 100644
--- a/tower-http/CHANGELOG.md
+++ b/tower-http/CHANGELOG.md
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **request_id:** Derive `Default` for `MakeRequestUuid` ([#335])
- **fs:** Derive `Default` for `ServeFileSystemResponseBody` ([#336])
+- **validate-request:** Add `ValidateRequestHeaderLayer::assert()` function to reject requests when a header does not have an expected value ([#360])
## Changed
@@ -27,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#335]: https://github.com/tower-rs/tower-http/pull/335
[#336]: https://github.com/tower-rs/tower-http/pull/336
[#354]: https://github.com/tower-rs/tower-http/pull/354
+[#360]: https://github.com/tower-rs/tower-http/pull/360
# 0.4.0 (February 24, 2023)
diff --git a/tower-http/src/validate_request.rs b/tower-http/src/validate_request.rs
index c61c1bed..f43b72d9 100644
--- a/tower-http/src/validate_request.rs
+++ b/tower-http/src/validate_request.rs
@@ -2,6 +2,8 @@
//!
//! # Example
//!
+//! Validation of the `Accept` header can be made by using [`ValidateRequestHeaderLayer::accept()`]:
+//!
//! ```
//! use tower_http::validate_request::ValidateRequestHeaderLayer;
//! use hyper::{Request, Response, Body, Error};
@@ -50,6 +52,70 @@
//! # }
//! ```
//!
+//! Validation of a custom header can be made by using [`ValidateRequestHeaderLayer::assert()`]:
+//!
+//! ```
+//! use tower_http::validate_request::ValidateRequestHeaderLayer;
+//! use hyper::{Request, Response, Body, Error};
+//! use http::StatusCode;
+//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
+//!
+//! async fn handle(request: Request
) -> Result, Error> {
+//! Ok(Response::new(Body::empty()))
+//! }
+//!
+//! # #[tokio::main]
+//! # async fn main() -> Result<(), Box> {
+//! let mut service = ServiceBuilder::new()
+//! // Require a `X-Custom-Header` header to have the value `random-value-1234567890` or reject with a `403 Forbidden` response
+//! .layer(ValidateRequestHeaderLayer::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN))
+//! .service_fn(handle);
+//!
+//! // Requests with the correct value are allowed through
+//! let request = Request::builder()
+//! .header("x-custom-header", "random-value-1234567890")
+//! .body(Body::empty())
+//! .unwrap();
+//!
+//! let response = service
+//! .ready()
+//! .await?
+//! .call(request)
+//! .await?;
+//!
+//! assert_eq!(StatusCode::OK, response.status());
+//!
+//! // Requests with an invalid value get a `403 Forbidden` response
+//! let request = Request::builder()
+//! .header("x-custom-header", "wrong-value")
+//! .body(Body::empty())
+//! .unwrap();
+//!
+//! let response = service
+//! .ready()
+//! .await?
+//! .call(request)
+//! .await?;
+//!
+//! assert_eq!(StatusCode::FORBIDDEN, response.status());
+//! #
+//! # // Requests without the expected header also get a `403 Forbidden` response
+//! # let request = Request::builder()
+//! # .body(Body::empty())
+//! # .unwrap();
+//! #
+//! # let response = service
+//! # .ready()
+//! # .await?
+//! # .call(request)
+//! # .await?;
+//! #
+//! # assert_eq!(StatusCode::FORBIDDEN, response.status());
+//! #
+//! # Ok(())
+//! # }
+//! ```
+//!
//! Custom validation can be made by implementing [`ValidateRequest`]:
//!
//! ```
@@ -112,6 +178,8 @@
//! # Ok(())
//! # }
//! ```
+//!
+//! [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
use http::{header, Request, Response, StatusCode};
use http_body::Body;
@@ -165,6 +233,34 @@ impl ValidateRequestHeaderLayer> {
}
}
+impl ValidateRequestHeaderLayer> {
+ /// Validate requests have a required header with a specific value.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use http::StatusCode;
+ /// use hyper::Body;
+ /// use tower_http::validate_request::{AssertHeaderOrReject, ValidateRequestHeaderLayer};
+ ///
+ /// let layer = ValidateRequestHeaderLayer::>::assert("x-custom-header", "random-value-1234567890", StatusCode::FORBIDDEN);
+ /// ```
+ pub fn assert(
+ expected_header_name: &str,
+ expected_header_value: &str,
+ response_status_code: StatusCode,
+ ) -> Self
+ where
+ ResBody: Body + Default,
+ {
+ Self::custom(AssertHeaderOrReject::new(
+ expected_header_name,
+ expected_header_value,
+ response_status_code,
+ ))
+ }
+}
+
impl ValidateRequestHeaderLayer {
/// Validate requests using a custom method.
pub fn custom(validate: T) -> ValidateRequestHeaderLayer {
@@ -409,6 +505,76 @@ where
}
}
+/// Type that rejects requests if a header is not present or does not have an expected value.
+pub struct AssertHeaderOrReject {
+ expected_header_name: String,
+ expected_header_value: String,
+ response_status_code: StatusCode,
+ _ty: PhantomData ResBody>,
+}
+
+impl AssertHeaderOrReject {
+ /// Create a new `AssertHeaderOrReject` struct.
+ fn new(
+ expected_header_name: &str,
+ expected_header_value: &str,
+ response_status_code: StatusCode,
+ ) -> Self
+ where
+ ResBody: Body + Default,
+ {
+ Self {
+ expected_header_name: expected_header_name.to_string(),
+ expected_header_value: expected_header_value.to_string(),
+ response_status_code,
+ _ty: PhantomData,
+ }
+ }
+}
+
+impl Clone for AssertHeaderOrReject {
+ fn clone(&self) -> Self {
+ Self {
+ expected_header_name: self.expected_header_name.clone(),
+ expected_header_value: self.expected_header_value.clone(),
+ response_status_code: self.response_status_code.clone(),
+ _ty: PhantomData,
+ }
+ }
+}
+
+impl fmt::Debug for AssertHeaderOrReject {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("AssertHeaderOrReject")
+ .field("expected_header_name", &self.expected_header_name)
+ .field("expected_header_value", &self.expected_header_value)
+ .field("response_status_code", &self.response_status_code)
+ .finish()
+ }
+}
+
+impl ValidateRequest for AssertHeaderOrReject
+where
+ ResBody: Body + Default,
+{
+ type ResponseBody = ResBody;
+
+ fn validate(&mut self, req: &mut Request) -> Result<(), Response> {
+ let request_header_value = req
+ .headers()
+ .get(&self.expected_header_name)
+ .and_then(|v| v.to_str().ok());
+
+ if request_header_value != Some(&self.expected_header_value) {
+ let mut res = Response::new(ResBody::default());
+ *res.status_mut() = self.response_status_code;
+ return Err(res);
+ }
+
+ Ok(())
+ }
+}
+
#[cfg(test)]
mod tests {
#[allow(unused_imports)]