Skip to content

Commit

Permalink
feat: Add AnyMiddleware to minimize boilerplate for Axum middleware
Browse files Browse the repository at this point in the history
If a consumer wants to use a middleware that's not already supported by
Roadster, they need to implement the `Middleware` trait, which is extra
boilerplate that could be annoying.

Add `AnyMiddleware` struct that implements the `Middleware` trait, so
consumers just need to provide the name of the middleware and the logic
to build/configure it.

Closes #470
  • Loading branch information
spencewenski committed Oct 21, 2024
1 parent e86381e commit 1786bf8
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 2 deletions.
4 changes: 4 additions & 0 deletions examples/full/config/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ name = "Full Example"
host = "127.0.0.1"
port = 3000

[service.http.middleware.hello-world]
enable = true
priority = 5

[service.grpc]
host = "127.0.0.1"
port = 3001
7 changes: 7 additions & 0 deletions examples/full/config/development.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,10 @@ secret = "secret-dev"

[database]
uri = "postgres://roadster:roadster@localhost:5432/example_dev"

# Sidekiq fails to connect a lot locally for some reason. Uncomment the below configs to disable it temporarily.
#[service.sidekiq]
#num-workers = 0
#queues = []
#[health-check.sidekiq]
#enable = false
12 changes: 12 additions & 0 deletions examples/full/src/api/http/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
use crate::app_state::AppState;
use aide::axum::ApiRouter;
use axum::extract::Request;
use axum::middleware::Next;
use axum::response::Response;
use roadster::service::http::middleware::any::AnyMiddleware;
use roadster::service::http::middleware::Middleware;
use tracing::info;

pub mod example;

pub fn routes(parent: &str) -> ApiRouter<AppState> {
ApiRouter::new().merge(example::routes(parent))
}

pub(crate) async fn hello_world_middleware_fn(request: Request, next: Next) -> Response {
info!("Running `hello-world` middleware");

next.run(request).await
}
13 changes: 12 additions & 1 deletion examples/full/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[cfg(feature = "grpc")]
use crate::api::grpc::routes;
use crate::api::http;
use crate::api::http::hello_world_middleware_fn;
use crate::app_state::AppState;
use crate::cli::AppCli;
use crate::service::example::example_service;
Expand All @@ -15,6 +16,7 @@ use roadster::error::RoadsterResult;
use roadster::service::function::service::FunctionService;
#[cfg(feature = "grpc")]
use roadster::service::grpc::service::GrpcService;
use roadster::service::http::middleware::any::AnyMiddleware;
use roadster::service::http::service::HttpService;
use roadster::service::registry::ServiceRegistry;
use roadster::service::worker::sidekiq::app_worker::AppWorker;
Expand Down Expand Up @@ -47,7 +49,16 @@ impl RoadsterApp<AppState> for App {
) -> RoadsterResult<()> {
registry
.register_builder(
HttpService::builder(Some(BASE), state).api_router(http::routes(BASE)),
HttpService::builder(Some(BASE), state)
.api_router(http::routes(BASE))
.middleware(
AnyMiddleware::builder()
.name("hello-world")
.layer_provider(|_state| {
axum::middleware::from_fn(hello_world_middleware_fn)
})
.build(),
)?,
)
.await?;

Expand Down
2 changes: 1 addition & 1 deletion src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub const ENV_VAR_PREFIX: &str = "ROADSTER";
pub const ENV_VAR_SEPARATOR: &str = "__";

cfg_if! {
if #[cfg(feature = "config-yaml")] {
if #[cfg(feature = "config-yml")] {
pub const FILE_EXTENSIONS: [&str; 3] = ["toml", "yaml", "yml"];
} else {
pub const FILE_EXTENSIONS: [&str; 1] = ["toml"];
Expand Down
1 change: 1 addition & 0 deletions src/config/service/http/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ pub struct CommonConfig {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub enable: Option<bool>,
#[serde(default)]
pub priority: i32,
}

Expand Down
95 changes: 95 additions & 0 deletions src/service/http/middleware/any.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use crate::app::context::AppContext;
use crate::error::RoadsterResult;
use crate::service::http::middleware::Middleware;
use axum::routing::Route;
use axum::Router;
use axum_core::extract::{FromRef, Request};
use axum_core::response::IntoResponse;
use std::convert::Infallible;
use tower::{Layer, Service};
use typed_builder::TypedBuilder;

#[derive(TypedBuilder)]
pub struct AnyMiddleware<S, L>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
// Layer constrains copied from https://docs.rs/axum/0.7.7/axum/routing/struct.Router.html#method.layer
L: Layer<Route> + Clone + Send + 'static,
L::Service: Service<Request> + Clone + Send + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static,
{
#[builder(setter(into))]
name: String,
#[builder(default, setter(strip_option))]
enabled: Option<bool>,
#[builder(default, setter(strip_option))]
priority: Option<i32>,
#[builder(setter(transform = |p: impl Fn(&S) -> L + Send + 'static| to_box_fn(p) ))]
layer_provider: Box<dyn Fn(&S) -> L + Send>,
}

fn to_box_fn<S, L>(p: impl Fn(&S) -> L + Send + 'static) -> Box<dyn Fn(&S) -> L + Send> {
Box::new(p)
}

impl<S, L> Middleware<S> for AnyMiddleware<S, L>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
// Layer constrains copied from https://docs.rs/axum/0.7.7/axum/routing/struct.Router.html#method.layer
L: Layer<Route> + Clone + Send + 'static,
L::Service: Service<Request> + Clone + Send + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static,
{
fn name(&self) -> String {
self.name.clone()
}

fn enabled(&self, state: &S) -> bool {
let context = AppContext::from_ref(state);
let config = context
.config()
.service
.http
.custom
.middleware
.custom
.get(&self.name);
if let Some(config) = config {
config.common.enabled(state)
} else {
context
.config()
.service
.http
.custom
.middleware
.default_enable
|| self.enabled.unwrap_or_default()
}
}

fn priority(&self, state: &S) -> i32 {
AppContext::from_ref(state)
.config()
.service
.http
.custom
.middleware
.custom
.get(&self.name)
.map(|config| config.common.priority)
.unwrap_or_else(|| self.priority.unwrap_or_default())
}

fn install(&self, router: Router, state: &S) -> RoadsterResult<Router> {
let router = router.layer((self.layer_provider)(state));

Ok(router)
}
}
1 change: 1 addition & 0 deletions src/service/http/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod any;
pub mod catch_panic;
pub mod compression;
pub mod cors;
Expand Down

0 comments on commit 1786bf8

Please sign in to comment.