Skip to content

Commit

Permalink
Allow using custom App::State in Initializer and Middleware traits
Browse files Browse the repository at this point in the history
  • Loading branch information
spencewenski committed Mar 31, 2024
1 parent af913fb commit 61fe6f9
Show file tree
Hide file tree
Showing 16 changed files with 142 additions and 77 deletions.
6 changes: 3 additions & 3 deletions .cargo-husky/hooks/pre-push
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ set -e
echo "### fmt --all --check ###"
cargo fmt --all --check

echo "### cargo doc --all-features --no-deps ###"
RUSTDOCFLAGS="-D rustdoc::all" cargo doc --all-features --no-deps

# With no features
echo "### test --no-default-features --workspace ###"
cargo test --no-default-features --workspace
Expand All @@ -33,3 +30,6 @@ echo "### check --all-features --workspace ###"
cargo check --all-features --workspace
echo "### clippy --workspace --all-targets --all-features -- -D warnings ###"
cargo clippy --workspace --all-targets --all-features -- -D warnings

echo "### cargo doc --all-features --no-deps ###"
RUSTDOCFLAGS="-D rustdoc::all" cargo doc --all-features --no-deps
3 changes: 0 additions & 3 deletions examples/minimal/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// The RoadsterApp trait uses `AppContext`, so allow an exception in order to implement the trait.
#![allow(clippy::disallowed_types)]

use aide::axum::ApiRouter;
use roadster::app::App as RoadsterApp;
use roadster::config::app_config::AppConfig;
Expand Down
29 changes: 16 additions & 13 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,48 +126,48 @@ where
let initializers = default_initializers()
.into_iter()
.chain(A::initializers(&context))
.filter(|initializer| initializer.enabled(&context))
.filter(|initializer| initializer.enabled(&context, &state))
.unique_by(|initializer| initializer.name())
.sorted_by(|a, b| Ord::cmp(&a.priority(&context), &b.priority(&context)))
.sorted_by(|a, b| Ord::cmp(&a.priority(&context, &state), &b.priority(&context, &state)))
.collect_vec();

let router = initializers
.iter()
.try_fold(router, |router, initializer| {
initializer.after_router(router, &context)
initializer.after_router(router, &context, &state)
})?;

let router = initializers
.iter()
.try_fold(router, |router, initializer| {
initializer.before_middleware(router, &context)
initializer.before_middleware(router, &context, &state)
})?;

// Install middleware, both the default middleware and any provided by the consumer.
info!("Installing middleware. Note: the order of installation is the inverse of the order middleware will run when handling a request.");
let router = default_middleware()
.into_iter()
.chain(A::middleware(&context, &state).into_iter())
.filter(|middleware| middleware.enabled(&context))
.filter(|middleware| middleware.enabled(&context, &state))
.unique_by(|middleware| middleware.name())
.sorted_by(|a, b| Ord::cmp(&a.priority(&context), &b.priority(&context)))
.sorted_by(|a, b| Ord::cmp(&a.priority(&context, &state), &b.priority(&context, &state)))
// Reverse due to how Axum's `Router#layer` method adds middleware.
.rev()
.try_fold(router, |router, middleware| {
info!("Installing middleware: `{}`", middleware.name());
middleware.install(router, &context)
middleware.install(router, &context, &state)
})?;

let router = initializers
.iter()
.try_fold(router, |router, initializer| {
initializer.after_middleware(router, &context)
initializer.after_middleware(router, &context, &state)
})?;

let router = initializers
.iter()
.try_fold(router, |router, initializer| {
initializer.before_serve(router, &context)
initializer.before_serve(router, &context, &state)
})?;

#[cfg(feature = "sidekiq")]
Expand Down Expand Up @@ -289,9 +289,9 @@ pub trait App {
}

/// Convert the [AppContext] to the custom [Self::State] that will be used throughout the app.
/// The conversion should mostly happen in a [`From<AppContext>`] implementation, but this
/// The conversion can simply happen in a [`From<AppContext>`] implementation, but this
/// method is provided in case there's any additional work that needs to be done that the
/// consumer doesn't want to put in a [`From<AppContext>`] implementation. For example, any
/// consumer can't put in a [`From<AppContext>`] implementation. For example, any
/// configuration that needs to happen in an async method.
async fn context_to_state(context: Arc<AppContext>) -> anyhow::Result<Self::State> {
let state = Self::State::from(context);
Expand All @@ -312,11 +312,14 @@ pub trait App {
}
}

fn middleware(_context: &AppContext, _state: &Self::State) -> Vec<Box<dyn Middleware>> {
fn middleware(
_context: &AppContext,
_state: &Self::State,
) -> Vec<Box<dyn Middleware<Self::State>>> {
Default::default()
}

fn initializers(_context: &AppContext) -> Vec<Box<dyn Initializer>> {
fn initializers(_context: &AppContext) -> Vec<Box<dyn Initializer<Self::State>>> {
Default::default()
}

Expand Down
7 changes: 7 additions & 0 deletions src/app_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,10 @@ impl AppContext {
Ok(context)
}
}

/// Implemented so consumers can use [AppContext] as their [crate::app::App::State] if they want.
impl From<Arc<AppContext>> for AppContext {
fn from(value: Arc<AppContext>) -> Self {
value.as_ref().clone()
}
}
8 changes: 4 additions & 4 deletions src/controller/middleware/catch_panic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ use tower_http::catch_panic::CatchPanicLayer;
pub struct CatchPanicConfig {}

pub struct CatchPanicMiddleware;
impl Middleware for CatchPanicMiddleware {
impl<S> Middleware<S> for CatchPanicMiddleware {
fn name(&self) -> String {
"catch-panic".to_string()
}

fn enabled(&self, context: &AppContext) -> bool {
fn enabled(&self, context: &AppContext, _state: &S) -> bool {
context
.config
.middleware
Expand All @@ -23,11 +23,11 @@ impl Middleware for CatchPanicMiddleware {
.enabled(context)
}

fn priority(&self, context: &AppContext) -> i32 {
fn priority(&self, context: &AppContext, _state: &S) -> i32 {
context.config.middleware.catch_panic.common.priority
}

fn install(&self, router: Router, _context: &AppContext) -> anyhow::Result<Router> {
fn install(&self, router: Router, _context: &AppContext, _state: &S) -> anyhow::Result<Router> {
let router = router.layer(CatchPanicLayer::new());

Ok(router)
Expand Down
17 changes: 9 additions & 8 deletions src/controller/middleware/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::app_context::AppContext;
use crate::controller::middleware::Middleware;
use axum::Router;
use serde_derive::{Deserialize, Serialize};

use tower_http::compression::CompressionLayer;
use tower_http::decompression::RequestDecompressionLayer;

Expand All @@ -14,12 +15,12 @@ pub struct ResponseCompressionConfig {}
pub struct RequestDecompressionConfig {}

pub struct ResponseCompressionMiddleware;
impl Middleware for ResponseCompressionMiddleware {
impl<S> Middleware<S> for ResponseCompressionMiddleware {
fn name(&self) -> String {
"response-compression".to_string()
}

fn enabled(&self, context: &AppContext) -> bool {
fn enabled(&self, context: &AppContext, _state: &S) -> bool {
context
.config
.middleware
Expand All @@ -28,7 +29,7 @@ impl Middleware for ResponseCompressionMiddleware {
.enabled(context)
}

fn priority(&self, context: &AppContext) -> i32 {
fn priority(&self, context: &AppContext, _state: &S) -> i32 {
context
.config
.middleware
Expand All @@ -37,20 +38,20 @@ impl Middleware for ResponseCompressionMiddleware {
.priority
}

fn install(&self, router: Router, _context: &AppContext) -> anyhow::Result<Router> {
fn install(&self, router: Router, _context: &AppContext, _state: &S) -> anyhow::Result<Router> {
let router = router.layer(CompressionLayer::new());

Ok(router)
}
}

pub struct RequestDecompressionMiddleware;
impl Middleware for RequestDecompressionMiddleware {
impl<S> Middleware<S> for RequestDecompressionMiddleware {
fn name(&self) -> String {
"request-decompression".to_string()
}

fn enabled(&self, context: &AppContext) -> bool {
fn enabled(&self, context: &AppContext, _state: &S) -> bool {
context
.config
.middleware
Expand All @@ -59,7 +60,7 @@ impl Middleware for RequestDecompressionMiddleware {
.enabled(context)
}

fn priority(&self, context: &AppContext) -> i32 {
fn priority(&self, context: &AppContext, _state: &S) -> i32 {
context
.config
.middleware
Expand All @@ -68,7 +69,7 @@ impl Middleware for RequestDecompressionMiddleware {
.priority
}

fn install(&self, router: Router, _context: &AppContext) -> anyhow::Result<Router> {
fn install(&self, router: Router, _context: &AppContext, _state: &S) -> anyhow::Result<Router> {
let router = router.layer(RequestDecompressionLayer::new());

Ok(router)
Expand Down
2 changes: 1 addition & 1 deletion src/controller/middleware/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::controller::middleware::timeout::TimeoutMiddleware;
use crate::controller::middleware::tracing::TracingMiddleware;
use crate::controller::middleware::Middleware;

pub fn default_middleware() -> Vec<Box<dyn Middleware>> {
pub fn default_middleware<S>() -> Vec<Box<dyn Middleware<S>>> {
vec![
Box::new(SensitiveRequestHeadersMiddleware),
Box::new(SensitiveResponseHeadersMiddleware),
Expand Down
36 changes: 32 additions & 4 deletions src/controller/middleware/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,37 @@ pub mod tracing;
use crate::app_context::AppContext;
use axum::Router;

pub trait Middleware {
/// Allows initializing and installing middleware on the app's [Router]. The type `S` is the
/// custom [crate::app::App::State] defined for the app.
///
/// This trait is provided in addition to [crate::initializer::Initializer] because installing
/// middleware is a bit of a special case compared to a general initializer:
/// 1. The order in which middleware runs matters. For example, we want
/// [tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer] to run before
/// [tower_http::trace::TraceLayer] to avoid logging sensitive headers.
/// 2. Because of how axum's [Router::layer] method installs middleware, the order in which
/// middleware is installed is the reverse of the order it will run when handling a request.
/// Therefore, we install the middleware in the reverse order that we want it to run (this
/// is done automatically by Roadster based on [Middleware::priority]).
pub trait Middleware<S> {
fn name(&self) -> String;
fn enabled(&self, context: &AppContext) -> bool;
fn priority(&self, context: &AppContext) -> i32;
fn install(&self, router: Router, context: &AppContext) -> anyhow::Result<Router>;
fn enabled(&self, context: &AppContext, state: &S) -> bool;
/// Used to determine the order in which the middleware will run when handling a request. Smaller
/// numbers will run before larger numbers. For example, a middleware with priority `-10`
/// will run before a middleware with priority `10`.
///
/// If two middlewares have the same priority, they are not guaranteed to run or be installed
/// in any particular order relative to each other. This may be fine for many middlewares.
///
/// If the order in which your middleware runs doesn't particularly matter, it's generally
/// safe to set its priority as `0`.
///
/// Note: Because of how axum's [Router::layer] method installs middleware, the order in which
/// middleware is installed is the reverse of the order it will run when handling a request.
/// Therefore, we install the middleware in the reverse order that we want it to run (this
/// is done automatically by Roadster based on [Middleware::priority]). So, a middleware
/// with priority `-10` will be _installed after_ a middleware with priority `10`, which will
/// allow the middleware with priority `-10` to _run before_ a middleware with priority `10`.
fn priority(&self, context: &AppContext, state: &S) -> i32;
fn install(&self, router: Router, context: &AppContext, state: &S) -> anyhow::Result<Router>;
}
16 changes: 8 additions & 8 deletions src/controller/middleware/request_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ pub struct PropagateRequestIdConfig {
}

pub struct SetRequestIdMiddleware;
impl Middleware for SetRequestIdMiddleware {
impl<S> Middleware<S> for SetRequestIdMiddleware {
fn name(&self) -> String {
"set-request-id".to_string()
}

fn enabled(&self, context: &AppContext) -> bool {
fn enabled(&self, context: &AppContext, _state: &S) -> bool {
context
.config
.middleware
Expand All @@ -51,11 +51,11 @@ impl Middleware for SetRequestIdMiddleware {
.enabled(context)
}

fn priority(&self, context: &AppContext) -> i32 {
fn priority(&self, context: &AppContext, _state: &S) -> i32 {
context.config.middleware.set_request_id.common.priority
}

fn install(&self, router: Router, context: &AppContext) -> anyhow::Result<Router> {
fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result<Router> {
let header_name = &context
.config
.middleware
Expand All @@ -74,12 +74,12 @@ impl Middleware for SetRequestIdMiddleware {
}

pub struct PropagateRequestIdMiddleware;
impl Middleware for PropagateRequestIdMiddleware {
impl<S> Middleware<S> for PropagateRequestIdMiddleware {
fn name(&self) -> String {
"propagate-request-id".to_string()
}

fn enabled(&self, context: &AppContext) -> bool {
fn enabled(&self, context: &AppContext, _state: &S) -> bool {
context
.config
.middleware
Expand All @@ -88,7 +88,7 @@ impl Middleware for PropagateRequestIdMiddleware {
.enabled(context)
}

fn priority(&self, context: &AppContext) -> i32 {
fn priority(&self, context: &AppContext, _state: &S) -> i32 {
context
.config
.middleware
Expand All @@ -97,7 +97,7 @@ impl Middleware for PropagateRequestIdMiddleware {
.priority
}

fn install(&self, router: Router, context: &AppContext) -> anyhow::Result<Router> {
fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result<Router> {
let header_name = &context
.config
.middleware
Expand Down
Loading

0 comments on commit 61fe6f9

Please sign in to comment.