From 8479839509bdd7f0b3d18116e255ebb4ab7ccc84 Mon Sep 17 00:00:00 2001 From: Nick Miller Date: Mon, 28 Aug 2023 14:04:51 -0500 Subject: [PATCH] http(serve): move state outside of serve function, use Extension for examples due to sqlx compile time checking --- Cargo.toml | 4 ++-- examples/demo/main.rs | 13 +++++++------ examples/echo/main.rs | 2 +- src/http/mod.rs | 18 +++++++----------- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b4c0dcd..15794b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ name = "demo" name = "echo" [dependencies] -axum = { version = "0.6.20", features = ["json"] } +axum = { version = "0.6.20", features = ["json", "macros"] } clap = { version = "4", features = ["derive", "env"] } once_cell = "1.18" prometheus = "0.13" @@ -28,7 +28,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.32", features = ["full"] } tower = "0.4" -tower-http = { version = "0.4", features = ["cors", "trace", "map-request-body"] } +tower-http = { version = "0.4", features = ["cors", "trace", "map-request-body", "util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["fmt", "std", "json", "env-filter"] } diff --git a/examples/demo/main.rs b/examples/demo/main.rs index 15f81aa..907111f 100644 --- a/examples/demo/main.rs +++ b/examples/demo/main.rs @@ -3,7 +3,7 @@ // SPDX-License-Identifier: AGPL-3.0-or-later use servus::axum::{ - extract::{self, State}, + extract::{self, Extension}, http::StatusCode, response::IntoResponse, routing::{get, post}, @@ -25,6 +25,7 @@ struct AppConfig { response: String, } +#[derive(Clone)] struct AppState { pool: sqlx::postgres::PgPool, } @@ -56,13 +57,13 @@ async fn main() -> anyhow::Result<()> { let router = Router::new() .route("/message", post(post_message)) - .route("/message/all", get(get_messages)); + .route("/message/all", get(get_messages)) + .layer(Extension(state)); servus::http::serve( config.servus.http_address, Some(config.servus.metrics_address), router, - state, ) .await; @@ -76,7 +77,7 @@ struct Message { } async fn post_message( - State(state): State>, + Extension(state): Extension>, extract::Json(payload): extract::Json, ) -> StatusCode { info!( @@ -88,7 +89,7 @@ async fn post_message( if let Err(e) = sqlx::query!( "INSERT INTO guestbook (author, message) VALUES ($1, $2)", payload.author, - payload.message + payload.message, ) .execute(&state.pool) .await @@ -100,7 +101,7 @@ async fn post_message( StatusCode::OK } -async fn get_messages(State(state): State>) -> impl IntoResponse { +async fn get_messages(Extension(state): Extension>) -> impl IntoResponse { info!(message = "got get messages request!"); let q = sqlx::query!("select * from guestbook") diff --git a/examples/echo/main.rs b/examples/echo/main.rs index 8f800dc..443a036 100644 --- a/examples/echo/main.rs +++ b/examples/echo/main.rs @@ -25,7 +25,7 @@ async fn main() -> anyhow::Result<()> { // Note, we pass the `metrics_address` parameter value as `None` to imply we don't want to // start the metrics server. Also, the `state` parameter is the unit type `()`, meaning we have // no global state and all handlers are stateless. - servus::http::serve(config.servus.http_address, None, router, ()).await; + servus::http::serve(config.servus.http_address, None, router).await; Ok(()) } diff --git a/src/http/mod.rs b/src/http/mod.rs index d133fd8..fa6e96f 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -32,18 +32,14 @@ use tracing::{error, Level}; /// /// Both the application server and metrics server will respond to a CTRL-C shutdown signal and /// terminate gracefully. -pub async fn serve( +pub async fn serve( http_address: SocketAddr, metrics_address: Option, - router: Router, - state: S, -) where - S: Send + Sync + Clone + 'static, -{ + router: Router<()>, +) { // create primary application router and server // applying handler state if we have it, and default metrics/tracing middleware - let r = router - .with_state(state) + let router = router .route_layer(middleware::from_fn(metrics::middleware)) // only record matched routes .layer( TraceLayer::new_for_http().make_span_with( @@ -54,17 +50,17 @@ pub async fn serve( ); let app = Server::bind(&http_address) - .serve(r.into_make_service()) + .serve(router.into_make_service()) .with_graceful_shutdown(shutdown_signal()); if let Some(metrics_address) = metrics_address { // create metrics router and server, also used for healthcheck - let r = Router::new() + let router = Router::new() .route("/metrics", routing::get(metrics::handler)) .route("/health", routing::get(health)); let metrics = Server::bind(&metrics_address) - .serve(r.into_make_service()) + .serve(router.into_make_service()) .with_graceful_shutdown(shutdown_signal()); // spawn each server instance (so they can be scheduled on separate threads as necessary)