Skip to content

Commit

Permalink
Disallow registering things multiple times
Browse files Browse the repository at this point in the history
Update the following to thrown an error if a duplicate resource is
registered:
- Worker service builder -- disallow duplicate workers or periodic
  jobs from being registered
- Http service builder -- disallow duplicate middleware or initializers
  from being registered. Note: Axum's built-in behavior is to disallow
  duplicate routes, so we don't need to handle that ourselves.
- Service registry -- disallow duplicate services from being registered.
  Note: we may want to allow this in the future -- it _may_ be useful,
  but will wait until we get a feature request for this before allowing
  it by default.
  • Loading branch information
spencewenski committed May 6, 2024
1 parent 3a9e9e8 commit dece541
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 35 deletions.
7 changes: 5 additions & 2 deletions examples/minimal/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ impl RoadsterApp for App {
state: Arc<Self::State>,
) -> anyhow::Result<()> {
registry
.register_builder(HttpService::builder(BASE, &context).router(controller::routes(BASE)))
.register_builder(
HttpService::builder(BASE, &context, state.as_ref())
.router(controller::routes(BASE)),
)
.await?;

registry
.register_builder(
SidekiqWorkerService::builder(context.clone(), state.clone())
.await?
.register_app_worker(ExampleWorker::build(&state)),
.register_app_worker(ExampleWorker::build(&state))?,
)
.await?;

Expand Down
1 change: 1 addition & 0 deletions src/controller/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub fn default_routes<S>(parent: &str, config: &AppConfig) -> ApiRouter<S>
where
S: Clone + Send + Sync + 'static + Into<Arc<AppContext>>,
{
// Todo: Allow disabling the default routes
ApiRouter::new()
.merge(ping::routes(parent))
.merge(health::routes(parent))
Expand Down
47 changes: 30 additions & 17 deletions src/service/http/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ use aide::axum::ApiRouter;
use aide::openapi::OpenApi;
#[cfg(feature = "open-api")]
use aide::transform::TransformOpenApi;
use anyhow::bail;
use async_trait::async_trait;
#[cfg(feature = "open-api")]
use axum::Extension;
#[cfg(not(feature = "open-api"))]
use axum::Router;
use itertools::Itertools;
use std::collections::BTreeMap;
#[cfg(feature = "open-api")]
use std::sync::Arc;
use tracing::info;
Expand All @@ -30,22 +32,22 @@ pub struct HttpServiceBuilder<A: App> {
router: ApiRouter<A::State>,
#[cfg(feature = "open-api")]
api_docs: Box<dyn Fn(TransformOpenApi) -> TransformOpenApi + Send>,
middleware: Vec<Box<dyn Middleware<A::State>>>,
initializers: Vec<Box<dyn Initializer<A::State>>>,
middleware: BTreeMap<String, Box<dyn Middleware<A::State>>>,
initializers: BTreeMap<String, Box<dyn Initializer<A::State>>>,
}

impl<A: App> HttpServiceBuilder<A> {
pub fn new(path_root: &str, app_context: &AppContext) -> Self {
pub fn new(path_root: &str, context: &AppContext, state: &A::State) -> Self {
#[cfg(feature = "open-api")]
let app_name = app_context.config.app.name.clone();
let app_name = context.config.app.name.clone();
Self {
router: default_routes(path_root, &app_context.config),
router: default_routes(path_root, &context.config),
#[cfg(feature = "open-api")]
api_docs: Box::new(move |api| {
api.title(&app_name).description(&format!("# {}", app_name))
}),
middleware: default_middleware(),
initializers: default_initializers(),
middleware: default_middleware(context, state),
initializers: default_initializers(context, state),
}
}

Expand All @@ -70,14 +72,27 @@ impl<A: App> HttpServiceBuilder<A> {
self
}

pub fn initializer(mut self, initializer: Box<dyn Initializer<A::State>>) -> Self {
self.initializers.push(initializer);
self
pub fn initializer(
mut self,
initializer: Box<dyn Initializer<A::State>>,
) -> anyhow::Result<Self> {
let name = initializer.name();
if self
.initializers
.insert(name.clone(), initializer)
.is_some()
{
bail!("Initializer `{name}` was already registered");
}
Ok(self)
}

pub fn middleware(mut self, middleware: Box<dyn Middleware<A::State>>) -> Self {
self.middleware.push(middleware);
self
pub fn middleware(mut self, middleware: Box<dyn Middleware<A::State>>) -> anyhow::Result<Self> {
let name = middleware.name();
if self.middleware.insert(name.clone(), middleware).is_some() {
bail!("Middleware `{name}` was already registered");
}
Ok(self)
}
}

Expand All @@ -102,9 +117,8 @@ impl<A: App> AppServiceBuilder<A, HttpService> for HttpServiceBuilder<A> {

let initializers = self
.initializers
.into_iter()
.values()
.filter(|initializer| initializer.enabled(context, state))
.unique_by(|initializer| initializer.name())
.sorted_by(|a, b| Ord::cmp(&a.priority(context, state), &b.priority(context, state)))
.collect_vec();

Expand All @@ -123,9 +137,8 @@ impl<A: App> AppServiceBuilder<A, HttpService> for HttpServiceBuilder<A> {
info!("Installing middleware. Note: the order of installation is the inverse of the order middleware will run when handling a request.");
let router = self
.middleware
.into_iter()
.values()
.filter(|middleware| middleware.enabled(context, state))
.unique_by(|middleware| middleware.name())
.sorted_by(|a, b| Ord::cmp(&a.priority(context, state), &b.priority(context, state)))
// Reverse due to how Axum's `Router#layer` method adds middleware.
.rev()
Expand Down
14 changes: 12 additions & 2 deletions src/service/http/initializer/default.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
use crate::app_context::AppContext;
use crate::service::http::initializer::normalize_path::NormalizePathInitializer;
use crate::service::http::initializer::Initializer;
use std::collections::BTreeMap;

pub fn default_initializers<S>() -> Vec<Box<dyn Initializer<S>>> {
vec![Box::new(NormalizePathInitializer)]
pub fn default_initializers<S>(
context: &AppContext,
state: &S,
) -> BTreeMap<String, Box<dyn Initializer<S>>> {
let initializers: Vec<Box<dyn Initializer<S>>> = vec![Box::new(NormalizePathInitializer)];
initializers
.into_iter()
.filter(|initializer| initializer.enabled(context, state))
.map(|initializer| (initializer.name(), initializer))
.collect()
}
16 changes: 13 additions & 3 deletions src/service/http/middleware/default.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::app_context::AppContext;
use crate::service::http::middleware::catch_panic::CatchPanicMiddleware;
use crate::service::http::middleware::compression::RequestDecompressionMiddleware;
use crate::service::http::middleware::request_id::{
Expand All @@ -10,9 +11,13 @@ use crate::service::http::middleware::size_limit::RequestBodyLimitMiddleware;
use crate::service::http::middleware::timeout::TimeoutMiddleware;
use crate::service::http::middleware::tracing::TracingMiddleware;
use crate::service::http::middleware::Middleware;
use std::collections::BTreeMap;

pub fn default_middleware<S>() -> Vec<Box<dyn Middleware<S>>> {
vec![
pub fn default_middleware<S>(
context: &AppContext,
state: &S,
) -> BTreeMap<String, Box<dyn Middleware<S>>> {
let middleware: Vec<Box<dyn Middleware<S>>> = vec![
Box::new(SensitiveRequestHeadersMiddleware),
Box::new(SensitiveResponseHeadersMiddleware),
Box::new(SetRequestIdMiddleware),
Expand All @@ -22,5 +27,10 @@ pub fn default_middleware<S>() -> Vec<Box<dyn Middleware<S>>> {
Box::new(RequestDecompressionMiddleware),
Box::new(TimeoutMiddleware),
Box::new(RequestBodyLimitMiddleware),
]
];
middleware
.into_iter()
.filter(|middleware| middleware.enabled(context, state))
.map(|middleware| (middleware.name(), middleware))
.collect()
}
8 changes: 6 additions & 2 deletions src/service/http/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,12 @@ impl<A: App> AppService<A> for HttpService {

impl HttpService {
/// Create a new [HttpServiceBuilder].
pub fn builder<A: App>(path_root: &str, context: &AppContext) -> HttpServiceBuilder<A> {
HttpServiceBuilder::new(path_root, context)
pub fn builder<A: App>(
path_root: &str,
context: &AppContext,
state: &A::State,
) -> HttpServiceBuilder<A> {
HttpServiceBuilder::new(path_root, context, state)
}

/// List the available HTTP API routes.
Expand Down
11 changes: 7 additions & 4 deletions src/service/registry.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::app::App;
use crate::app_context::AppContext;
use crate::service::{AppService, AppServiceBuilder};
use anyhow::bail;
use std::collections::BTreeMap;
use std::sync::Arc;
use tracing::info;
Expand Down Expand Up @@ -34,7 +35,7 @@ impl<A: App> ServiceRegistry<A> {
info!(service = %S::name(), "Service is not enabled, skipping registration");
return Ok(());
}
self.register_unchecked(service)
self.register_internal(service)
}

/// Build and register a new service. If the service is not enabled (e.g.,
Expand All @@ -52,16 +53,18 @@ impl<A: App> ServiceRegistry<A> {
info!(service = %S::name(), "Building service");
let service = builder.build(&self.context, &self.state).await?;

self.register_unchecked(service)
self.register_internal(service)
}

fn register_unchecked<S>(&mut self, service: S) -> anyhow::Result<()>
fn register_internal<S>(&mut self, service: S) -> anyhow::Result<()>
where
S: AppService<A> + 'static,
{
info!(service = %S::name(), "Registering service");

self.services.insert(S::name(), Box::new(service));
if self.services.insert(S::name(), Box::new(service)).is_some() {
bail!("Service `{}` was already registered", S::name());
}
Ok(())
}
}
16 changes: 11 additions & 5 deletions src/service/worker/sidekiq/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ where
///
/// The worker will be wrapped by a [RoadsterWorker], which provides some common behavior, such
/// as enforcing a timeout/max duration of worker jobs.
pub fn register_app_worker<Args, W>(mut self, worker: W) -> Self
pub fn register_app_worker<Args, W>(mut self, worker: W) -> anyhow::Result<Self>
where
Args: Sync + Send + Serialize + for<'de> serde::Deserialize<'de> + 'static,
W: AppWorker<A, Args> + 'static,
Expand All @@ -219,12 +219,14 @@ where
{
let class_name = W::class_name();
debug!(worker = %class_name, "Registering worker");
registered_workers.insert(class_name.clone());
if !registered_workers.insert(class_name.clone()) {
bail!("Worker `{class_name}` was already registered");
}
let roadster_worker = RoadsterWorker::new(worker, state.clone());
processor.register(roadster_worker);
}

self
Ok(self)
}

/// Register a periodic [worker][AppWorker] that will run with the provided args. The cadence
Expand Down Expand Up @@ -255,8 +257,12 @@ where
debug!(worker = %class_name, "Registering periodic worker");
let roadster_worker = RoadsterWorker::new(worker, state.clone());
let builder = builder.args(args)?;
let job_json = serde_json::to_string(&builder.into_periodic_job(class_name)?)?;
registered_periodic_workers.insert(job_json);
let job_json = serde_json::to_string(&builder.into_periodic_job(class_name.clone())?)?;
if !registered_periodic_workers.insert(job_json.clone()) {
bail!(
"Periodic worker `{class_name}` was already registered; full job: {job_json}"
);
}
builder.register(processor, roadster_worker).await?;
}

Expand Down

0 comments on commit dece541

Please sign in to comment.