Skip to content

Commit

Permalink
feat!: Use Axum's FromRef to allow custom state
Browse files Browse the repository at this point in the history
  • Loading branch information
spencewenski committed Jun 29, 2024
1 parent 3af6c2d commit 1dc5fbc
Show file tree
Hide file tree
Showing 56 changed files with 927 additions and 717 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ rust-version = "1.74.1"

[features]
default = ["sidekiq", "db-sql", "open-api", "jwt-ietf", "cli", "otel"]
http = ["dep:axum", "dep:axum-extra", "dep:tower", "dep:tower-http"]
http = ["dep:axum-extra", "dep:tower", "dep:tower-http"]
open-api = ["http", "dep:aide", "dep:schemars"]
sidekiq = ["dep:rusty-sidekiq", "dep:bb8", "dep:num_cpus"]
db-sql = ["dep:sea-orm", "dep:sea-orm-migration"]
Expand Down Expand Up @@ -43,7 +43,9 @@ opentelemetry-otlp = { version = "0.16.0", features = ["metrics", "trace", "logs
tracing-opentelemetry = { version = "0.24.0", features = ["metrics"], optional = true }

# Controllers
axum = { workspace = true, optional = true }
# `axum` is not optional because we use the `FromRef` trait pretty extensively, even in parts of
# the code that wouldn't otherwise need `axum`.
axum = { workspace = true, features = ["macros"] }
axum-extra = { version = "0.9.0", features = ["typed-header"], optional = true }
tower = { version = "0.4.13", optional = true }
tower-http = { version = "0.5.0", features = ["trace", "timeout", "request-id", "util", "normalize-path", "sensitive-headers", "catch-panic", "compression-full", "decompression-full", "limit", "cors"], optional = true }
Expand Down
4 changes: 2 additions & 2 deletions examples/full/src/api/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ pub mod hello_world;

use crate::api::grpc::hello_world::greeter_server::GreeterServer;
use crate::api::grpc::hello_world::MyGreeter;
use crate::app_state::AppState;
use roadster::app::context::AppContext;
use tonic::transport::server::Router;
use tonic::transport::Server;

pub fn routes(_state: &AppState) -> anyhow::Result<Router> {
pub fn routes(_state: &AppContext) -> anyhow::Result<Router> {
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(hello_world::FILE_DESCRIPTOR_SET)
.build()?;
Expand Down
6 changes: 3 additions & 3 deletions examples/full/src/api/http/example.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::app_state::AppState;
use crate::worker::example::ExampleWorker;
use aide::axum::routing::get_with;
use aide::axum::ApiRouter;
use aide::transform::TransformOperation;
use axum::extract::State;
use axum::Json;
use roadster::api::http::build_path;
use roadster::app::context::AppContext;
use roadster::error::RoadsterResult;
use roadster::service::worker::sidekiq::app_worker::AppWorker;
use schemars::JsonSchema;
Expand All @@ -15,7 +15,7 @@ use tracing::instrument;
const BASE: &str = "/example";
const TAG: &str = "Example";

pub fn routes(parent: &str) -> ApiRouter<AppState> {
pub fn routes(parent: &str) -> ApiRouter<AppContext> {
let root = build_path(parent, BASE);

ApiRouter::new().api_route(&root, get_with(example_get, example_get_docs))
Expand All @@ -26,7 +26,7 @@ pub fn routes(parent: &str) -> ApiRouter<AppState> {
pub struct ExampleResponse {}

#[instrument(skip_all)]
async fn example_get(State(state): State<AppState>) -> RoadsterResult<Json<ExampleResponse>> {
async fn example_get(State(state): State<AppContext>) -> RoadsterResult<Json<ExampleResponse>> {
ExampleWorker::enqueue(&state, "Example".to_string()).await?;
Ok(Json(ExampleResponse {}))
}
Expand Down
4 changes: 2 additions & 2 deletions examples/full/src/api/http/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::app_state::AppState;
use aide::axum::ApiRouter;
use roadster::app::context::AppContext;

pub mod example;

pub fn routes(parent: &str) -> ApiRouter<AppState> {
pub fn routes(parent: &str) -> ApiRouter<AppContext> {
ApiRouter::new().merge(example::routes(parent))
}
12 changes: 5 additions & 7 deletions examples/full/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(feature = "grpc")]
use crate::api::grpc::routes;
use crate::api::http;
use crate::app_state::CustomAppContext;
use crate::cli::AppCli;
use crate::service::example::example_service;
use crate::worker::example::ExampleWorker;
Expand All @@ -26,8 +25,7 @@ const BASE: &str = "/api";
pub struct App;

#[async_trait]
impl RoadsterApp for App {
type State = CustomAppContext;
impl RoadsterApp<AppContext> for App {
type Cli = AppCli;
type M = Migrator;

Expand All @@ -37,13 +35,13 @@ impl RoadsterApp for App {
.build())
}

async fn with_state(_context: &AppContext) -> RoadsterResult<Self::State> {
Ok(())
async fn provide_state(_context: AppContext) -> RoadsterResult<Self::State> {
Ok(_context)
}

async fn services(
registry: &mut ServiceRegistry<Self>,
context: &AppContext<Self::State>,
registry: &mut ServiceRegistry<Self, AppContext>,
context: &AppContext,
) -> RoadsterResult<()> {
registry
.register_builder(
Expand Down
5 changes: 0 additions & 5 deletions examples/full/src/app_state.rs

This file was deleted.

1 change: 0 additions & 1 deletion examples/full/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
pub mod api;
pub mod app;
pub mod app_state;
pub mod cli;
pub mod service;
pub mod worker;
4 changes: 2 additions & 2 deletions examples/full/src/service/example.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::app_state::AppState;
use roadster::app::context::AppContext;
use roadster::error::RoadsterResult;
use tokio_util::sync::CancellationToken;
use tracing::info;

pub async fn example_service(
_state: AppState,
_state: AppContext,
_cancel_token: CancellationToken,
) -> RoadsterResult<()> {
info!("Running example function-based service");
Expand Down
6 changes: 3 additions & 3 deletions examples/full/src/worker/example.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::app::App;
use crate::app_state::AppState;
use async_trait::async_trait;
use roadster::app::context::AppContext;
use roadster::service::worker::sidekiq::app_worker::AppWorker;
use sidekiq::Worker;
use tracing::{info, instrument};
Expand All @@ -17,8 +17,8 @@ impl Worker<String> for ExampleWorker {
}

#[async_trait]
impl AppWorker<App, String> for ExampleWorker {
fn build(_context: &AppState) -> Self {
impl AppWorker<_, String> for ExampleWorker {
fn build(_context: &AppContext) -> Self {
Self {}
}
}
84 changes: 57 additions & 27 deletions src/api/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ use crate::app::App;
use crate::app::MockApp;
use crate::error::RoadsterResult;
use async_trait::async_trait;
use axum::extract::FromRef;
use clap::{Args, Command, FromArgMatches};
use std::ffi::OsString;

pub mod roadster;

/// Implement to enable Roadster to run your custom CLI commands.
#[async_trait]
pub trait RunCommand<A>
pub trait RunCommand<A, S>
where
A: App + ?Sized + Sync,
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
A: App<S> + ?Sized + Sync,
{
/// Run the command.
///
Expand All @@ -25,17 +28,14 @@ where
/// continue execution after the command is complete.
/// * `Err(...)` - If the implementation experienced an error while handling the command. The
/// app should end execution after the command is complete.
async fn run(
&self,
app: &A,
cli: &A::Cli,
context: &AppContext<A::State>,
) -> RoadsterResult<bool>;
async fn run(&self, app: &A, cli: &A::Cli, context: &S) -> RoadsterResult<bool>;
}

pub(crate) fn parse_cli<A, I, T>(args: I) -> RoadsterResult<(RoadsterCli, A::Cli)>
pub(crate) fn parse_cli<A, S, I, T>(args: I) -> RoadsterResult<(RoadsterCli, A::Cli)>
where
A: App,
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
A: App<S>,
I: IntoIterator<Item = T>,
T: Into<OsString> + Clone,
{
Expand Down Expand Up @@ -78,14 +78,16 @@ where
Ok((roadster_cli, app_cli))
}

pub(crate) async fn handle_cli<A>(
pub(crate) async fn handle_cli<A, S>(
app: &A,
roadster_cli: &RoadsterCli,
app_cli: &A::Cli,
context: &AppContext<A::State>,
context: &S,
) -> RoadsterResult<bool>
where
A: App,
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
A: App<S>,
{
if roadster_cli.run(app, roadster_cli, context).await? {
return Ok(true);
Expand All @@ -96,29 +98,57 @@ where
Ok(false)
}

#[cfg(test)]
pub struct TestCli<S>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
{
_state: std::marker::PhantomData<S>,
}

#[cfg(test)]
mockall::mock! {
pub Cli {}
pub TestCli<S>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
{}

#[async_trait]
impl RunCommand<MockApp> for Cli {
async fn run(
&self,
app: &MockApp,
cli: &<MockApp as App>::Cli,
context: &AppContext<<MockApp as App>::State>,
) -> RoadsterResult<bool>;
impl<S> RunCommand<MockApp<S>, S> for TestCli<S>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
{
async fn run(&self, app: &MockApp<S>, cli: &<MockApp<S> as App<S>>::Cli, context: &S) -> RoadsterResult<bool>;
}

impl clap::FromArgMatches for Cli {
impl<S> clap::FromArgMatches for TestCli<S>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
{
fn from_arg_matches(matches: &clap::ArgMatches) -> Result<Self, clap::Error>;
fn update_from_arg_matches(&mut self, matches: &clap::ArgMatches) -> Result<(), clap::Error>;
}

impl clap::Args for Cli {
impl<S> clap::Args for TestCli<S>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
{
fn augment_args(cmd: clap::Command) -> clap::Command;
fn augment_args_for_update(cmd: clap::Command) -> clap::Command;
}

impl<S> Clone for TestCli<S>
where
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
{
fn clone(&self) -> Self;
}
}

#[cfg(test)]
Expand Down Expand Up @@ -150,12 +180,12 @@ mod tests {
#[cfg_attr(coverage_nightly, coverage(off))]
fn parse_cli(_case: TestCase, #[case] args: Option<&str>, #[case] arg_list: Option<Vec<&str>>) {
// Arrange
let augment_args_context = MockCli::augment_args_context();
let augment_args_context = MockTestCli::<AppContext>::augment_args_context();
augment_args_context.expect().returning(|c| c);
let from_arg_matches_context = MockCli::from_arg_matches_context();
let from_arg_matches_context = MockTestCli::<AppContext>::from_arg_matches_context();
from_arg_matches_context
.expect()
.returning(|_| Ok(MockCli::default()));
.returning(|_| Ok(MockTestCli::<AppContext>::default()));

let args = if let Some(args) = args {
args.split(' ').collect_vec()
Expand All @@ -169,7 +199,7 @@ mod tests {
.collect_vec();

// Act
let (roadster_cli, _a) = super::parse_cli::<MockApp, _, _>(args).unwrap();
let (roadster_cli, _a) = super::parse_cli::<MockApp<AppContext>, _, _, _>(args).unwrap();

// Assert
assert_toml_snapshot!(roadster_cli);
Expand Down
15 changes: 7 additions & 8 deletions src/api/cli/roadster/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::app::context::AppContext;
use crate::app::App;
use crate::error::RoadsterResult;
use async_trait::async_trait;
use axum::extract::FromRef;
use clap::Parser;
use serde_derive::Serialize;
use tracing::info;
Expand All @@ -13,21 +14,19 @@ use tracing::info;
pub struct HealthArgs {}

#[async_trait]
impl<A> RunRoadsterCommand<A> for HealthArgs
impl<A, S> RunRoadsterCommand<A, S> for HealthArgs
where
A: App,
S: Clone + Send + Sync + 'static,
AppContext: FromRef<S>,
A: App<S>,
{
async fn run(
&self,
_app: &A,
_cli: &RoadsterCli,
#[allow(unused_variables)] context: &AppContext<A::State>,
#[allow(unused_variables)] context: &S,
) -> RoadsterResult<bool> {
let health = health_check::<A::State>(
#[cfg(any(feature = "sidekiq", feature = "db-sql"))]
context,
)
.await?;
let health = health_check(context).await?;
let health = serde_json::to_string_pretty(&health)?;
info!("\n{health}");
Ok(true)
Expand Down
Loading

0 comments on commit 1dc5fbc

Please sign in to comment.