diff --git a/examples/minimal/src/app.rs b/examples/minimal/src/app.rs index e50df891..dac55895 100644 --- a/examples/minimal/src/app.rs +++ b/examples/minimal/src/app.rs @@ -30,14 +30,17 @@ impl RoadsterApp for App { state: Arc, ) -> anyhow::Result<()> { registry - .register_builder(HttpService::builder(BASE, &context).router(controller::routes(BASE))) + .register_builder( + HttpService::builder(BASE, &context, state.as_ref()) + .router(controller::routes(BASE)), + ) .await?; registry .register_builder( SidekiqWorkerService::builder(context.clone(), state.clone()) .await? - .register_app_worker(ExampleWorker::build(&state)), + .register_app_worker(ExampleWorker::build(&state))?, ) .await?; diff --git a/src/controller/mod.rs b/src/controller/mod.rs index 4ded12fc..20cfc773 100644 --- a/src/controller/mod.rs +++ b/src/controller/mod.rs @@ -40,6 +40,7 @@ pub fn default_routes(parent: &str, config: &AppConfig) -> ApiRouter where S: Clone + Send + Sync + 'static + Into>, { + // Todo: Allow disabling the default routes ApiRouter::new() .merge(ping::routes(parent)) .merge(health::routes(parent)) diff --git a/src/service/http/builder.rs b/src/service/http/builder.rs index 1d6a1ec1..f8555abb 100644 --- a/src/service/http/builder.rs +++ b/src/service/http/builder.rs @@ -13,12 +13,14 @@ use aide::axum::ApiRouter; use aide::openapi::OpenApi; #[cfg(feature = "open-api")] use aide::transform::TransformOpenApi; +use anyhow::bail; use async_trait::async_trait; #[cfg(feature = "open-api")] use axum::Extension; #[cfg(not(feature = "open-api"))] use axum::Router; use itertools::Itertools; +use std::collections::BTreeMap; #[cfg(feature = "open-api")] use std::sync::Arc; use tracing::info; @@ -30,22 +32,22 @@ pub struct HttpServiceBuilder { router: ApiRouter, #[cfg(feature = "open-api")] api_docs: Box TransformOpenApi + Send>, - middleware: Vec>>, - initializers: Vec>>, + middleware: BTreeMap>>, + initializers: BTreeMap>>, } impl HttpServiceBuilder { - pub fn new(path_root: &str, app_context: &AppContext) -> Self { + pub fn new(path_root: &str, context: &AppContext, state: &A::State) -> Self { #[cfg(feature = "open-api")] - let app_name = app_context.config.app.name.clone(); + let app_name = context.config.app.name.clone(); Self { - router: default_routes(path_root, &app_context.config), + router: default_routes(path_root, &context.config), #[cfg(feature = "open-api")] api_docs: Box::new(move |api| { api.title(&app_name).description(&format!("# {}", app_name)) }), - middleware: default_middleware(), - initializers: default_initializers(), + middleware: default_middleware(context, state), + initializers: default_initializers(context, state), } } @@ -70,14 +72,27 @@ impl HttpServiceBuilder { self } - pub fn initializer(mut self, initializer: Box>) -> Self { - self.initializers.push(initializer); - self + pub fn initializer( + mut self, + initializer: Box>, + ) -> anyhow::Result { + let name = initializer.name(); + if self + .initializers + .insert(name.clone(), initializer) + .is_some() + { + bail!("Initializer `{name}` was already registered"); + } + Ok(self) } - pub fn middleware(mut self, middleware: Box>) -> Self { - self.middleware.push(middleware); - self + pub fn middleware(mut self, middleware: Box>) -> anyhow::Result { + let name = middleware.name(); + if self.middleware.insert(name.clone(), middleware).is_some() { + bail!("Middleware `{name}` was already registered"); + } + Ok(self) } } @@ -102,9 +117,8 @@ impl AppServiceBuilder for HttpServiceBuilder { let initializers = self .initializers - .into_iter() + .values() .filter(|initializer| initializer.enabled(context, state)) - .unique_by(|initializer| initializer.name()) .sorted_by(|a, b| Ord::cmp(&a.priority(context, state), &b.priority(context, state))) .collect_vec(); @@ -123,9 +137,8 @@ impl AppServiceBuilder for HttpServiceBuilder { info!("Installing middleware. Note: the order of installation is the inverse of the order middleware will run when handling a request."); let router = self .middleware - .into_iter() + .values() .filter(|middleware| middleware.enabled(context, state)) - .unique_by(|middleware| middleware.name()) .sorted_by(|a, b| Ord::cmp(&a.priority(context, state), &b.priority(context, state))) // Reverse due to how Axum's `Router#layer` method adds middleware. .rev() diff --git a/src/service/http/initializer/default.rs b/src/service/http/initializer/default.rs index 80268e3c..ed996271 100644 --- a/src/service/http/initializer/default.rs +++ b/src/service/http/initializer/default.rs @@ -1,6 +1,16 @@ +use crate::app_context::AppContext; use crate::service::http::initializer::normalize_path::NormalizePathInitializer; use crate::service::http::initializer::Initializer; +use std::collections::BTreeMap; -pub fn default_initializers() -> Vec>> { - vec![Box::new(NormalizePathInitializer)] +pub fn default_initializers( + context: &AppContext, + state: &S, +) -> BTreeMap>> { + let initializers: Vec>> = vec![Box::new(NormalizePathInitializer)]; + initializers + .into_iter() + .filter(|initializer| initializer.enabled(context, state)) + .map(|initializer| (initializer.name(), initializer)) + .collect() } diff --git a/src/service/http/middleware/default.rs b/src/service/http/middleware/default.rs index 796e6faf..78cc035f 100644 --- a/src/service/http/middleware/default.rs +++ b/src/service/http/middleware/default.rs @@ -1,3 +1,4 @@ +use crate::app_context::AppContext; use crate::service::http::middleware::catch_panic::CatchPanicMiddleware; use crate::service::http::middleware::compression::RequestDecompressionMiddleware; use crate::service::http::middleware::request_id::{ @@ -10,9 +11,13 @@ use crate::service::http::middleware::size_limit::RequestBodyLimitMiddleware; use crate::service::http::middleware::timeout::TimeoutMiddleware; use crate::service::http::middleware::tracing::TracingMiddleware; use crate::service::http::middleware::Middleware; +use std::collections::BTreeMap; -pub fn default_middleware() -> Vec>> { - vec![ +pub fn default_middleware( + context: &AppContext, + state: &S, +) -> BTreeMap>> { + let middleware: Vec>> = vec![ Box::new(SensitiveRequestHeadersMiddleware), Box::new(SensitiveResponseHeadersMiddleware), Box::new(SetRequestIdMiddleware), @@ -22,5 +27,10 @@ pub fn default_middleware() -> Vec>> { Box::new(RequestDecompressionMiddleware), Box::new(TimeoutMiddleware), Box::new(RequestBodyLimitMiddleware), - ] + ]; + middleware + .into_iter() + .filter(|middleware| middleware.enabled(context, state)) + .map(|middleware| (middleware.name(), middleware)) + .collect() } diff --git a/src/service/http/service.rs b/src/service/http/service.rs index fc09918b..d5569d73 100644 --- a/src/service/http/service.rs +++ b/src/service/http/service.rs @@ -86,8 +86,12 @@ impl AppService for HttpService { impl HttpService { /// Create a new [HttpServiceBuilder]. - pub fn builder(path_root: &str, context: &AppContext) -> HttpServiceBuilder { - HttpServiceBuilder::new(path_root, context) + pub fn builder( + path_root: &str, + context: &AppContext, + state: &A::State, + ) -> HttpServiceBuilder { + HttpServiceBuilder::new(path_root, context, state) } /// List the available HTTP API routes. diff --git a/src/service/registry.rs b/src/service/registry.rs index c2a7f92d..a6c3fd08 100644 --- a/src/service/registry.rs +++ b/src/service/registry.rs @@ -1,6 +1,7 @@ use crate::app::App; use crate::app_context::AppContext; use crate::service::{AppService, AppServiceBuilder}; +use anyhow::bail; use std::collections::BTreeMap; use std::sync::Arc; use tracing::info; @@ -34,7 +35,7 @@ impl ServiceRegistry { info!(service = %S::name(), "Service is not enabled, skipping registration"); return Ok(()); } - self.register_unchecked(service) + self.register_internal(service) } /// Build and register a new service. If the service is not enabled (e.g., @@ -52,16 +53,18 @@ impl ServiceRegistry { info!(service = %S::name(), "Building service"); let service = builder.build(&self.context, &self.state).await?; - self.register_unchecked(service) + self.register_internal(service) } - fn register_unchecked(&mut self, service: S) -> anyhow::Result<()> + fn register_internal(&mut self, service: S) -> anyhow::Result<()> where S: AppService + 'static, { info!(service = %S::name(), "Registering service"); - self.services.insert(S::name(), Box::new(service)); + if self.services.insert(S::name(), Box::new(service)).is_some() { + bail!("Service `{}` was already registered", S::name()); + } Ok(()) } } diff --git a/src/service/worker/sidekiq/builder.rs b/src/service/worker/sidekiq/builder.rs index 8e3a6158..f238ad76 100644 --- a/src/service/worker/sidekiq/builder.rs +++ b/src/service/worker/sidekiq/builder.rs @@ -205,7 +205,7 @@ where /// /// The worker will be wrapped by a [RoadsterWorker], which provides some common behavior, such /// as enforcing a timeout/max duration of worker jobs. - pub fn register_app_worker(mut self, worker: W) -> Self + pub fn register_app_worker(mut self, worker: W) -> anyhow::Result where Args: Sync + Send + Serialize + for<'de> serde::Deserialize<'de> + 'static, W: AppWorker + 'static, @@ -219,12 +219,14 @@ where { let class_name = W::class_name(); debug!(worker = %class_name, "Registering worker"); - registered_workers.insert(class_name.clone()); + if !registered_workers.insert(class_name.clone()) { + bail!("Worker `{class_name}` was already registered"); + } let roadster_worker = RoadsterWorker::new(worker, state.clone()); processor.register(roadster_worker); } - self + Ok(self) } /// Register a periodic [worker][AppWorker] that will run with the provided args. The cadence @@ -255,8 +257,12 @@ where debug!(worker = %class_name, "Registering periodic worker"); let roadster_worker = RoadsterWorker::new(worker, state.clone()); let builder = builder.args(args)?; - let job_json = serde_json::to_string(&builder.into_periodic_job(class_name)?)?; - registered_periodic_workers.insert(job_json); + let job_json = serde_json::to_string(&builder.into_periodic_job(class_name.clone())?)?; + if !registered_periodic_workers.insert(job_json.clone()) { + bail!( + "Periodic worker `{class_name}` was already registered; full job: {job_json}" + ); + } builder.register(processor, roadster_worker).await?; }