From 61fe6f97d727e98f524cfc19bf5e683196e10a22 Mon Sep 17 00:00:00 2001 From: Spencer Ferris <3319370+spencewenski@users.noreply.github.com> Date: Sun, 31 Mar 2024 11:58:26 -0700 Subject: [PATCH] Allow using custom App::State in Initializer and Middleware traits --- .cargo-husky/hooks/pre-push | 6 +-- examples/minimal/src/app.rs | 3 -- src/app.rs | 29 ++++++++------- src/app_context.rs | 7 ++++ src/controller/middleware/catch_panic.rs | 8 ++-- src/controller/middleware/compression.rs | 17 +++++---- src/controller/middleware/default.rs | 2 +- src/controller/middleware/mod.rs | 36 ++++++++++++++++-- src/controller/middleware/request_id.rs | 16 ++++---- .../middleware/sensitive_headers.rs | 17 +++++---- src/controller/middleware/size_limit.rs | 8 ++-- src/controller/middleware/timeout.rs | 8 ++-- src/controller/middleware/tracing.rs | 10 ++--- src/initializer/default.rs | 2 +- src/initializer/mod.rs | 37 +++++++++++++++---- src/initializer/normalize_path.rs | 13 +++++-- 16 files changed, 142 insertions(+), 77 deletions(-) diff --git a/.cargo-husky/hooks/pre-push b/.cargo-husky/hooks/pre-push index 6a0f996f..1c9d577a 100755 --- a/.cargo-husky/hooks/pre-push +++ b/.cargo-husky/hooks/pre-push @@ -7,9 +7,6 @@ set -e echo "### fmt --all --check ###" cargo fmt --all --check -echo "### cargo doc --all-features --no-deps ###" -RUSTDOCFLAGS="-D rustdoc::all" cargo doc --all-features --no-deps - # With no features echo "### test --no-default-features --workspace ###" cargo test --no-default-features --workspace @@ -33,3 +30,6 @@ echo "### check --all-features --workspace ###" cargo check --all-features --workspace echo "### clippy --workspace --all-targets --all-features -- -D warnings ###" cargo clippy --workspace --all-targets --all-features -- -D warnings + +echo "### cargo doc --all-features --no-deps ###" +RUSTDOCFLAGS="-D rustdoc::all" cargo doc --all-features --no-deps diff --git a/examples/minimal/src/app.rs b/examples/minimal/src/app.rs index 174ee956..1e24b752 100644 --- a/examples/minimal/src/app.rs +++ b/examples/minimal/src/app.rs @@ -1,6 +1,3 @@ -// The RoadsterApp trait uses `AppContext`, so allow an exception in order to implement the trait. -#![allow(clippy::disallowed_types)] - use aide::axum::ApiRouter; use roadster::app::App as RoadsterApp; use roadster::config::app_config::AppConfig; diff --git a/src/app.rs b/src/app.rs index 712295c9..89d92d25 100644 --- a/src/app.rs +++ b/src/app.rs @@ -126,21 +126,21 @@ where let initializers = default_initializers() .into_iter() .chain(A::initializers(&context)) - .filter(|initializer| initializer.enabled(&context)) + .filter(|initializer| initializer.enabled(&context, &state)) .unique_by(|initializer| initializer.name()) - .sorted_by(|a, b| Ord::cmp(&a.priority(&context), &b.priority(&context))) + .sorted_by(|a, b| Ord::cmp(&a.priority(&context, &state), &b.priority(&context, &state))) .collect_vec(); let router = initializers .iter() .try_fold(router, |router, initializer| { - initializer.after_router(router, &context) + initializer.after_router(router, &context, &state) })?; let router = initializers .iter() .try_fold(router, |router, initializer| { - initializer.before_middleware(router, &context) + initializer.before_middleware(router, &context, &state) })?; // Install middleware, both the default middleware and any provided by the consumer. @@ -148,26 +148,26 @@ where let router = default_middleware() .into_iter() .chain(A::middleware(&context, &state).into_iter()) - .filter(|middleware| middleware.enabled(&context)) + .filter(|middleware| middleware.enabled(&context, &state)) .unique_by(|middleware| middleware.name()) - .sorted_by(|a, b| Ord::cmp(&a.priority(&context), &b.priority(&context))) + .sorted_by(|a, b| Ord::cmp(&a.priority(&context, &state), &b.priority(&context, &state))) // Reverse due to how Axum's `Router#layer` method adds middleware. .rev() .try_fold(router, |router, middleware| { info!("Installing middleware: `{}`", middleware.name()); - middleware.install(router, &context) + middleware.install(router, &context, &state) })?; let router = initializers .iter() .try_fold(router, |router, initializer| { - initializer.after_middleware(router, &context) + initializer.after_middleware(router, &context, &state) })?; let router = initializers .iter() .try_fold(router, |router, initializer| { - initializer.before_serve(router, &context) + initializer.before_serve(router, &context, &state) })?; #[cfg(feature = "sidekiq")] @@ -289,9 +289,9 @@ pub trait App { } /// Convert the [AppContext] to the custom [Self::State] that will be used throughout the app. - /// The conversion should mostly happen in a [`From`] implementation, but this + /// The conversion can simply happen in a [`From`] implementation, but this /// method is provided in case there's any additional work that needs to be done that the - /// consumer doesn't want to put in a [`From`] implementation. For example, any + /// consumer can't put in a [`From`] implementation. For example, any /// configuration that needs to happen in an async method. async fn context_to_state(context: Arc) -> anyhow::Result { let state = Self::State::from(context); @@ -312,11 +312,14 @@ pub trait App { } } - fn middleware(_context: &AppContext, _state: &Self::State) -> Vec> { + fn middleware( + _context: &AppContext, + _state: &Self::State, + ) -> Vec>> { Default::default() } - fn initializers(_context: &AppContext) -> Vec> { + fn initializers(_context: &AppContext) -> Vec>> { Default::default() } diff --git a/src/app_context.rs b/src/app_context.rs index 65b48b3e..8144fb8a 100644 --- a/src/app_context.rs +++ b/src/app_context.rs @@ -42,3 +42,10 @@ impl AppContext { Ok(context) } } + +/// Implemented so consumers can use [AppContext] as their [crate::app::App::State] if they want. +impl From> for AppContext { + fn from(value: Arc) -> Self { + value.as_ref().clone() + } +} diff --git a/src/controller/middleware/catch_panic.rs b/src/controller/middleware/catch_panic.rs index 8cf20705..c69115d2 100644 --- a/src/controller/middleware/catch_panic.rs +++ b/src/controller/middleware/catch_panic.rs @@ -9,12 +9,12 @@ use tower_http::catch_panic::CatchPanicLayer; pub struct CatchPanicConfig {} pub struct CatchPanicMiddleware; -impl Middleware for CatchPanicMiddleware { +impl Middleware for CatchPanicMiddleware { fn name(&self) -> String { "catch-panic".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context .config .middleware @@ -23,11 +23,11 @@ impl Middleware for CatchPanicMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context.config.middleware.catch_panic.common.priority } - fn install(&self, router: Router, _context: &AppContext) -> anyhow::Result { + fn install(&self, router: Router, _context: &AppContext, _state: &S) -> anyhow::Result { let router = router.layer(CatchPanicLayer::new()); Ok(router) diff --git a/src/controller/middleware/compression.rs b/src/controller/middleware/compression.rs index 35293f02..2f76ba7c 100644 --- a/src/controller/middleware/compression.rs +++ b/src/controller/middleware/compression.rs @@ -2,6 +2,7 @@ use crate::app_context::AppContext; use crate::controller::middleware::Middleware; use axum::Router; use serde_derive::{Deserialize, Serialize}; + use tower_http::compression::CompressionLayer; use tower_http::decompression::RequestDecompressionLayer; @@ -14,12 +15,12 @@ pub struct ResponseCompressionConfig {} pub struct RequestDecompressionConfig {} pub struct ResponseCompressionMiddleware; -impl Middleware for ResponseCompressionMiddleware { +impl Middleware for ResponseCompressionMiddleware { fn name(&self) -> String { "response-compression".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context .config .middleware @@ -28,7 +29,7 @@ impl Middleware for ResponseCompressionMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context .config .middleware @@ -37,7 +38,7 @@ impl Middleware for ResponseCompressionMiddleware { .priority } - fn install(&self, router: Router, _context: &AppContext) -> anyhow::Result { + fn install(&self, router: Router, _context: &AppContext, _state: &S) -> anyhow::Result { let router = router.layer(CompressionLayer::new()); Ok(router) @@ -45,12 +46,12 @@ impl Middleware for ResponseCompressionMiddleware { } pub struct RequestDecompressionMiddleware; -impl Middleware for RequestDecompressionMiddleware { +impl Middleware for RequestDecompressionMiddleware { fn name(&self) -> String { "request-decompression".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context .config .middleware @@ -59,7 +60,7 @@ impl Middleware for RequestDecompressionMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context .config .middleware @@ -68,7 +69,7 @@ impl Middleware for RequestDecompressionMiddleware { .priority } - fn install(&self, router: Router, _context: &AppContext) -> anyhow::Result { + fn install(&self, router: Router, _context: &AppContext, _state: &S) -> anyhow::Result { let router = router.layer(RequestDecompressionLayer::new()); Ok(router) diff --git a/src/controller/middleware/default.rs b/src/controller/middleware/default.rs index adcab7ac..16e3ffec 100644 --- a/src/controller/middleware/default.rs +++ b/src/controller/middleware/default.rs @@ -11,7 +11,7 @@ use crate::controller::middleware::timeout::TimeoutMiddleware; use crate::controller::middleware::tracing::TracingMiddleware; use crate::controller::middleware::Middleware; -pub fn default_middleware() -> Vec> { +pub fn default_middleware() -> Vec>> { vec![ Box::new(SensitiveRequestHeadersMiddleware), Box::new(SensitiveResponseHeadersMiddleware), diff --git a/src/controller/middleware/mod.rs b/src/controller/middleware/mod.rs index 59daead3..23e834c5 100644 --- a/src/controller/middleware/mod.rs +++ b/src/controller/middleware/mod.rs @@ -10,9 +10,37 @@ pub mod tracing; use crate::app_context::AppContext; use axum::Router; -pub trait Middleware { +/// Allows initializing and installing middleware on the app's [Router]. The type `S` is the +/// custom [crate::app::App::State] defined for the app. +/// +/// This trait is provided in addition to [crate::initializer::Initializer] because installing +/// middleware is a bit of a special case compared to a general initializer: +/// 1. The order in which middleware runs matters. For example, we want +/// [tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer] to run before +/// [tower_http::trace::TraceLayer] to avoid logging sensitive headers. +/// 2. Because of how axum's [Router::layer] method installs middleware, the order in which +/// middleware is installed is the reverse of the order it will run when handling a request. +/// Therefore, we install the middleware in the reverse order that we want it to run (this +/// is done automatically by Roadster based on [Middleware::priority]). +pub trait Middleware { fn name(&self) -> String; - fn enabled(&self, context: &AppContext) -> bool; - fn priority(&self, context: &AppContext) -> i32; - fn install(&self, router: Router, context: &AppContext) -> anyhow::Result; + fn enabled(&self, context: &AppContext, state: &S) -> bool; + /// Used to determine the order in which the middleware will run when handling a request. Smaller + /// numbers will run before larger numbers. For example, a middleware with priority `-10` + /// will run before a middleware with priority `10`. + /// + /// If two middlewares have the same priority, they are not guaranteed to run or be installed + /// in any particular order relative to each other. This may be fine for many middlewares. + /// + /// If the order in which your middleware runs doesn't particularly matter, it's generally + /// safe to set its priority as `0`. + /// + /// Note: Because of how axum's [Router::layer] method installs middleware, the order in which + /// middleware is installed is the reverse of the order it will run when handling a request. + /// Therefore, we install the middleware in the reverse order that we want it to run (this + /// is done automatically by Roadster based on [Middleware::priority]). So, a middleware + /// with priority `-10` will be _installed after_ a middleware with priority `10`, which will + /// allow the middleware with priority `-10` to _run before_ a middleware with priority `10`. + fn priority(&self, context: &AppContext, state: &S) -> i32; + fn install(&self, router: Router, context: &AppContext, state: &S) -> anyhow::Result; } diff --git a/src/controller/middleware/request_id.rs b/src/controller/middleware/request_id.rs index 3607d8ae..195ee77e 100644 --- a/src/controller/middleware/request_id.rs +++ b/src/controller/middleware/request_id.rs @@ -37,12 +37,12 @@ pub struct PropagateRequestIdConfig { } pub struct SetRequestIdMiddleware; -impl Middleware for SetRequestIdMiddleware { +impl Middleware for SetRequestIdMiddleware { fn name(&self) -> String { "set-request-id".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context .config .middleware @@ -51,11 +51,11 @@ impl Middleware for SetRequestIdMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context.config.middleware.set_request_id.common.priority } - fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { let header_name = &context .config .middleware @@ -74,12 +74,12 @@ impl Middleware for SetRequestIdMiddleware { } pub struct PropagateRequestIdMiddleware; -impl Middleware for PropagateRequestIdMiddleware { +impl Middleware for PropagateRequestIdMiddleware { fn name(&self) -> String { "propagate-request-id".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context .config .middleware @@ -88,7 +88,7 @@ impl Middleware for PropagateRequestIdMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context .config .middleware @@ -97,7 +97,7 @@ impl Middleware for PropagateRequestIdMiddleware { .priority } - fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { let header_name = &context .config .middleware diff --git a/src/controller/middleware/sensitive_headers.rs b/src/controller/middleware/sensitive_headers.rs index 57f870fa..9d580621 100644 --- a/src/controller/middleware/sensitive_headers.rs +++ b/src/controller/middleware/sensitive_headers.rs @@ -5,6 +5,7 @@ use axum::Router; use itertools::Itertools; use serde_derive::{Deserialize, Serialize}; use std::str::FromStr; + use tower_http::sensitive_headers::{ SetSensitiveRequestHeadersLayer, SetSensitiveResponseHeadersLayer, }; @@ -55,12 +56,12 @@ pub struct SensitiveResponseHeadersConfig { pub struct SensitiveRequestHeadersMiddleware; -impl Middleware for SensitiveRequestHeadersMiddleware { +impl Middleware for SensitiveRequestHeadersMiddleware { fn name(&self) -> String { "sensitive-request-headers".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context .config .middleware @@ -69,7 +70,7 @@ impl Middleware for SensitiveRequestHeadersMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context .config .middleware @@ -77,7 +78,7 @@ impl Middleware for SensitiveRequestHeadersMiddleware { .common .priority } - fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { let headers = context .config .middleware @@ -94,12 +95,12 @@ impl Middleware for SensitiveRequestHeadersMiddleware { pub struct SensitiveResponseHeadersMiddleware; -impl Middleware for SensitiveResponseHeadersMiddleware { +impl Middleware for SensitiveResponseHeadersMiddleware { fn name(&self) -> String { "sensitive-response-headers".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context .config .middleware @@ -108,7 +109,7 @@ impl Middleware for SensitiveResponseHeadersMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context .config .middleware @@ -116,7 +117,7 @@ impl Middleware for SensitiveResponseHeadersMiddleware { .common .priority } - fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { let headers = context .config .middleware diff --git a/src/controller/middleware/size_limit.rs b/src/controller/middleware/size_limit.rs index 25fda618..d933cf44 100644 --- a/src/controller/middleware/size_limit.rs +++ b/src/controller/middleware/size_limit.rs @@ -23,20 +23,20 @@ impl Default for SizeLimitConfig { } pub struct RequestBodyLimitMiddleware; -impl Middleware for RequestBodyLimitMiddleware { +impl Middleware for RequestBodyLimitMiddleware { fn name(&self) -> String { "request-body-size-limit".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context.config.middleware.size_limit.common.enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context.config.middleware.size_limit.common.priority } - fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { let limit = &context .config .middleware diff --git a/src/controller/middleware/timeout.rs b/src/controller/middleware/timeout.rs index 72d7ad9d..df43171a 100644 --- a/src/controller/middleware/timeout.rs +++ b/src/controller/middleware/timeout.rs @@ -23,20 +23,20 @@ impl Default for TimeoutConfig { } pub struct TimeoutMiddleware; -impl Middleware for TimeoutMiddleware { +impl Middleware for TimeoutMiddleware { fn name(&self) -> String { "timeout".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context.config.middleware.timeout.common.enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context.config.middleware.timeout.common.priority } - fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { let timeout = &context.config.middleware.timeout.custom.timeout; let router = router.layer(TimeoutLayer::new(*timeout)); diff --git a/src/controller/middleware/tracing.rs b/src/controller/middleware/tracing.rs index e34ccb71..8ffd2d02 100644 --- a/src/controller/middleware/tracing.rs +++ b/src/controller/middleware/tracing.rs @@ -16,21 +16,21 @@ use tracing::{event, field, info_span, Level, Span, Value}; pub struct TracingConfig {} pub struct TracingMiddleware; -impl Middleware for TracingMiddleware { +impl Middleware for TracingMiddleware { fn name(&self) -> String { "tracing".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context.config.middleware.tracing.common.enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context.config.middleware.tracing.common.priority } - fn install(&self, router: Router, _context: &AppContext) -> anyhow::Result { - let request_id_header_name = &_context + fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { + let request_id_header_name = &context .config .middleware .set_request_id diff --git a/src/initializer/default.rs b/src/initializer/default.rs index 237d4c1c..694182bb 100644 --- a/src/initializer/default.rs +++ b/src/initializer/default.rs @@ -1,6 +1,6 @@ use crate::initializer::normalize_path::NormalizePathInitializer; use crate::initializer::Initializer; -pub fn default_initializers() -> Vec> { +pub fn default_initializers() -> Vec>> { vec![Box::new(NormalizePathInitializer)] } diff --git a/src/initializer/mod.rs b/src/initializer/mod.rs index 181b709f..e3f51776 100644 --- a/src/initializer/mod.rs +++ b/src/initializer/mod.rs @@ -4,26 +4,49 @@ pub mod normalize_path; use crate::app_context::AppContext; use axum::Router; -pub trait Initializer { +/// Provides hooks into various stages of the app's startup to allow initializing and installing +/// anything that needs to be done during a specific stage of startup. The type `S` is the +/// custom [crate::app::App::State] defined for the app. +pub trait Initializer { fn name(&self) -> String; - fn enabled(&self, context: &AppContext) -> bool; + fn enabled(&self, context: &AppContext, state: &S) -> bool; - fn priority(&self, context: &AppContext) -> i32; + fn priority(&self, context: &AppContext, state: &S) -> i32; - fn after_router(&self, router: Router, _context: &AppContext) -> anyhow::Result { + fn after_router( + &self, + router: Router, + _context: &AppContext, + _state: &S, + ) -> anyhow::Result { Ok(router) } - fn before_middleware(&self, router: Router, _context: &AppContext) -> anyhow::Result { + fn before_middleware( + &self, + router: Router, + _context: &AppContext, + _state: &S, + ) -> anyhow::Result { Ok(router) } - fn after_middleware(&self, router: Router, _context: &AppContext) -> anyhow::Result { + fn after_middleware( + &self, + router: Router, + _context: &AppContext, + _state: &S, + ) -> anyhow::Result { Ok(router) } - fn before_serve(&self, router: Router, _context: &AppContext) -> anyhow::Result { + fn before_serve( + &self, + router: Router, + _context: &AppContext, + _state: &S, + ) -> anyhow::Result { Ok(router) } } diff --git a/src/initializer/normalize_path.rs b/src/initializer/normalize_path.rs index cf38a490..4c611d24 100644 --- a/src/initializer/normalize_path.rs +++ b/src/initializer/normalize_path.rs @@ -11,12 +11,12 @@ pub struct NormalizePathConfig {} pub struct NormalizePathInitializer; -impl Initializer for NormalizePathInitializer { +impl Initializer for NormalizePathInitializer { fn name(&self) -> String { "normalize-path".to_string() } - fn enabled(&self, context: &AppContext) -> bool { + fn enabled(&self, context: &AppContext, _state: &S) -> bool { context .config .initializer @@ -25,11 +25,16 @@ impl Initializer for NormalizePathInitializer { .enabled(context) } - fn priority(&self, context: &AppContext) -> i32 { + fn priority(&self, context: &AppContext, _state: &S) -> i32 { context.config.initializer.normalize_path.common.priority } - fn before_serve(&self, router: Router, _context: &AppContext) -> anyhow::Result { + fn before_serve( + &self, + router: Router, + _context: &AppContext, + _state: &S, + ) -> anyhow::Result { let router = NormalizePathLayer::trim_trailing_slash().layer(router); let router = Router::new().nest_service("/", router); Ok(router)