From 50c6ca240ced41f2063e3fe783c8beecae009147 Mon Sep 17 00:00:00 2001 From: Spencer Ferris <3319370+spencewenski@users.noreply.github.com> Date: Sun, 9 Jun 2024 01:31:31 -0700 Subject: [PATCH] Add support for tower's CORS middleware Closes https://github.com/roadster-rs/roadster/issues/47 --- Cargo.toml | 2 +- src/config/service/http/default.toml | 3 + src/config/service/http/middleware.rs | 6 +- ...ster__config__app_config__tests__test.snap | 5 + src/error/axum.rs | 20 +- src/middleware/http/auth/jwt/ietf.rs | 8 +- src/middleware/http/auth/jwt/mod.rs | 7 +- src/middleware/http/auth/jwt/openid.rs | 7 +- src/service/http/middleware/catch_panic.rs | 3 +- src/service/http/middleware/compression.rs | 5 +- src/service/http/middleware/cors.rs | 441 ++++++++++++++++++ src/service/http/middleware/default.rs | 22 +- src/service/http/middleware/mod.rs | 1 + src/service/http/middleware/request_id.rs | 7 +- .../http/middleware/sensitive_headers.rs | 8 +- src/service/http/middleware/size_limit.rs | 3 +- ...deserialize_cors_allow_headers@case_1.snap | 6 + ...deserialize_cors_allow_headers@case_2.snap | 6 + ...deserialize_cors_allow_headers@case_3.snap | 10 + ...deserialize_cors_allow_methods@case_1.snap | 6 + ...deserialize_cors_allow_methods@case_2.snap | 6 + ...deserialize_cors_allow_methods@case_3.snap | 7 + ...deserialize_cors_allow_methods@case_4.snap | 10 + ...deserialize_cors_allow_origins@case_1.snap | 6 + ...deserialize_cors_allow_origins@case_2.snap | 6 + ...deserialize_cors_allow_origins@case_3.snap | 7 + ...deserialize_cors_allow_origins@case_4.snap | 10 + ...eserialize_cors_expose_headers@case_1.snap | 6 + ...eserialize_cors_expose_headers@case_2.snap | 10 + ...ult__tests__default_middleware@case_1.snap | 5 + ...ult__tests__default_middleware@case_2.snap | 16 + src/service/http/middleware/timeout.rs | 3 +- src/service/http/middleware/tracing.rs | 3 +- src/service/worker/sidekiq/app_worker.rs | 6 +- src/util/serde_util.rs | 12 +- 35 files changed, 639 insertions(+), 50 deletions(-) create mode 100644 src/service/http/middleware/cors.rs create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_1.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_2.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_3.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_1.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_2.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_3.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_4.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_1.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_2.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_3.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_4.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_expose_headers@case_1.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_expose_headers@case_2.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__default__tests__default_middleware@case_1.snap create mode 100644 src/service/http/middleware/snapshots/roadster__service__http__middleware__default__tests__default_middleware@case_2.snap diff --git a/Cargo.toml b/Cargo.toml index efbbeb58..021786cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ tracing-opentelemetry = { version = "0.24.0", features = ["metrics"], optional = axum = { workspace = true, optional = true } axum-extra = { version = "0.9.0", features = ["typed-header"], optional = true } tower = { version = "0.4.13", optional = true } -tower-http = { version = "0.5.0", features = ["trace", "timeout", "request-id", "util", "normalize-path", "sensitive-headers", "catch-panic", "compression-full", "decompression-full", "limit"], optional = true } +tower-http = { version = "0.5.0", features = ["trace", "timeout", "request-id", "util", "normalize-path", "sensitive-headers", "catch-panic", "compression-full", "decompression-full", "limit", "cors"], optional = true } aide = { workspace = true, features = ["axum", "redoc", "scalar", "macros"], optional = true } schemars = { workspace = true, optional = true } diff --git a/src/config/service/http/default.toml b/src/config/service/http/default.toml index eeedf39e..08b5be69 100644 --- a/src/config/service/http/default.toml +++ b/src/config/service/http/default.toml @@ -38,6 +38,9 @@ timeout = 10000 priority = -9970 limit = "5 MB" +[service.http.middleware.cors] +priority = -9950 + # Initializers [service.http.initializer] default-enable = true diff --git a/src/config/service/http/middleware.rs b/src/config/service/http/middleware.rs index 3600871f..200c5d2c 100644 --- a/src/config/service/http/middleware.rs +++ b/src/config/service/http/middleware.rs @@ -4,6 +4,7 @@ use crate::service::http::middleware::catch_panic::CatchPanicConfig; use crate::service::http::middleware::compression::{ RequestDecompressionConfig, ResponseCompressionConfig, }; +use crate::service::http::middleware::cors::CorsConfig; use crate::service::http::middleware::request_id::{PropagateRequestIdConfig, SetRequestIdConfig}; use crate::service::http::middleware::sensitive_headers::{ SensitiveRequestHeadersConfig, SensitiveResponseHeadersConfig, @@ -44,6 +45,9 @@ pub struct Middleware { pub timeout: MiddlewareConfig, pub size_limit: MiddlewareConfig, + + pub cors: MiddlewareConfig, + /// Allows providing configs for custom middleware. Any configs that aren't pre-defined above /// will be collected here. /// @@ -102,7 +106,7 @@ impl CommonConfig { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct MiddlewareConfig { #[serde(flatten)] diff --git a/src/config/snapshots/roadster__config__app_config__tests__test.snap b/src/config/snapshots/roadster__config__app_config__tests__test.snap index 7b985848..497eeb01 100644 --- a/src/config/snapshots/roadster__config__app_config__tests__test.snap +++ b/src/config/snapshots/roadster__config__app_config__tests__test.snap @@ -64,6 +64,11 @@ timeout = 10000 priority = -9970 limit = '5 MB' +[service.http.middleware.cors] +priority = -9950 +preset = 'restrictive' +max-age = 3600000 + [service.http.initializer] default-enable = true diff --git a/src/error/axum.rs b/src/error/axum.rs index 585b72b2..4b2d2dac 100644 --- a/src/error/axum.rs +++ b/src/error/axum.rs @@ -3,7 +3,13 @@ use crate::error::Error; #[derive(Debug, Error)] pub enum AxumError { #[error(transparent)] - InvalidHeader(#[from] axum::http::header::InvalidHeaderName), + InvalidHeaderName(#[from] axum::http::header::InvalidHeaderName), + + #[error(transparent)] + InvalidHeaderValue(#[from] axum::http::header::InvalidHeaderValue), + + #[error(transparent)] + InvalidMethod(#[from] axum::http::method::InvalidMethod), #[cfg(feature = "jwt")] #[error(transparent)] @@ -19,6 +25,18 @@ impl From for Error { } } +impl From for Error { + fn from(value: axum::http::header::InvalidHeaderValue) -> Self { + Self::Axum(AxumError::from(value)) + } +} + +impl From for Error { + fn from(value: axum::http::method::InvalidMethod) -> Self { + Self::Axum(AxumError::from(value)) + } +} + #[cfg(feature = "jwt")] impl From for Error { fn from(value: axum_extra::typed_header::TypedHeaderRejection) -> Self { diff --git a/src/middleware/http/auth/jwt/ietf.rs b/src/middleware/http/auth/jwt/ietf.rs index b2d208df..e8629003 100644 --- a/src/middleware/http/auth/jwt/ietf.rs +++ b/src/middleware/http/auth/jwt/ietf.rs @@ -54,10 +54,9 @@ mod tests { use super::*; use crate::error::RoadsterResult; use crate::middleware::http::auth::jwt::decode_auth_token; - use crate::util::serde_util::UriOrString; + use crate::util::serde_util::{UriOrString, Wrapper}; use chrono::{TimeDelta, Utc}; use jsonwebtoken::{encode, EncodingKey, Header, TokenData}; - use serde_derive::{Deserialize, Serialize}; use serde_json::from_str; use std::ops::{Add, Sub}; use std::str::FromStr; @@ -132,11 +131,6 @@ mod tests { (claims, token) } - #[derive(Debug, Deserialize, Serialize)] - struct Wrapper { - inner: T, - } - #[test] #[cfg_attr(coverage_nightly, coverage(off))] fn deserialize_audience_as_vec() { diff --git a/src/middleware/http/auth/jwt/mod.rs b/src/middleware/http/auth/jwt/mod.rs index 0bae2f3a..cfac6527 100644 --- a/src/middleware/http/auth/jwt/mod.rs +++ b/src/middleware/http/auth/jwt/mod.rs @@ -129,16 +129,11 @@ pub enum Subject { #[cfg(test)] mod tests { use super::*; - use serde_derive::{Deserialize, Serialize}; + use crate::util::serde_util::Wrapper; use serde_json::from_str; use std::str::FromStr; use url::Url; - #[derive(Debug, Deserialize, Serialize)] - struct Wrapper { - inner: T, - } - #[test] #[cfg_attr(coverage_nightly, coverage(off))] fn deserialize_subject_as_uri() { diff --git a/src/middleware/http/auth/jwt/openid.rs b/src/middleware/http/auth/jwt/openid.rs index aefcafcc..135d1e4b 100644 --- a/src/middleware/http/auth/jwt/openid.rs +++ b/src/middleware/http/auth/jwt/openid.rs @@ -69,16 +69,11 @@ pub enum Acr { #[cfg(test)] mod tests { use super::*; - use serde_derive::{Deserialize, Serialize}; + use crate::util::serde_util::Wrapper; use serde_json::from_str; use std::str::FromStr; use url::Url; - #[derive(Debug, Deserialize, Serialize)] - struct Wrapper { - inner: T, - } - #[test] #[cfg_attr(coverage_nightly, coverage(off))] fn deserialize_acr_as_uri() { diff --git a/src/service/http/middleware/catch_panic.rs b/src/service/http/middleware/catch_panic.rs index 61adb7ba..0f4b51a7 100644 --- a/src/service/http/middleware/catch_panic.rs +++ b/src/service/http/middleware/catch_panic.rs @@ -4,8 +4,9 @@ use crate::service::http::middleware::Middleware; use axum::Router; use serde_derive::{Deserialize, Serialize}; use tower_http::catch_panic::CatchPanicLayer; +use validator::Validate; -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct CatchPanicConfig {} diff --git a/src/service/http/middleware/compression.rs b/src/service/http/middleware/compression.rs index 659ca814..850bbb81 100644 --- a/src/service/http/middleware/compression.rs +++ b/src/service/http/middleware/compression.rs @@ -6,12 +6,13 @@ use serde_derive::{Deserialize, Serialize}; use crate::error::RoadsterResult; use tower_http::compression::CompressionLayer; use tower_http::decompression::RequestDecompressionLayer; +use validator::Validate; -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct ResponseCompressionConfig {} -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct RequestDecompressionConfig {} diff --git a/src/service/http/middleware/cors.rs b/src/service/http/middleware/cors.rs new file mode 100644 index 00000000..9b7ead24 --- /dev/null +++ b/src/service/http/middleware/cors.rs @@ -0,0 +1,441 @@ +use crate::app::context::AppContext; +use crate::error::RoadsterResult; +use crate::service::http::middleware::Middleware; +use axum::http::{HeaderName, HeaderValue, Method}; +use axum::Router; +use itertools::Itertools; +use serde_derive::{Deserialize, Serialize}; +use serde_with::{serde_as, skip_serializing_none}; +use std::str::FromStr; +use std::time::Duration; +use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer, ExposeHeaders}; +use validator::Validate; + +#[serde_as] +#[skip_serializing_none] +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +#[non_exhaustive] +pub struct CorsConfig { + #[serde(default)] + pub preset: CorsPreset, + + /// See + #[serde(default)] + pub allow_credentials: Option, + + /// See + #[serde(default)] + pub allow_private_network: Option, + + /// Duration in milliseconds. If a value less than one second (1000 ms) is provided, the + /// header will not be set by the middleware. + /// See + #[serde(default = "default_max_age")] + #[serde_as(as = "serde_with::DurationMilliSeconds")] + pub max_age: Duration, + + /// See + #[serde(default)] + pub allow_headers: Option, + + /// See + #[serde(default)] + pub allow_methods: Option, + + /// See + #[serde(default)] + pub allow_origins: Option, + + /// See + #[serde(default)] + pub expose_headers: Option, + + /// See + // Todo: deserialize as HeaderName directly instead of string + #[serde(default)] + pub vary: Option>, +} + +fn default_max_age() -> Duration { + Duration::from_secs(60 * 60) +} + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum CorsPreset { + /// See + #[default] + Restrictive, + /// See + Permissive, + /// See + VeryPermissive, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +pub enum CorsAllowHeaders { + Any, + MirrorRequest, + // Todo: deserialize as HeaderName directly instead of string + List { headers: Vec }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +pub enum CorsAllowMethods { + Any, + MirrorRequest, + // Todo: deserialize as Method directly instead of string + Exact { method: String }, + // Todo: deserialize as Method directly instead of string + List { methods: Vec }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +pub enum CorsAllowOrigins { + Any, + MirrorRequest, + // Todo: deserialize as HeaderValue directly instead of string + Exact { origin: String }, + // Todo: deserialize as HeaderValue directly instead of string + List { origins: Vec }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +pub enum CorsExposeHeaders { + Any, + // Todo: deserialize as HeaderName directly instead of string + List { headers: Vec }, +} + +pub fn parse_header_names(header_names: &[String]) -> RoadsterResult> { + let header_names = header_names + .iter() + .map(|header_name| HeaderName::from_str(header_name)) + .try_collect()?; + Ok(header_names) +} + +pub fn parse_header_values(header_values: &[String]) -> RoadsterResult> { + let header_values = header_values + .iter() + .map(|header_value| HeaderValue::from_str(header_value)) + .try_collect()?; + Ok(header_values) +} + +pub fn parse_methods(methods: &[String]) -> RoadsterResult> { + let methods = methods + .iter() + .map(|method| Method::from_str(method)) + .try_collect()?; + Ok(methods) +} + +pub struct CorsMiddleware; +impl Middleware for CorsMiddleware { + fn name(&self) -> String { + "cors".to_string() + } + + fn enabled(&self, context: &AppContext) -> bool { + context + .config() + .service + .http + .custom + .middleware + .cors + .common + .enabled(context) + } + + fn priority(&self, context: &AppContext) -> i32 { + context + .config() + .service + .http + .custom + .middleware + .cors + .common + .priority + } + + fn install(&self, router: Router, context: &AppContext) -> RoadsterResult { + let config = &context.config().service.http.custom.middleware.cors.custom; + let layer = match config.preset { + CorsPreset::Restrictive => CorsLayer::new(), + CorsPreset::Permissive => CorsLayer::permissive(), + CorsPreset::VeryPermissive => CorsLayer::very_permissive(), + }; + + let layer = if config.max_age > Duration::from_secs(1) { + layer.max_age(config.max_age) + } else { + layer + }; + + let layer = config + .allow_credentials + .iter() + .fold(layer, |layer, allow| layer.allow_credentials(*allow)); + + let layer = config + .allow_private_network + .iter() + .fold(layer, |layer, allow| layer.allow_private_network(*allow)); + + let layer = config.allow_headers.iter().try_fold( + layer, + |layer, allow| -> RoadsterResult { + let layer = match allow { + CorsAllowHeaders::Any => layer.allow_headers(AllowHeaders::any()), + CorsAllowHeaders::MirrorRequest => { + layer.allow_headers(AllowHeaders::mirror_request()) + } + CorsAllowHeaders::List { headers } => { + layer.allow_headers(AllowHeaders::list(parse_header_names(headers)?)) + } + }; + Ok(layer) + }, + )?; + + let layer = config.expose_headers.iter().try_fold( + layer, + |layer, allow| -> RoadsterResult { + let layer = match allow { + CorsExposeHeaders::Any => layer.expose_headers(ExposeHeaders::any()), + CorsExposeHeaders::List { headers } => { + layer.expose_headers(ExposeHeaders::list(parse_header_names(headers)?)) + } + }; + Ok(layer) + }, + )?; + + let layer = config.vary.iter().try_fold( + layer, + |layer, header_names| -> RoadsterResult { + let layer = layer.vary(parse_header_names(header_names)?); + Ok(layer) + }, + )?; + + let layer = config.allow_origins.iter().try_fold( + layer, + |layer, allow| -> RoadsterResult { + let layer = match allow { + CorsAllowOrigins::Any => layer.allow_origin(AllowOrigin::any()), + CorsAllowOrigins::MirrorRequest => { + layer.allow_origin(AllowOrigin::mirror_request()) + } + CorsAllowOrigins::Exact { origin } => { + layer.allow_origin(AllowOrigin::exact(HeaderValue::from_str(origin)?)) + } + CorsAllowOrigins::List { origins } => { + layer.allow_origin(AllowOrigin::list(parse_header_values(origins)?)) + } + }; + Ok(layer) + }, + )?; + + let layer = config.allow_methods.iter().try_fold( + layer, + |layer, allow| -> RoadsterResult { + let layer = match allow { + CorsAllowMethods::Any => layer.allow_methods(AllowMethods::any()), + CorsAllowMethods::MirrorRequest => { + layer.allow_methods(AllowMethods::mirror_request()) + } + CorsAllowMethods::Exact { method } => { + layer.allow_methods(AllowMethods::exact(Method::from_str(method)?)) + } + CorsAllowMethods::List { methods } => { + layer.allow_methods(AllowMethods::list(parse_methods(methods)?)) + } + }; + Ok(layer) + }, + )?; + + let router = router.layer(layer); + + Ok(router) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::app_config::AppConfig; + use crate::util::serde_util::Wrapper; + use crate::util::test_util::TestCase; + use insta::assert_toml_snapshot; + use rstest::{fixture, rstest}; + + #[fixture] + #[cfg_attr(coverage_nightly, coverage(off))] + fn case() -> TestCase { + Default::default() + } + + #[rstest] + #[case(false, Some(true), true)] + #[case(false, Some(false), false)] + #[cfg_attr(coverage_nightly, coverage(off))] + fn cors_enabled( + #[case] default_enable: bool, + #[case] enable: Option, + #[case] expected_enabled: bool, + ) { + // Arrange + let mut config = AppConfig::test(None).unwrap(); + config.service.http.custom.middleware.default_enable = default_enable; + config.service.http.custom.middleware.cors.common.enable = enable; + + let context = AppContext::<()>::test(Some(config), None, None).unwrap(); + + let middleware = CorsMiddleware; + + // Act/Assert + assert_eq!(middleware.enabled(&context), expected_enabled); + } + + #[rstest] + #[case(None, -9950)] + #[case(Some(1234), 1234)] + #[cfg_attr(coverage_nightly, coverage(off))] + fn cors_priority(#[case] override_priority: Option, #[case] expected_priority: i32) { + // Arrange + let mut config = AppConfig::test(None).unwrap(); + if let Some(priority) = override_priority { + config.service.http.custom.middleware.cors.common.priority = priority; + } + + let context = AppContext::<()>::test(Some(config), None, None).unwrap(); + + let middleware = CorsMiddleware; + + // Act/Assert + assert_eq!(middleware.priority(&context), expected_priority); + } + + #[rstest] + #[case( + r#" + [inner] + type = 'any' + "# + )] + #[case( + r#" + [inner] + type = 'mirror-request' + "# + )] + #[case( + r#" + [inner] + type = 'list' + headers = ["foo", "bar"] + "# + )] + #[cfg_attr(coverage_nightly, coverage(off))] + fn deserialize_cors_allow_headers(_case: TestCase, #[case] serialized: &str) { + let value: Wrapper = toml::from_str(serialized).unwrap(); + assert_toml_snapshot!(value); + } + + #[rstest] + #[case( + r#" + [inner] + type = 'any' + "# + )] + #[case( + r#" + [inner] + type = 'mirror-request' + "# + )] + #[case( + r#" + [inner] + type = 'exact' + method = "foo" + "# + )] + #[case( + r#" + [inner] + type = 'list' + methods = ["foo", "bar"] + "# + )] + #[cfg_attr(coverage_nightly, coverage(off))] + fn deserialize_cors_allow_methods(_case: TestCase, #[case] serialized: &str) { + let value: Wrapper = toml::from_str(serialized).unwrap(); + assert_toml_snapshot!(value); + } + + #[rstest] + #[case( + r#" + [inner] + type = 'any' + "# + )] + #[case( + r#" + [inner] + type = 'mirror-request' + "# + )] + #[case( + r#" + [inner] + type = 'exact' + origin = "foo" + "# + )] + #[case( + r#" + [inner] + type = 'list' + origins = ["foo", "bar"] + "# + )] + #[cfg_attr(coverage_nightly, coverage(off))] + fn deserialize_cors_allow_origins(_case: TestCase, #[case] serialized: &str) { + let value: Wrapper = toml::from_str(serialized).unwrap(); + assert_toml_snapshot!(value); + } + + #[rstest] + #[case( + r#" + [inner] + type = 'any' + "# + )] + #[case( + r#" + [inner] + type = 'list' + headers = ["foo", "bar"] + "# + )] + #[cfg_attr(coverage_nightly, coverage(off))] + fn deserialize_cors_expose_headers(_case: TestCase, #[case] serialized: &str) { + let value: Wrapper = toml::from_str(serialized).unwrap(); + assert_toml_snapshot!(value); + } +} diff --git a/src/service/http/middleware/default.rs b/src/service/http/middleware/default.rs index a03fbf5f..0a03dd72 100644 --- a/src/service/http/middleware/default.rs +++ b/src/service/http/middleware/default.rs @@ -1,6 +1,7 @@ use crate::app::context::AppContext; use crate::service::http::middleware::catch_panic::CatchPanicMiddleware; use crate::service::http::middleware::compression::RequestDecompressionMiddleware; +use crate::service::http::middleware::cors::CorsMiddleware; use crate::service::http::middleware::request_id::{ PropagateRequestIdMiddleware, SetRequestIdMiddleware, }; @@ -26,6 +27,7 @@ pub fn default_middleware( Box::new(RequestDecompressionMiddleware), Box::new(TimeoutMiddleware), Box::new(RequestBodyLimitMiddleware), + Box::new(CorsMiddleware), ]; middleware .into_iter() @@ -38,13 +40,22 @@ pub fn default_middleware( mod tests { use crate::app::context::AppContext; use crate::config::app_config::AppConfig; - use rstest::rstest; + use crate::util::test_util::TestCase; + use insta::assert_toml_snapshot; + use itertools::Itertools; + use rstest::{fixture, rstest}; + + #[fixture] + #[cfg_attr(coverage_nightly, coverage(off))] + fn case() -> TestCase { + Default::default() + } #[rstest] - #[case(true, 9)] - #[case(false, 0)] + #[case(false)] + #[case(true)] #[cfg_attr(coverage_nightly, coverage(off))] - fn default_middleware(#[case] default_enable: bool, #[case] expected_size: usize) { + fn default_middleware(_case: TestCase, #[case] default_enable: bool) { // Arrange let mut config = AppConfig::test(None).unwrap(); config.service.http.custom.middleware.default_enable = default_enable; @@ -53,8 +64,9 @@ mod tests { // Act let middleware = super::default_middleware(&context); + let middleware = middleware.keys().collect_vec(); // Assert - assert_eq!(middleware.len(), expected_size); + assert_toml_snapshot!(middleware); } } diff --git a/src/service/http/middleware/mod.rs b/src/service/http/middleware/mod.rs index c543d664..e641b2cf 100644 --- a/src/service/http/middleware/mod.rs +++ b/src/service/http/middleware/mod.rs @@ -1,5 +1,6 @@ pub mod catch_panic; pub mod compression; +pub mod cors; pub mod default; pub mod request_id; pub mod sensitive_headers; diff --git a/src/service/http/middleware/request_id.rs b/src/service/http/middleware/request_id.rs index 70e213a7..baabb714 100644 --- a/src/service/http/middleware/request_id.rs +++ b/src/service/http/middleware/request_id.rs @@ -6,10 +6,11 @@ use axum::Router; use serde_derive::{Deserialize, Serialize}; use std::str::FromStr; use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer}; +use validator::Validate; pub const REQUEST_ID_HEADER_NAME: &str = "request-id"; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct CommonRequestIdConfig { pub header_name: String, @@ -23,14 +24,14 @@ impl Default for CommonRequestIdConfig { } } -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct SetRequestIdConfig { #[serde(flatten)] pub common: CommonRequestIdConfig, } -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct PropagateRequestIdConfig { #[serde(flatten)] diff --git a/src/service/http/middleware/sensitive_headers.rs b/src/service/http/middleware/sensitive_headers.rs index 3c60324e..2b39c976 100644 --- a/src/service/http/middleware/sensitive_headers.rs +++ b/src/service/http/middleware/sensitive_headers.rs @@ -10,8 +10,9 @@ use crate::error::RoadsterResult; use tower_http::sensitive_headers::{ SetSensitiveRequestHeadersLayer, SetSensitiveResponseHeadersLayer, }; +use validator::Validate; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct CommonSensitiveHeadersConfig { pub header_names: Vec, @@ -41,14 +42,14 @@ impl CommonSensitiveHeadersConfig { } } -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct SensitiveRequestHeadersConfig { #[serde(flatten)] pub common: CommonSensitiveHeadersConfig, } -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct SensitiveResponseHeadersConfig { #[serde(flatten)] @@ -131,6 +132,7 @@ impl Middleware for SensitiveResponseHeadersMiddlew .common .priority } + fn install(&self, router: Router, context: &AppContext) -> RoadsterResult { let headers = context .config() diff --git a/src/service/http/middleware/size_limit.rs b/src/service/http/middleware/size_limit.rs index 871c81e5..36148bed 100644 --- a/src/service/http/middleware/size_limit.rs +++ b/src/service/http/middleware/size_limit.rs @@ -8,8 +8,9 @@ use byte_unit::Byte; use byte_unit::Unit::MB; use serde_derive::{Deserialize, Serialize}; use tower_http::limit::RequestBodyLimitLayer; +use validator::Validate; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct SizeLimitConfig { pub limit: Byte, diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_1.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_1.snap new file mode 100644 index 00000000..d3187d02 --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_1.snap @@ -0,0 +1,6 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'any' diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_2.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_2.snap new file mode 100644 index 00000000..07cdb0cf --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_2.snap @@ -0,0 +1,6 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'mirror-request' diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_3.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_3.snap new file mode 100644 index 00000000..48c3cb5f --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_headers@case_3.snap @@ -0,0 +1,10 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'list' +headers = [ + 'foo', + 'bar', +] diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_1.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_1.snap new file mode 100644 index 00000000..d3187d02 --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_1.snap @@ -0,0 +1,6 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'any' diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_2.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_2.snap new file mode 100644 index 00000000..07cdb0cf --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_2.snap @@ -0,0 +1,6 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'mirror-request' diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_3.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_3.snap new file mode 100644 index 00000000..60c5a31f --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_3.snap @@ -0,0 +1,7 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'exact' +method = 'foo' diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_4.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_4.snap new file mode 100644 index 00000000..940a736d --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_methods@case_4.snap @@ -0,0 +1,10 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'list' +methods = [ + 'foo', + 'bar', +] diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_1.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_1.snap new file mode 100644 index 00000000..d3187d02 --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_1.snap @@ -0,0 +1,6 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'any' diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_2.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_2.snap new file mode 100644 index 00000000..07cdb0cf --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_2.snap @@ -0,0 +1,6 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'mirror-request' diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_3.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_3.snap new file mode 100644 index 00000000..4bcdd100 --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_3.snap @@ -0,0 +1,7 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'exact' +origin = 'foo' diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_4.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_4.snap new file mode 100644 index 00000000..8c5e1a6b --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_allow_origins@case_4.snap @@ -0,0 +1,10 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'list' +origins = [ + 'foo', + 'bar', +] diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_expose_headers@case_1.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_expose_headers@case_1.snap new file mode 100644 index 00000000..d3187d02 --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_expose_headers@case_1.snap @@ -0,0 +1,6 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'any' diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_expose_headers@case_2.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_expose_headers@case_2.snap new file mode 100644 index 00000000..48c3cb5f --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__cors__tests__deserialize_cors_expose_headers@case_2.snap @@ -0,0 +1,10 @@ +--- +source: src/service/http/middleware/cors.rs +expression: value +--- +[inner] +type = 'list' +headers = [ + 'foo', + 'bar', +] diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__default__tests__default_middleware@case_1.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__default__tests__default_middleware@case_1.snap new file mode 100644 index 00000000..3b396331 --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__default__tests__default_middleware@case_1.snap @@ -0,0 +1,5 @@ +--- +source: src/service/http/middleware/default.rs +expression: middleware +--- +[] diff --git a/src/service/http/middleware/snapshots/roadster__service__http__middleware__default__tests__default_middleware@case_2.snap b/src/service/http/middleware/snapshots/roadster__service__http__middleware__default__tests__default_middleware@case_2.snap new file mode 100644 index 00000000..5b9a2efd --- /dev/null +++ b/src/service/http/middleware/snapshots/roadster__service__http__middleware__default__tests__default_middleware@case_2.snap @@ -0,0 +1,16 @@ +--- +source: src/service/http/middleware/default.rs +expression: middleware +--- +[ + 'catch-panic', + 'cors', + 'propagate-request-id', + 'request-body-size-limit', + 'request-decompression', + 'sensitive-request-headers', + 'sensitive-response-headers', + 'set-request-id', + 'timeout', + 'tracing', +] diff --git a/src/service/http/middleware/timeout.rs b/src/service/http/middleware/timeout.rs index 954434bf..7d254c66 100644 --- a/src/service/http/middleware/timeout.rs +++ b/src/service/http/middleware/timeout.rs @@ -6,9 +6,10 @@ use serde_derive::{Deserialize, Serialize}; use serde_with::serde_as; use std::time::Duration; use tower_http::timeout::TimeoutLayer; +use validator::Validate; #[serde_as] -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct TimeoutConfig { #[serde_as(as = "serde_with::DurationMilliSeconds")] diff --git a/src/service/http/middleware/tracing.rs b/src/service/http/middleware/tracing.rs index 75e780a1..7200d54c 100644 --- a/src/service/http/middleware/tracing.rs +++ b/src/service/http/middleware/tracing.rs @@ -11,8 +11,9 @@ use serde_derive::{Deserialize, Serialize}; use std::time::Duration; use tower_http::trace::{DefaultOnResponse, MakeSpan, OnRequest, OnResponse, TraceLayer}; use tracing::{event, field, info_span, Level, Span, Value}; +use validator::Validate; -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Validate, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", default)] pub struct TracingConfig {} diff --git a/src/service/worker/sidekiq/app_worker.rs b/src/service/worker/sidekiq/app_worker.rs index e285a86d..ac71c518 100644 --- a/src/service/worker/sidekiq/app_worker.rs +++ b/src/service/worker/sidekiq/app_worker.rs @@ -125,13 +125,9 @@ where #[cfg(test)] mod tests { use super::*; + use crate::util::serde_util::Wrapper; use serde_json::from_str; - #[derive(Debug, Deserialize, Serialize)] - struct Wrapper { - inner: T, - } - #[test] #[cfg_attr(coverage_nightly, coverage(off))] fn deserialize_config_override_max_retries() { diff --git a/src/util/serde_util.rs b/src/util/serde_util.rs index cef0bf08..5bcc022f 100644 --- a/src/util/serde_util.rs +++ b/src/util/serde_util.rs @@ -62,19 +62,19 @@ pub(crate) fn empty_json_object() -> impl for<'de> Deserializer<'de> { Value::Object(Map::new()).into_deserializer() } +#[cfg(test)] +#[derive(Debug, Deserialize, Serialize)] +pub(crate) struct Wrapper { + pub inner: T, +} + #[cfg(test)] mod tests { use super::*; - use serde_derive::{Deserialize, Serialize}; use serde_json::from_str; use std::str::FromStr; use url::Url; - #[derive(Debug, Deserialize, Serialize)] - struct Wrapper { - inner: T, - } - #[test] #[cfg_attr(coverage_nightly, coverage(off))] fn deserialize_uri_or_string_as_uri() {