diff --git a/examples/minimal/src/app.rs b/examples/minimal/src/app.rs index dac55895..fb069dd0 100644 --- a/examples/minimal/src/app.rs +++ b/examples/minimal/src/app.rs @@ -6,7 +6,6 @@ use roadster::service::http::service::HttpService; use roadster::service::registry::ServiceRegistry; use roadster::service::worker::sidekiq::app_worker::AppWorker; use roadster::service::worker::sidekiq::service::SidekiqWorkerService; -use std::sync::Arc; use crate::app_state::AppState; use crate::cli::AppCli; @@ -24,23 +23,23 @@ impl RoadsterApp for App { type Cli = AppCli; type M = Migrator; + async fn with_state(_context: &AppContext) -> anyhow::Result { + Ok(()) + } + async fn services( registry: &mut ServiceRegistry, - context: Arc, - state: Arc, + context: AppContext, ) -> anyhow::Result<()> { registry - .register_builder( - HttpService::builder(BASE, &context, state.as_ref()) - .router(controller::routes(BASE)), - ) + .register_builder(HttpService::builder(BASE, &context).router(controller::routes(BASE))) .await?; registry .register_builder( - SidekiqWorkerService::builder(context.clone(), state.clone()) + SidekiqWorkerService::builder(context.clone()) .await? - .register_app_worker(ExampleWorker::build(&state))?, + .register_app_worker(ExampleWorker::build(&context))?, ) .await?; diff --git a/examples/minimal/src/app_state.rs b/examples/minimal/src/app_state.rs index dbccbf2d..e65fb25e 100644 --- a/examples/minimal/src/app_state.rs +++ b/examples/minimal/src/app_state.rs @@ -2,29 +2,4 @@ // to implement the required traits used to convert it to/from `AppState`. #![allow(clippy::disallowed_types)] -use std::sync::Arc; - -use roadster::app_context::AppContext; - -#[derive(Debug, Clone)] -pub struct AppState { - context: Arc, -} - -impl AppState { - pub fn new(ctx: Arc) -> Self { - Self { context: ctx } - } -} - -impl From> for AppState { - fn from(value: Arc) -> Self { - AppState::new(value) - } -} - -impl From for Arc { - fn from(value: AppState) -> Self { - value.context - } -} +pub type AppState = (); diff --git a/examples/minimal/src/cli/mod.rs b/examples/minimal/src/cli/mod.rs index e70bbc4d..efbc8d4d 100644 --- a/examples/minimal/src/cli/mod.rs +++ b/examples/minimal/src/cli/mod.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; use clap::{Parser, Subcommand}; +use roadster::app_context::AppContext; use roadster::cli::RunCommand; @@ -18,9 +19,14 @@ pub struct AppCli { #[async_trait] impl RunCommand for AppCli { #[allow(clippy::disallowed_types)] - async fn run(&self, app: &App, cli: &AppCli, state: &AppState) -> anyhow::Result { + async fn run( + &self, + app: &App, + cli: &AppCli, + context: &AppContext, + ) -> anyhow::Result { if let Some(command) = self.command.as_ref() { - command.run(app, cli, state).await + command.run(app, cli, context).await } else { Ok(false) } @@ -35,7 +41,12 @@ pub enum AppCommand {} #[async_trait] impl RunCommand for AppCommand { - async fn run(&self, _app: &App, _cli: &AppCli, _state: &AppState) -> anyhow::Result { + async fn run( + &self, + _app: &App, + _cli: &AppCli, + _context: &AppContext, + ) -> anyhow::Result { Ok(false) } } diff --git a/examples/minimal/src/controller/example.rs b/examples/minimal/src/controller/example.rs index fb036d6f..cec0ef31 100644 --- a/examples/minimal/src/controller/example.rs +++ b/examples/minimal/src/controller/example.rs @@ -5,6 +5,7 @@ use aide::axum::ApiRouter; use aide::transform::TransformOperation; use axum::extract::State; use axum::Json; +use roadster::app_context::AppContext; use roadster::controller::build_path; use roadster::service::worker::sidekiq::app_worker::AppWorker; use roadster::view::app_error::AppError; @@ -15,7 +16,7 @@ use tracing::instrument; const BASE: &str = "/example"; const TAG: &str = "Example"; -pub fn routes(parent: &str) -> ApiRouter { +pub fn routes(parent: &str) -> ApiRouter> { let root = build_path(parent, BASE); ApiRouter::new().api_route(&root, get_with(example_get, example_get_docs)) @@ -26,7 +27,9 @@ pub fn routes(parent: &str) -> ApiRouter { pub struct ExampleResponse {} #[instrument(skip_all)] -async fn example_get(State(state): State) -> Result, AppError> { +async fn example_get( + State(state): State>, +) -> Result, AppError> { ExampleWorker::enqueue(&state, "Example".to_string()).await?; Ok(Json(ExampleResponse {})) } diff --git a/examples/minimal/src/controller/mod.rs b/examples/minimal/src/controller/mod.rs index a790f6e6..da3866f0 100644 --- a/examples/minimal/src/controller/mod.rs +++ b/examples/minimal/src/controller/mod.rs @@ -1,8 +1,9 @@ use crate::app_state::AppState; use aide::axum::ApiRouter; +use roadster::app_context::AppContext; pub mod example; -pub fn routes(parent: &str) -> ApiRouter { +pub fn routes(parent: &str) -> ApiRouter> { ApiRouter::new().merge(example::routes(parent)) } diff --git a/examples/minimal/src/worker/example.rs b/examples/minimal/src/worker/example.rs index e5f6d0aa..af7bc615 100644 --- a/examples/minimal/src/worker/example.rs +++ b/examples/minimal/src/worker/example.rs @@ -1,6 +1,7 @@ use crate::app::App; use crate::app_state::AppState; use async_trait::async_trait; +use roadster::app_context::AppContext; use roadster::service::worker::sidekiq::app_worker::AppWorker; use sidekiq::Worker; use tracing::{info, instrument}; @@ -18,7 +19,7 @@ impl Worker for ExampleWorker { #[async_trait] impl AppWorker for ExampleWorker { - fn build(_state: &AppState) -> Self { + fn build(_context: &AppContext) -> Self { Self {} } } diff --git a/src/app.rs b/src/app.rs index d39f7743..558aa216 100644 --- a/src/app.rs +++ b/src/app.rs @@ -15,7 +15,6 @@ use sea_orm::{ConnectOptions, Database}; use sea_orm_migration::MigratorTrait; use std::future; use std::future::Future; -use std::sync::Arc; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::{error, info, instrument, warn}; @@ -117,7 +116,7 @@ where (redis_enqueue, redis_fetch) }; - let context = AppContext::new( + let context = AppContext::<()>::new( config, #[cfg(feature = "db-sql")] db, @@ -125,30 +124,28 @@ where redis_enqueue.clone(), #[cfg(feature = "sidekiq")] redis_fetch.clone(), - ) - .await?; + )?; - let context = Arc::new(context); - let state = A::context_to_state(context.clone()).await?; - let state = Arc::new(state); + let state = A::with_state(&context).await?; + let context = context.with_custom(state); #[cfg(feature = "cli")] { if roadster_cli.run(&app, &roadster_cli, &context).await? { return Ok(()); } - if app_cli.run(&app, &app_cli, &state).await? { + if app_cli.run(&app, &app_cli, &context).await? { return Ok(()); } } - let mut service_registry = ServiceRegistry::new(context.clone(), state.clone()); - A::services(&mut service_registry, context.clone(), state.clone()).await?; + let mut service_registry = ServiceRegistry::new(context.clone()); + A::services(&mut service_registry, context.clone()).await?; #[cfg(feature = "cli")] for (_name, service) in service_registry.services.iter() { if service - .handle_cli(&roadster_cli, &app_cli, &context, &state) + .handle_cli(&roadster_cli, &app_cli, &context) .await? { return Ok(()); @@ -171,11 +168,10 @@ where // Spawn tasks for the app's services for (name, service) in service_registry.services { let context = context.clone(); - let state = state.clone(); let cancel_token = cancel_token.clone(); join_set.spawn(Box::pin(async move { info!(service=%name, "Running service"); - service.run(context, state, cancel_token).await + service.run(context, cancel_token).await })); } @@ -185,15 +181,14 @@ where context.clone(), graceful_shutdown( token_shutdown_signal(cancel_token.clone()), - A::graceful_shutdown(context.clone(), state.clone()), - #[cfg(feature = "db-sql")] + A::graceful_shutdown(context.clone()), context.clone(), ), )); // Task to listen for the signal to gracefully shutdown, and trigger other tasks to stop. let graceful_shutdown_signal = graceful_shutdown_signal( cancel_token.clone(), - A::graceful_shutdown_signal(context.clone(), state.clone()), + A::graceful_shutdown_signal(context.clone()), ); join_set.spawn(cancel_token_on_signal_received( graceful_shutdown_signal, @@ -223,7 +218,8 @@ where #[async_trait] pub trait App: Send + Sync { - type State: From> + Into> + Clone + Send + Sync + 'static; + // Todo: Are clone, etc necessary if we store it inside an Arc? + type State: Clone + Send + Sync + 'static; #[cfg(feature = "cli")] type Cli: clap::Args + RunCommand; #[cfg(feature = "db-sql")] @@ -258,16 +254,12 @@ pub trait App: Send + Sync { /// method is provided in case there's any additional work that needs to be done that the /// consumer can't put in a [`From`] implementation. For example, any /// configuration that needs to happen in an async method. - async fn context_to_state(context: Arc) -> anyhow::Result { - let state = Self::State::from(context); - Ok(state) - } + async fn with_state(context: &AppContext) -> anyhow::Result; /// Provide the services to run in the app. async fn services( _registry: &mut ServiceRegistry, - _context: Arc, - _state: Arc, + _context: AppContext, ) -> anyhow::Result<()> { Ok(()) } @@ -275,17 +267,14 @@ pub trait App: Send + Sync { /// Override to provide a custom shutdown signal. Roadster provides some default shutdown /// signals, but it may be desirable to provide a custom signal in order to, e.g., shutdown the /// server when a particular API is called. - async fn graceful_shutdown_signal(_context: Arc, _state: Arc) { + async fn graceful_shutdown_signal(_context: AppContext) { let _output: () = future::pending().await; } /// Override to provide custom graceful shutdown logic to clean up any resources created by /// the app. Roadster will take care of cleaning up the resources it created. #[instrument(skip_all)] - async fn graceful_shutdown( - _context: Arc, - _state: Arc, - ) -> anyhow::Result<()> { + async fn graceful_shutdown(_context: AppContext) -> anyhow::Result<()> { Ok(()) } } @@ -343,9 +332,9 @@ async fn token_shutdown_signal(cancellation_token: CancellationToken) { cancellation_token.cancelled().await } -async fn cancel_on_error( +async fn cancel_on_error( cancellation_token: CancellationToken, - context: Arc, + context: AppContext, f: F, ) -> anyhow::Result where @@ -359,10 +348,11 @@ where } #[instrument(skip_all)] -async fn graceful_shutdown( +async fn graceful_shutdown( shutdown_signal: F1, app_graceful_shutdown: F2, - #[cfg(feature = "db-sql")] context: Arc, + // This parameter is (currently) not used when no features are enabled. + #[allow(unused_variables)] context: AppContext, ) -> anyhow::Result<()> where F1: Future + Send + 'static, diff --git a/src/app_context.rs b/src/app_context.rs index aa92414e..1a59c3f3 100644 --- a/src/app_context.rs +++ b/src/app_context.rs @@ -5,18 +5,19 @@ use sea_orm::DatabaseConnection; use crate::config::app_config::AppConfig; -#[derive(Debug, Clone)] -pub struct AppContext { +#[derive(Clone)] +pub struct AppContext { inner: Arc, + custom: Arc, } -impl AppContext { - pub async fn new( +impl AppContext { + pub fn new( config: AppConfig, #[cfg(feature = "db-sql")] db: DatabaseConnection, #[cfg(feature = "sidekiq")] redis_enqueue: sidekiq::RedisPool, #[cfg(feature = "sidekiq")] redis_fetch: Option, - ) -> anyhow::Result { + ) -> anyhow::Result> { let inner = AppContextInner { config, #[cfg(feature = "db-sql")] @@ -26,11 +27,19 @@ impl AppContext { #[cfg(feature = "sidekiq")] redis_fetch, }; - Ok(Self { + Ok(AppContext { inner: Arc::new(inner), + custom: Arc::new(()), }) } + pub fn with_custom(self, custom: NewT) -> AppContext { + AppContext { + inner: self.inner, + custom: Arc::new(custom), + } + } + pub fn config(&self) -> &AppConfig { &self.inner.config } @@ -49,6 +58,10 @@ impl AppContext { pub fn redis_fetch(&self) -> Option<&sidekiq::RedisPool> { self.inner.redis_fetch.as_ref() } + + pub fn custom(&self) -> &T { + &self.custom + } } #[derive(Debug)] diff --git a/src/cli/migrate.rs b/src/cli/migrate.rs index 84713375..5cca2a39 100644 --- a/src/cli/migrate.rs +++ b/src/cli/migrate.rs @@ -19,7 +19,12 @@ impl RunRoadsterCommand for MigrateArgs where A: App, { - async fn run(&self, app: &A, cli: &RoadsterCli, context: &AppContext) -> anyhow::Result { + async fn run( + &self, + app: &A, + cli: &RoadsterCli, + context: &AppContext, + ) -> anyhow::Result { self.command.run(app, cli, context).await } } @@ -45,7 +50,12 @@ impl RunRoadsterCommand for MigrateCommand where A: App, { - async fn run(&self, _app: &A, cli: &RoadsterCli, context: &AppContext) -> anyhow::Result { + async fn run( + &self, + _app: &A, + cli: &RoadsterCli, + context: &AppContext, + ) -> anyhow::Result { if is_destructive(self) && !cli.allow_dangerous(context) { bail!("Running destructive command `{:?}` is not allowed in environment `{:?}`. To override, provide the `--allow-dangerous` CLI arg.", self, context.config().environment); } else if is_destructive(self) { diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 225093b2..2e2c7334 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -35,7 +35,12 @@ where /// continue execution after the command is complete. /// * `Err(...)` - If the implementation experienced an error while handling the command. The /// app should end execution after the command is complete. - async fn run(&self, app: &A, cli: &A::Cli, state: &A::State) -> anyhow::Result; + async fn run( + &self, + app: &A, + cli: &A::Cli, + context: &AppContext, + ) -> anyhow::Result; } /// Internal version of [RunCommand] that uses the [RoadsterCli] and [AppContext] instead of @@ -46,7 +51,12 @@ pub(crate) trait RunRoadsterCommand where A: App, { - async fn run(&self, app: &A, cli: &RoadsterCli, context: &AppContext) -> anyhow::Result; + async fn run( + &self, + app: &A, + cli: &RoadsterCli, + context: &AppContext, + ) -> anyhow::Result; } /// Roadster: The Roadster CLI provides various utilities for managing your application. If no subcommand @@ -70,7 +80,7 @@ pub struct RoadsterCli { } impl RoadsterCli { - pub fn allow_dangerous(&self, context: &AppContext) -> bool { + pub fn allow_dangerous(&self, context: &AppContext) -> bool { context.config().environment != Environment::Production || self.allow_dangerous } } @@ -80,7 +90,12 @@ impl RunRoadsterCommand for RoadsterCli where A: App, { - async fn run(&self, app: &A, cli: &RoadsterCli, context: &AppContext) -> anyhow::Result { + async fn run( + &self, + app: &A, + cli: &RoadsterCli, + context: &AppContext, + ) -> anyhow::Result { if let Some(command) = self.command.as_ref() { command.run(app, cli, context).await } else { @@ -102,7 +117,12 @@ impl RunRoadsterCommand for RoadsterCommand where A: App, { - async fn run(&self, app: &A, cli: &RoadsterCli, context: &AppContext) -> anyhow::Result { + async fn run( + &self, + app: &A, + cli: &RoadsterCli, + context: &AppContext, + ) -> anyhow::Result { match self { RoadsterCommand::Roadster(args) => args.run(app, cli, context).await, } @@ -120,7 +140,12 @@ impl RunRoadsterCommand for RoadsterArgs where A: App, { - async fn run(&self, app: &A, cli: &RoadsterCli, context: &AppContext) -> anyhow::Result { + async fn run( + &self, + app: &A, + cli: &RoadsterCli, + context: &AppContext, + ) -> anyhow::Result { self.command.run(app, cli, context).await } } @@ -151,7 +176,12 @@ impl RunRoadsterCommand for RoadsterSubCommand where A: App, { - async fn run(&self, app: &A, cli: &RoadsterCli, context: &AppContext) -> anyhow::Result { + async fn run( + &self, + app: &A, + cli: &RoadsterCli, + context: &AppContext, + ) -> anyhow::Result { match self { #[cfg(feature = "open-api")] RoadsterSubCommand::ListRoutes(_) => { diff --git a/src/cli/print_config.rs b/src/cli/print_config.rs index ae352ba0..6f73062a 100644 --- a/src/cli/print_config.rs +++ b/src/cli/print_config.rs @@ -37,7 +37,7 @@ where &self, _app: &A, _cli: &RoadsterCli, - context: &AppContext, + context: &AppContext, ) -> anyhow::Result { match self.format { Format::Debug => { diff --git a/src/config/service/http/initializer.rs b/src/config/service/http/initializer.rs index ed7c4d07..7702e527 100644 --- a/src/config/service/http/initializer.rs +++ b/src/config/service/http/initializer.rs @@ -74,7 +74,7 @@ impl CommonConfig { self } - pub fn enabled(&self, context: &AppContext) -> bool { + pub fn enabled(&self, context: &AppContext) -> bool { self.enable.unwrap_or( context .config() diff --git a/src/config/service/http/middleware.rs b/src/config/service/http/middleware.rs index eb6c0419..4be9e2a1 100644 --- a/src/config/service/http/middleware.rs +++ b/src/config/service/http/middleware.rs @@ -137,7 +137,7 @@ impl CommonConfig { self } - pub fn enabled(&self, context: &AppContext) -> bool { + pub fn enabled(&self, context: &AppContext) -> bool { self.enable.unwrap_or( context .config() diff --git a/src/config/service/mod.rs b/src/config/service/mod.rs index e0571ad0..d2757612 100644 --- a/src/config/service/mod.rs +++ b/src/config/service/mod.rs @@ -29,7 +29,7 @@ pub struct CommonConfig { } impl CommonConfig { - pub fn enabled(&self, context: &AppContext) -> bool { + pub fn enabled(&self, context: &AppContext) -> bool { self.enable .unwrap_or(context.config().service.default_enable) } diff --git a/src/controller/health.rs b/src/controller/health.rs index 6a726795..4472ad53 100644 --- a/src/controller/health.rs +++ b/src/controller/health.rs @@ -1,4 +1,3 @@ -use std::sync::Arc; #[cfg(feature = "sidekiq")] use std::time::Duration; use std::time::Instant; @@ -39,9 +38,9 @@ const BASE: &str = "/_health"; const TAG: &str = "Health"; #[cfg(not(feature = "open-api"))] -pub fn routes(parent: &str) -> Router +pub fn routes(parent: &str) -> Router> where - S: Clone + Send + Sync + 'static + Into>, + S: Clone + Send + Sync + 'static, { let root = build_path(parent, BASE); @@ -49,9 +48,9 @@ where } #[cfg(feature = "open-api")] -pub fn routes(parent: &str) -> ApiRouter +pub fn routes(parent: &str) -> ApiRouter> where - S: Clone + Send + Sync + 'static + Into>, + S: Clone + Send + Sync + 'static, { let root = build_path(parent, BASE); @@ -101,14 +100,13 @@ pub enum Status { #[instrument(skip_all)] async fn health_get( - #[cfg(any(feature = "sidekiq", feature = "db-sql"))] State(state): State, + #[cfg(any(feature = "sidekiq", feature = "db-sql"))] State(state): State>, ) -> Result, AppError> where - S: Clone + Send + Sync + 'static + Into>, + S: Clone + Send + Sync + 'static, { let timer = Instant::now(); #[cfg(any(feature = "sidekiq", feature = "db-sql"))] - let state: Arc = state.into(); #[cfg(feature = "db-sql")] let db = { let db_timer = Instant::now(); diff --git a/src/controller/mod.rs b/src/controller/mod.rs index 20cfc773..3fd268c8 100644 --- a/src/controller/mod.rs +++ b/src/controller/mod.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - #[cfg(feature = "open-api")] use aide::axum::ApiRouter; #[cfg(not(feature = "open-api"))] @@ -26,9 +24,9 @@ pub fn build_path(parent: &str, child: &str) -> String { } #[cfg(not(feature = "open-api"))] -pub fn default_routes(parent: &str, _config: &AppConfig) -> Router +pub fn default_routes(parent: &str, _config: &AppConfig) -> Router> where - S: Clone + Send + Sync + 'static + Into>, + S: Clone + Send + Sync + 'static, { Router::new() .merge(ping::routes(parent)) @@ -36,9 +34,9 @@ where } #[cfg(feature = "open-api")] -pub fn default_routes(parent: &str, config: &AppConfig) -> ApiRouter +pub fn default_routes(parent: &str, config: &AppConfig) -> ApiRouter> where - S: Clone + Send + Sync + 'static + Into>, + S: Clone + Send + Sync + 'static, { // Todo: Allow disabling the default routes ApiRouter::new() diff --git a/src/service/http/builder.rs b/src/service/http/builder.rs index 139288d7..891e1eae 100644 --- a/src/service/http/builder.rs +++ b/src/service/http/builder.rs @@ -27,9 +27,9 @@ use tracing::info; pub struct HttpServiceBuilder { #[cfg(not(feature = "open-api"))] - router: Router, + router: Router>, #[cfg(feature = "open-api")] - router: ApiRouter, + router: ApiRouter>, #[cfg(feature = "open-api")] api_docs: Box TransformOpenApi + Send>, middleware: BTreeMap>>, @@ -37,7 +37,7 @@ pub struct HttpServiceBuilder { } impl HttpServiceBuilder { - pub fn new(path_root: &str, context: &AppContext, state: &A::State) -> Self { + pub fn new(path_root: &str, context: &AppContext) -> Self { #[cfg(feature = "open-api")] let app_name = context.config().app.name.clone(); Self { @@ -46,19 +46,19 @@ impl HttpServiceBuilder { api_docs: Box::new(move |api| { api.title(&app_name).description(&format!("# {}", app_name)) }), - middleware: default_middleware(context, state), - initializers: default_initializers(context, state), + middleware: default_middleware(context), + initializers: default_initializers(context), } } #[cfg(not(feature = "open-api"))] - pub fn router(mut self, router: Router) -> Self { + pub fn router(mut self, router: Router>) -> Self { self.router = self.router.merge(router); self } #[cfg(feature = "open-api")] - pub fn router(mut self, router: ApiRouter) -> Self { + pub fn router(mut self, router: ApiRouter>) -> Self { self.router = self.router.merge(router); self } @@ -98,7 +98,7 @@ impl HttpServiceBuilder { #[async_trait] impl AppServiceBuilder for HttpServiceBuilder { - async fn build(self, context: &AppContext, state: &A::State) -> anyhow::Result { + async fn build(self, context: &AppContext) -> anyhow::Result { #[cfg(not(feature = "open-api"))] let router = self.router; @@ -113,50 +113,50 @@ impl AppServiceBuilder for HttpServiceBuilder { (router, api) }; - let router = router.with_state::<()>(state.clone()); + let router = router.with_state::<()>(context.clone()); let initializers = self .initializers .values() - .filter(|initializer| initializer.enabled(context, state)) - .sorted_by(|a, b| Ord::cmp(&a.priority(context, state), &b.priority(context, state))) + .filter(|initializer| initializer.enabled(context)) + .sorted_by(|a, b| Ord::cmp(&a.priority(context), &b.priority(context))) .collect_vec(); let router = initializers .iter() .try_fold(router, |router, initializer| { - initializer.after_router(router, context, state) + initializer.after_router(router, context) })?; let router = initializers .iter() .try_fold(router, |router, initializer| { - initializer.before_middleware(router, context, state) + initializer.before_middleware(router, context) })?; info!("Installing middleware. Note: the order of installation is the inverse of the order middleware will run when handling a request."); let router = self .middleware .values() - .filter(|middleware| middleware.enabled(context, state)) - .sorted_by(|a, b| Ord::cmp(&a.priority(context, state), &b.priority(context, state))) + .filter(|middleware| middleware.enabled(context)) + .sorted_by(|a, b| Ord::cmp(&a.priority(context), &b.priority(context))) // Reverse due to how Axum's `Router#layer` method adds middleware. .rev() .try_fold(router, |router, middleware| { info!(middleware=%middleware.name(), "Installing middleware"); - middleware.install(router, context, state) + middleware.install(router, context) })?; let router = initializers .iter() .try_fold(router, |router, initializer| { - initializer.after_middleware(router, context, state) + initializer.after_middleware(router, context) })?; let router = initializers .iter() .try_fold(router, |router, initializer| { - initializer.before_serve(router, context, state) + initializer.before_serve(router, context) })?; let service = HttpService { diff --git a/src/service/http/initializer/default.rs b/src/service/http/initializer/default.rs index ed996271..e76c4b62 100644 --- a/src/service/http/initializer/default.rs +++ b/src/service/http/initializer/default.rs @@ -4,13 +4,12 @@ use crate::service::http::initializer::Initializer; use std::collections::BTreeMap; pub fn default_initializers( - context: &AppContext, - state: &S, + context: &AppContext, ) -> BTreeMap>> { let initializers: Vec>> = vec![Box::new(NormalizePathInitializer)]; initializers .into_iter() - .filter(|initializer| initializer.enabled(context, state)) + .filter(|initializer| initializer.enabled(context)) .map(|initializer| (initializer.name(), initializer)) .collect() } diff --git a/src/service/http/initializer/mod.rs b/src/service/http/initializer/mod.rs index b71acd63..f5023ce6 100644 --- a/src/service/http/initializer/mod.rs +++ b/src/service/http/initializer/mod.rs @@ -10,7 +10,7 @@ use axum::Router; pub trait Initializer: Send { fn name(&self) -> String; - fn enabled(&self, context: &AppContext, state: &S) -> bool; + fn enabled(&self, context: &AppContext) -> bool; /// Used to determine the order in which the initializer will run during app initialization. /// Smaller numbers will run before larger numbers. For example, an initializer with priority @@ -21,41 +21,25 @@ pub trait Initializer: Send { /// /// If the order in which your initializer runs doesn't particularly matter, it's generally /// safe to set its priority as `0`. - fn priority(&self, context: &AppContext, state: &S) -> i32; + fn priority(&self, context: &AppContext) -> i32; - fn after_router( - &self, - router: Router, - _context: &AppContext, - _state: &S, - ) -> anyhow::Result { + fn after_router(&self, router: Router, _context: &AppContext) -> anyhow::Result { Ok(router) } fn before_middleware( &self, router: Router, - _context: &AppContext, - _state: &S, + _context: &AppContext, ) -> anyhow::Result { Ok(router) } - fn after_middleware( - &self, - router: Router, - _context: &AppContext, - _state: &S, - ) -> anyhow::Result { + fn after_middleware(&self, router: Router, _context: &AppContext) -> anyhow::Result { Ok(router) } - fn before_serve( - &self, - router: Router, - _context: &AppContext, - _state: &S, - ) -> anyhow::Result { + fn before_serve(&self, router: Router, _context: &AppContext) -> anyhow::Result { Ok(router) } } diff --git a/src/service/http/initializer/normalize_path.rs b/src/service/http/initializer/normalize_path.rs index a2f436e3..162d6476 100644 --- a/src/service/http/initializer/normalize_path.rs +++ b/src/service/http/initializer/normalize_path.rs @@ -16,7 +16,7 @@ impl Initializer for NormalizePathInitializer { "normalize-path".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -28,7 +28,7 @@ impl Initializer for NormalizePathInitializer { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -40,12 +40,7 @@ impl Initializer for NormalizePathInitializer { .priority } - fn before_serve( - &self, - router: Router, - _context: &AppContext, - _state: &S, - ) -> anyhow::Result { + fn before_serve(&self, router: Router, _context: &AppContext) -> anyhow::Result { let router = NormalizePathLayer::trim_trailing_slash().layer(router); let router = Router::new().nest_service("/", router); Ok(router) diff --git a/src/service/http/middleware/catch_panic.rs b/src/service/http/middleware/catch_panic.rs index 823d38ff..4d810154 100644 --- a/src/service/http/middleware/catch_panic.rs +++ b/src/service/http/middleware/catch_panic.rs @@ -14,7 +14,7 @@ impl Middleware for CatchPanicMiddleware { "catch-panic".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -26,7 +26,7 @@ impl Middleware for CatchPanicMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -38,7 +38,7 @@ impl Middleware for CatchPanicMiddleware { .priority } - fn install(&self, router: Router, _context: &AppContext, _state: &S) -> anyhow::Result { + fn install(&self, router: Router, _context: &AppContext) -> anyhow::Result { let router = router.layer(CatchPanicLayer::new()); Ok(router) diff --git a/src/service/http/middleware/compression.rs b/src/service/http/middleware/compression.rs index e4406294..2e5477bb 100644 --- a/src/service/http/middleware/compression.rs +++ b/src/service/http/middleware/compression.rs @@ -20,7 +20,7 @@ impl Middleware for ResponseCompressionMiddleware { "response-compression".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -32,7 +32,7 @@ impl Middleware for ResponseCompressionMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -44,7 +44,7 @@ impl Middleware for ResponseCompressionMiddleware { .priority } - fn install(&self, router: Router, _context: &AppContext, _state: &S) -> anyhow::Result { + fn install(&self, router: Router, _context: &AppContext) -> anyhow::Result { let router = router.layer(CompressionLayer::new()); Ok(router) @@ -57,7 +57,7 @@ impl Middleware for RequestDecompressionMiddleware { "request-decompression".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -69,7 +69,7 @@ impl Middleware for RequestDecompressionMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -81,7 +81,7 @@ impl Middleware for RequestDecompressionMiddleware { .priority } - fn install(&self, router: Router, _context: &AppContext, _state: &S) -> anyhow::Result { + fn install(&self, router: Router, _context: &AppContext) -> anyhow::Result { let router = router.layer(RequestDecompressionLayer::new()); Ok(router) diff --git a/src/service/http/middleware/default.rs b/src/service/http/middleware/default.rs index 78cc035f..ce6a670b 100644 --- a/src/service/http/middleware/default.rs +++ b/src/service/http/middleware/default.rs @@ -13,10 +13,7 @@ use crate::service::http::middleware::tracing::TracingMiddleware; use crate::service::http::middleware::Middleware; use std::collections::BTreeMap; -pub fn default_middleware( - context: &AppContext, - state: &S, -) -> BTreeMap>> { +pub fn default_middleware(context: &AppContext) -> BTreeMap>> { let middleware: Vec>> = vec![ Box::new(SensitiveRequestHeadersMiddleware), Box::new(SensitiveResponseHeadersMiddleware), @@ -30,7 +27,7 @@ pub fn default_middleware( ]; middleware .into_iter() - .filter(|middleware| middleware.enabled(context, state)) + .filter(|middleware| middleware.enabled(context)) .map(|middleware| (middleware.name(), middleware)) .collect() } diff --git a/src/service/http/middleware/mod.rs b/src/service/http/middleware/mod.rs index f0a58301..d889030c 100644 --- a/src/service/http/middleware/mod.rs +++ b/src/service/http/middleware/mod.rs @@ -24,7 +24,7 @@ use axum::Router; /// is done automatically by Roadster based on [Middleware::priority]). pub trait Middleware: Send { fn name(&self) -> String; - fn enabled(&self, context: &AppContext, state: &S) -> bool; + fn enabled(&self, context: &AppContext) -> bool; /// Used to determine the order in which the middleware will run when handling a request. Smaller /// numbers will run before larger numbers. For example, a middleware with priority `-10` /// will run before a middleware with priority `10`. @@ -41,6 +41,6 @@ pub trait Middleware: Send { /// is done automatically by Roadster based on [Middleware::priority]). So, a middleware /// with priority `-10` will be _installed after_ a middleware with priority `10`, which will /// allow the middleware with priority `-10` to _run before_ a middleware with priority `10`. - fn priority(&self, context: &AppContext, state: &S) -> i32; - fn install(&self, router: Router, context: &AppContext, state: &S) -> anyhow::Result; + fn priority(&self, context: &AppContext) -> i32; + fn install(&self, router: Router, context: &AppContext) -> anyhow::Result; } diff --git a/src/service/http/middleware/request_id.rs b/src/service/http/middleware/request_id.rs index 76b88b17..fd05c741 100644 --- a/src/service/http/middleware/request_id.rs +++ b/src/service/http/middleware/request_id.rs @@ -42,7 +42,7 @@ impl Middleware for SetRequestIdMiddleware { "set-request-id".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -54,7 +54,7 @@ impl Middleware for SetRequestIdMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -66,7 +66,7 @@ impl Middleware for SetRequestIdMiddleware { .priority } - fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { let header_name = &context .config() .service @@ -93,7 +93,7 @@ impl Middleware for PropagateRequestIdMiddleware { "propagate-request-id".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -105,7 +105,7 @@ impl Middleware for PropagateRequestIdMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -117,7 +117,7 @@ impl Middleware for PropagateRequestIdMiddleware { .priority } - fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { let header_name = &context .config() .service diff --git a/src/service/http/middleware/sensitive_headers.rs b/src/service/http/middleware/sensitive_headers.rs index f4d64945..259b8261 100644 --- a/src/service/http/middleware/sensitive_headers.rs +++ b/src/service/http/middleware/sensitive_headers.rs @@ -61,7 +61,7 @@ impl Middleware for SensitiveRequestHeadersMiddleware { "sensitive-request-headers".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -73,7 +73,7 @@ impl Middleware for SensitiveRequestHeadersMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -84,7 +84,7 @@ impl Middleware for SensitiveRequestHeadersMiddleware { .common .priority } - fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { let headers = context .config() .service @@ -109,7 +109,7 @@ impl Middleware for SensitiveResponseHeadersMiddleware { "sensitive-response-headers".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -121,7 +121,7 @@ impl Middleware for SensitiveResponseHeadersMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -132,7 +132,7 @@ impl Middleware for SensitiveResponseHeadersMiddleware { .common .priority } - fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { let headers = context .config() .service diff --git a/src/service/http/middleware/size_limit.rs b/src/service/http/middleware/size_limit.rs index f7021138..3a1cf9c3 100644 --- a/src/service/http/middleware/size_limit.rs +++ b/src/service/http/middleware/size_limit.rs @@ -28,7 +28,7 @@ impl Middleware for RequestBodyLimitMiddleware { "request-body-size-limit".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -40,7 +40,7 @@ impl Middleware for RequestBodyLimitMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -52,7 +52,7 @@ impl Middleware for RequestBodyLimitMiddleware { .priority } - fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { let limit = &context .config() .service diff --git a/src/service/http/middleware/timeout.rs b/src/service/http/middleware/timeout.rs index 0c048abc..856229d2 100644 --- a/src/service/http/middleware/timeout.rs +++ b/src/service/http/middleware/timeout.rs @@ -28,7 +28,7 @@ impl Middleware for TimeoutMiddleware { "timeout".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -40,7 +40,7 @@ impl Middleware for TimeoutMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -52,7 +52,7 @@ impl Middleware for TimeoutMiddleware { .priority } - fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { let timeout = &context .config() .service diff --git a/src/service/http/middleware/tracing.rs b/src/service/http/middleware/tracing.rs index 8888cb67..c0e72658 100644 --- a/src/service/http/middleware/tracing.rs +++ b/src/service/http/middleware/tracing.rs @@ -21,7 +21,7 @@ impl Middleware for TracingMiddleware { "tracing".to_string() } - fn enabled(&self, context: &AppContext, _state: &S) -> bool { + fn enabled(&self, context: &AppContext) -> bool { context .config() .service @@ -33,7 +33,7 @@ impl Middleware for TracingMiddleware { .enabled(context) } - fn priority(&self, context: &AppContext, _state: &S) -> i32 { + fn priority(&self, context: &AppContext) -> i32 { context .config() .service @@ -45,7 +45,7 @@ impl Middleware for TracingMiddleware { .priority } - fn install(&self, router: Router, context: &AppContext, _state: &S) -> anyhow::Result { + fn install(&self, router: Router, context: &AppContext) -> anyhow::Result { let request_id_header_name = &context .config() .service diff --git a/src/service/http/service.rs b/src/service/http/service.rs index 32944473..188766f3 100644 --- a/src/service/http/service.rs +++ b/src/service/http/service.rs @@ -18,6 +18,7 @@ use std::fs::File; use std::io::Write; #[cfg(feature = "open-api")] use std::path::PathBuf; +#[cfg(feature = "open-api")] use std::sync::Arc; use tokio_util::sync::CancellationToken; use tracing::info; @@ -34,7 +35,7 @@ impl AppService for HttpService { "http".to_string() } - fn enabled(context: &AppContext, _state: &A::State) -> bool { + fn enabled(context: &AppContext) -> bool { context.config().service.http.common.enabled(context) } @@ -43,8 +44,7 @@ impl AppService for HttpService { &self, roadster_cli: &RoadsterCli, _app_cli: &A::Cli, - _app_context: &AppContext, - _app_state: &A::State, + _app_context: &AppContext, ) -> anyhow::Result { if let Some(command) = roadster_cli.command.as_ref() { match command { @@ -68,8 +68,7 @@ impl AppService for HttpService { async fn run( &self, - app_context: Arc, - _app_state: Arc, + app_context: AppContext, cancel_token: CancellationToken, ) -> anyhow::Result<()> { let server_addr = app_context.config().service.http.custom.address.url(); @@ -88,10 +87,9 @@ impl HttpService { /// Create a new [HttpServiceBuilder]. pub fn builder( path_root: &str, - context: &AppContext, - state: &A::State, + context: &AppContext, ) -> HttpServiceBuilder { - HttpServiceBuilder::new(path_root, context, state) + HttpServiceBuilder::new(path_root, context) } /// List the available HTTP API routes. diff --git a/src/service/mod.rs b/src/service/mod.rs index e9f23365..e32002c2 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -3,7 +3,6 @@ use crate::app_context::AppContext; #[cfg(feature = "cli")] use crate::cli::RoadsterCli; use async_trait::async_trait; -use std::sync::Arc; use tokio_util::sync::CancellationToken; pub mod http; @@ -21,7 +20,7 @@ pub trait AppService: Send + Sync { Self: Sized; /// Whether the service is enabled. If the service is not enabled, it will not be run. - fn enabled(context: &AppContext, state: &A::State) -> bool + fn enabled(context: &AppContext) -> bool where Self: Sized; @@ -37,8 +36,7 @@ pub trait AppService: Send + Sync { &self, _roadster_cli: &RoadsterCli, _app_cli: &A::Cli, - _app_context: &AppContext, - _app_state: &A::State, + _app_context: &AppContext, ) -> anyhow::Result { Ok(false) } @@ -49,8 +47,7 @@ pub trait AppService: Send + Sync { /// the service. async fn run( &self, - app_context: Arc, - app_state: Arc, + app_context: AppContext, cancel_token: CancellationToken, ) -> anyhow::Result<()>; } @@ -66,9 +63,9 @@ where A: App, S: AppService, { - fn enabled(&self, app_context: &AppContext, app_state: &A::State) -> bool { - S::enabled(app_context, app_state) + fn enabled(&self, app_context: &AppContext) -> bool { + S::enabled(app_context) } - async fn build(self, context: &AppContext, state: &A::State) -> anyhow::Result; + async fn build(self, context: &AppContext) -> anyhow::Result; } diff --git a/src/service/registry.rs b/src/service/registry.rs index a6c3fd08..4eace7f1 100644 --- a/src/service/registry.rs +++ b/src/service/registry.rs @@ -3,7 +3,6 @@ use crate::app_context::AppContext; use crate::service::{AppService, AppServiceBuilder}; use anyhow::bail; use std::collections::BTreeMap; -use std::sync::Arc; use tracing::info; /// Registry for [AppService]s that will be run in the app. @@ -11,16 +10,14 @@ pub struct ServiceRegistry where A: App + ?Sized, { - pub(crate) context: Arc, - pub(crate) state: Arc, + pub(crate) context: AppContext, pub(crate) services: BTreeMap>>, } impl ServiceRegistry { - pub(crate) fn new(context: Arc, state: Arc) -> Self { + pub(crate) fn new(context: AppContext) -> Self { Self { context, - state, services: Default::default(), } } @@ -31,7 +28,7 @@ impl ServiceRegistry { where S: AppService + 'static, { - if !S::enabled(&self.context, &self.state) { + if !S::enabled(&self.context) { info!(service = %S::name(), "Service is not enabled, skipping registration"); return Ok(()); } @@ -45,13 +42,13 @@ impl ServiceRegistry { S: AppService + 'static, B: AppServiceBuilder, { - if !S::enabled(&self.context, &self.state) || !builder.enabled(&self.context, &self.state) { + if !S::enabled(&self.context) || !builder.enabled(&self.context) { info!(service = %S::name(), "Service is not enabled, skipping building and registration"); return Ok(()); } info!(service = %S::name(), "Building service"); - let service = builder.build(&self.context, &self.state).await?; + let service = builder.build(&self.context).await?; self.register_internal(service) } diff --git a/src/service/worker/sidekiq/app_worker.rs b/src/service/worker/sidekiq/app_worker.rs index e641eb81..8f164591 100644 --- a/src/service/worker/sidekiq/app_worker.rs +++ b/src/service/worker/sidekiq/app_worker.rs @@ -4,7 +4,6 @@ use async_trait::async_trait; use serde_derive::{Deserialize, Serialize}; use serde_with::{serde_as, skip_serializing_none}; use sidekiq::Worker; -use std::sync::Arc; use std::time::Duration; use typed_builder::TypedBuilder; @@ -52,13 +51,12 @@ where Args: Send + Sync + serde::Serialize + 'static, { /// Build a new instance of the [worker][Self]. - fn build(state: &A::State) -> Self; + fn build(context: &AppContext) -> Self; /// Enqueue the worker into its Sidekiq queue. This is a helper method around [Worker::perform_async] /// so the caller can simply provide the [state][App::State] instead of needing to access the /// [sidekiq::RedisPool] from inside the [state][App::State]. - async fn enqueue(state: &A::State, args: Args) -> anyhow::Result<()> { - let context: Arc = state.clone().into(); + async fn enqueue(context: &AppContext, args: Args) -> anyhow::Result<()> { Self::perform_async(context.redis_enqueue(), args).await?; Ok(()) } @@ -66,20 +64,19 @@ where /// Provide the [AppWorkerConfig] for [Self]. The default implementation populates the /// [AppWorkerConfig] using the values from the corresponding methods on [Self], e.g., /// [Self::max_retries]. - fn config(&self, state: &A::State) -> AppWorkerConfig { + fn config(&self, context: &AppContext) -> AppWorkerConfig { AppWorkerConfig::builder() - .max_retries(AppWorker::max_retries(self, state)) - .timeout(self.timeout(state)) - .max_duration(self.max_duration(state)) - .disable_argument_coercion(AppWorker::disable_argument_coercion(self, state)) + .max_retries(AppWorker::max_retries(self, context)) + .timeout(self.timeout(context)) + .max_duration(self.max_duration(context)) + .disable_argument_coercion(AppWorker::disable_argument_coercion(self, context)) .build() } /// See [AppWorkerConfig::max_retries]. /// /// The default implementation uses the value from the app's config file. - fn max_retries(&self, state: &A::State) -> usize { - let context: Arc = state.clone().into(); + fn max_retries(&self, context: &AppContext) -> usize { context .config() .service @@ -92,8 +89,7 @@ where /// See [AppWorkerConfig::timeout]. /// /// The default implementation uses the value from the app's config file. - fn timeout(&self, state: &A::State) -> bool { - let context: Arc = state.clone().into(); + fn timeout(&self, context: &AppContext) -> bool { context .config() .service @@ -106,8 +102,7 @@ where /// See [AppWorkerConfig::max_duration]. /// /// The default implementation uses the value from the app's config file. - fn max_duration(&self, state: &A::State) -> Duration { - let context: Arc = state.clone().into(); + fn max_duration(&self, context: &AppContext) -> Duration { context .config() .service @@ -120,8 +115,7 @@ where /// See [AppWorkerConfig::disable_argument_coercion]. /// /// The default implementation uses the value from the app's config file. - fn disable_argument_coercion(&self, state: &A::State) -> bool { - let context: Arc = state.clone().into(); + fn disable_argument_coercion(&self, context: &AppContext) -> bool { context .config() .service diff --git a/src/service/worker/sidekiq/builder.rs b/src/service/worker/sidekiq/builder.rs index 69cc4806..c1fe1a43 100644 --- a/src/service/worker/sidekiq/builder.rs +++ b/src/service/worker/sidekiq/builder.rs @@ -12,7 +12,6 @@ use num_traits::ToPrimitive; use serde::Serialize; use sidekiq::{periodic, Processor, ProcessorConfig}; use std::collections::HashSet; -use std::sync::Arc; use tracing::{debug, info, warn}; const PERIODIC_KEY: &str = "periodic"; @@ -27,8 +26,7 @@ where enum BuilderState { Enabled { processor: Processor, - context: Arc, - state: Arc, + context: AppContext, registered_workers: HashSet, registered_periodic_workers: HashSet, }, @@ -40,20 +38,16 @@ impl AppServiceBuilder for SidekiqWorkerServiceBuild where A: App + 'static, { - fn enabled(&self, app_context: &AppContext, app_state: &A::State) -> bool { + fn enabled(&self, app_context: &AppContext) -> bool { match self.state { BuilderState::Enabled { .. } => { - >::enabled(app_context, app_state) + >::enabled(app_context) } BuilderState::Disabled => false, } } - async fn build( - self, - context: &AppContext, - _state: &A::State, - ) -> anyhow::Result { + async fn build(self, context: &AppContext) -> anyhow::Result { let service = match self.state { BuilderState::Enabled { processor, @@ -77,23 +71,21 @@ where A: App + 'static, { pub async fn with_processor( - context: Arc, - state: Arc, + context: &AppContext, processor: Processor, ) -> anyhow::Result { - Self::new(context, state, Some(processor)).await + Self::new(context.clone(), Some(processor)).await } pub async fn with_default_processor( - context: Arc, - state: Arc, + context: &AppContext, worker_queues: Option>, ) -> anyhow::Result { - let processor = if !>::enabled(&context, &state) { + let processor = if !>::enabled(context) { debug!("Sidekiq service not enabled, not creating the Sidekiq processor"); None } else if let Some(redis_fetch) = context.redis_fetch() { - Self::auto_clean_periodic(&context).await?; + Self::auto_clean_periodic(context).await?; let queues = context .config() .service @@ -136,15 +128,14 @@ where None }; - Self::new(context, state, processor).await + Self::new(context.clone(), processor).await } async fn new( - context: Arc, - state: Arc, + context: AppContext, processor: Option, ) -> anyhow::Result { - let processor = if >::enabled(&context, &state) { + let processor = if >::enabled(&context) { processor } else { None @@ -154,7 +145,6 @@ where BuilderState::Enabled { processor, context, - state, registered_workers: Default::default(), registered_periodic_workers: Default::default(), } @@ -165,7 +155,7 @@ where Ok(Self { state }) } - async fn auto_clean_periodic(context: &AppContext) -> anyhow::Result<()> { + async fn auto_clean_periodic(context: &AppContext) -> anyhow::Result<()> { if context .config() .service @@ -218,8 +208,8 @@ where { if let BuilderState::Enabled { processor, - state, registered_workers, + context, .. } = &mut self.state { @@ -228,7 +218,7 @@ where if !registered_workers.insert(class_name.clone()) { bail!("Worker `{class_name}` was already registered"); } - let roadster_worker = RoadsterWorker::new(worker, state.clone()); + let roadster_worker = RoadsterWorker::new(worker, context); processor.register(roadster_worker); } @@ -254,14 +244,14 @@ where { if let BuilderState::Enabled { processor, - state, + context, registered_periodic_workers, .. } = &mut self.state { let class_name = W::class_name(); debug!(worker = %class_name, "Registering periodic worker"); - let roadster_worker = RoadsterWorker::new(worker, state.clone()); + let roadster_worker = RoadsterWorker::new(worker, context); let builder = builder.args(args)?; let job_json = serde_json::to_string(&builder.into_periodic_job(class_name.clone())?)?; if !registered_periodic_workers.insert(job_json.clone()) { @@ -284,7 +274,7 @@ where /// /// This is run after all the app's periodic jobs have been registered. pub(crate) async fn remove_stale_periodic_jobs( - context: &AppContext, + context: &AppContext, registered_periodic_workers: &HashSet, ) -> anyhow::Result<()> { let mut conn = context.redis_enqueue().get().await?; diff --git a/src/service/worker/sidekiq/roadster_worker.rs b/src/service/worker/sidekiq/roadster_worker.rs index 591cf630..415fc569 100644 --- a/src/service/worker/sidekiq/roadster_worker.rs +++ b/src/service/worker/sidekiq/roadster_worker.rs @@ -5,9 +5,9 @@ use crate::service::worker::sidekiq::app_worker::AppWorkerConfig; use async_trait::async_trait; use serde::Serialize; +use crate::app_context::AppContext; use sidekiq::{RedisPool, Worker, WorkerOpts}; use std::marker::PhantomData; -use std::sync::Arc; use std::time::Duration; use tracing::{error, instrument}; @@ -32,8 +32,8 @@ where Args: Send + Sync + Serialize, W: AppWorker, { - pub(crate) fn new(inner: W, state: Arc) -> Self { - let config = inner.config(&state); + pub(crate) fn new(inner: W, context: &AppContext) -> Self { + let config = inner.config(context); Self { inner, inner_config: config, diff --git a/src/service/worker/sidekiq/service.rs b/src/service/worker/sidekiq/service.rs index b8c218a0..874489ce 100644 --- a/src/service/worker/sidekiq/service.rs +++ b/src/service/worker/sidekiq/service.rs @@ -4,7 +4,6 @@ use crate::service::worker::sidekiq::builder::SidekiqWorkerServiceBuilder; use crate::service::AppService; use async_trait::async_trait; use sidekiq::Processor; -use std::sync::Arc; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::{debug, error}; @@ -22,7 +21,7 @@ impl AppService for SidekiqWorkerService { "sidekiq".to_string() } - fn enabled(context: &AppContext, _state: &A::State) -> bool + fn enabled(context: &AppContext) -> bool where Self: Sized, { @@ -48,8 +47,7 @@ impl AppService for SidekiqWorkerService { async fn run( &self, - _app_context: Arc, - _app_state: Arc, + _app_context: AppContext, cancel_token: CancellationToken, ) -> anyhow::Result<()> { let processor = self.processor.clone(); @@ -82,12 +80,11 @@ impl AppService for SidekiqWorkerService { impl SidekiqWorkerService { pub async fn builder( - context: Arc, - state: Arc, + context: AppContext, ) -> anyhow::Result> where A: App + 'static, { - SidekiqWorkerServiceBuilder::with_default_processor(context, state, None).await + SidekiqWorkerServiceBuilder::with_default_processor(&context, None).await } }