diff --git a/examples/full/config/default.toml b/examples/full/config/default.toml index 7ebe4858..bf4b786b 100644 --- a/examples/full/config/default.toml +++ b/examples/full/config/default.toml @@ -8,6 +8,10 @@ name = "Full Example" host = "127.0.0.1" port = 3000 +[service.http.middleware.hello-world] +enable = true +priority = 5 + [service.grpc] host = "127.0.0.1" port = 3001 diff --git a/examples/full/config/development.toml b/examples/full/config/development.toml index b12d21a1..bf45e9d1 100644 --- a/examples/full/config/development.toml +++ b/examples/full/config/development.toml @@ -3,3 +3,10 @@ secret = "secret-dev" [database] uri = "postgres://roadster:roadster@localhost:5432/example_dev" + +# Sidekiq fails to connect a lot locally for some reason. Uncomment the below configs to disable it temporarily. +#[service.sidekiq] +#num-workers = 0 +#queues = [] +#[health-check.sidekiq] +#enable = false diff --git a/examples/full/src/api/http/mod.rs b/examples/full/src/api/http/mod.rs index a790f6e6..c381d13e 100644 --- a/examples/full/src/api/http/mod.rs +++ b/examples/full/src/api/http/mod.rs @@ -1,8 +1,20 @@ use crate::app_state::AppState; use aide::axum::ApiRouter; +use axum::extract::Request; +use axum::middleware::Next; +use axum::response::Response; +use roadster::service::http::middleware::any::AnyMiddleware; +use roadster::service::http::middleware::Middleware; +use tracing::info; pub mod example; pub fn routes(parent: &str) -> ApiRouter { ApiRouter::new().merge(example::routes(parent)) } + +pub(crate) async fn hello_world_middleware_fn(request: Request, next: Next) -> Response { + info!("Running `hello-world` middleware"); + + next.run(request).await +} diff --git a/examples/full/src/app.rs b/examples/full/src/app.rs index d2a246cc..3975bd01 100644 --- a/examples/full/src/app.rs +++ b/examples/full/src/app.rs @@ -1,6 +1,7 @@ #[cfg(feature = "grpc")] use crate::api::grpc::routes; use crate::api::http; +use crate::api::http::hello_world_middleware_fn; use crate::app_state::AppState; use crate::cli::AppCli; use crate::service::example::example_service; @@ -15,6 +16,7 @@ use roadster::error::RoadsterResult; use roadster::service::function::service::FunctionService; #[cfg(feature = "grpc")] use roadster::service::grpc::service::GrpcService; +use roadster::service::http::middleware::any::AnyMiddleware; use roadster::service::http::service::HttpService; use roadster::service::registry::ServiceRegistry; use roadster::service::worker::sidekiq::app_worker::AppWorker; @@ -47,7 +49,16 @@ impl RoadsterApp for App { ) -> RoadsterResult<()> { registry .register_builder( - HttpService::builder(Some(BASE), state).api_router(http::routes(BASE)), + HttpService::builder(Some(BASE), state) + .api_router(http::routes(BASE)) + .middleware( + AnyMiddleware::builder() + .name("hello-world") + .layer_provider(|_state| { + axum::middleware::from_fn(hello_world_middleware_fn) + }) + .build(), + )?, ) .await?; diff --git a/src/config/mod.rs b/src/config/mod.rs index 60d0c21d..b7f18af7 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -86,7 +86,7 @@ pub const ENV_VAR_PREFIX: &str = "ROADSTER"; pub const ENV_VAR_SEPARATOR: &str = "__"; cfg_if! { - if #[cfg(feature = "config-yaml")] { + if #[cfg(feature = "config-yml")] { pub const FILE_EXTENSIONS: [&str; 3] = ["toml", "yaml", "yml"]; } else { pub const FILE_EXTENSIONS: [&str; 1] = ["toml"]; diff --git a/src/config/service/http/middleware.rs b/src/config/service/http/middleware.rs index 450ee521..a617e799 100644 --- a/src/config/service/http/middleware.rs +++ b/src/config/service/http/middleware.rs @@ -93,6 +93,7 @@ pub struct CommonConfig { #[serde(skip_serializing_if = "Option::is_none")] #[serde(default)] pub enable: Option, + #[serde(default)] pub priority: i32, } diff --git a/src/service/http/middleware/any.rs b/src/service/http/middleware/any.rs new file mode 100644 index 00000000..bcc29e7b --- /dev/null +++ b/src/service/http/middleware/any.rs @@ -0,0 +1,95 @@ +use crate::app::context::AppContext; +use crate::error::RoadsterResult; +use crate::service::http::middleware::Middleware; +use axum::routing::Route; +use axum::Router; +use axum_core::extract::{FromRef, Request}; +use axum_core::response::IntoResponse; +use std::convert::Infallible; +use tower::{Layer, Service}; +use typed_builder::TypedBuilder; + +#[derive(TypedBuilder)] +pub struct AnyMiddleware +where + S: Clone + Send + Sync + 'static, + AppContext: FromRef, + // Layer constrains copied from https://docs.rs/axum/0.7.7/axum/routing/struct.Router.html#method.layer + L: Layer + Clone + Send + 'static, + L::Service: Service + Clone + Send + 'static, + >::Response: IntoResponse + 'static, + >::Error: Into + 'static, + >::Future: Send + 'static, +{ + #[builder(setter(into))] + name: String, + #[builder(default, setter(strip_option))] + enabled: Option, + #[builder(default, setter(strip_option))] + priority: Option, + #[builder(setter(transform = |p: impl Fn(&S) -> L + Send + 'static| to_box_fn(p) ))] + layer_provider: Box L + Send>, +} + +fn to_box_fn(p: impl Fn(&S) -> L + Send + 'static) -> Box L + Send> { + Box::new(p) +} + +impl Middleware for AnyMiddleware +where + S: Clone + Send + Sync + 'static, + AppContext: FromRef, + // Layer constrains copied from https://docs.rs/axum/0.7.7/axum/routing/struct.Router.html#method.layer + L: Layer + Clone + Send + 'static, + L::Service: Service + Clone + Send + 'static, + >::Response: IntoResponse + 'static, + >::Error: Into + 'static, + >::Future: Send + 'static, +{ + fn name(&self) -> String { + self.name.clone() + } + + fn enabled(&self, state: &S) -> bool { + let context = AppContext::from_ref(state); + let config = context + .config() + .service + .http + .custom + .middleware + .custom + .get(&self.name); + if let Some(config) = config { + config.common.enabled(state) + } else { + context + .config() + .service + .http + .custom + .middleware + .default_enable + || self.enabled.unwrap_or_default() + } + } + + fn priority(&self, state: &S) -> i32 { + AppContext::from_ref(state) + .config() + .service + .http + .custom + .middleware + .custom + .get(&self.name) + .map(|config| config.common.priority) + .unwrap_or_else(|| self.priority.unwrap_or_default()) + } + + fn install(&self, router: Router, state: &S) -> RoadsterResult { + let router = router.layer((self.layer_provider)(state)); + + Ok(router) + } +} diff --git a/src/service/http/middleware/mod.rs b/src/service/http/middleware/mod.rs index 332e0506..e1d0eabd 100644 --- a/src/service/http/middleware/mod.rs +++ b/src/service/http/middleware/mod.rs @@ -1,3 +1,4 @@ +pub mod any; pub mod catch_panic; pub mod compression; pub mod cors;