Skip to content

Commit

Permalink
Merge pull request #696 from loco-rs/request-id-middleware-ordering
Browse files Browse the repository at this point in the history
add: request id + test more effective middleware ordering
  • Loading branch information
jondot authored Aug 11, 2024
2 parents 723bfeb + 96ef9ad commit 595531f
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 88 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ cargo_metadata = "0.18.1"

cfg-if = "1"

uuid = { version = "1.6", features = ["v4"] }
uuid = { version = "1.10.0", features = ["v4", "fast-rng"] }
requestty = "0.5.0"

# A socket.io server implementation
Expand Down
170 changes: 83 additions & 87 deletions src/controller/app_routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use tower_http::{
add_extension::AddExtensionLayer,
catch_panic::CatchPanicLayer,
compression::CompressionLayer,
cors,
services::{ServeDir, ServeFile},
set_header::SetResponseHeaderLayer,
timeout::TimeoutLayer,
Expand All @@ -20,9 +19,15 @@ use tower_http::{

#[cfg(feature = "channels")]
use super::channels::AppChannels;
use super::routes::Routes;
use super::{middleware::cors::cors_middleware, routes::Routes};
use crate::{
app::AppContext, config, controller::middleware::etag::EtagLayer, environment::Environment,
app::AppContext,
config,
controller::middleware::{
etag::EtagLayer,
request_id::{request_id_middleware, LocoRequestId},
},
environment::Environment,
errors, Result,
};

Expand Down Expand Up @@ -179,26 +184,86 @@ impl AppRoutes {
/// [`axum::Router`].
#[allow(clippy::cognitive_complexity)]
pub fn to_router(&self, ctx: AppContext, mut app: AXRouter<AppContext>) -> Result<AXRouter> {
//
// IMPORTANT: middleware ordering in this function is opposite to what you
// intuitively may think. when using `app.layer` to add individual middleware,
// the LAST middleware is the FIRST to meet the outside world (a user request
// starting), or "LIFO" order.
// We build the "onion" from the inside (start of this function),
// outwards (end of this function). This is why routes is first in coding order
// here (the core of the onion), and request ID is amongst the last
// (because every request is assigned with a unique ID, which starts its
// "life").
//
// NOTE: when using ServiceBuilder#layer the order is FIRST to LAST (but we
// don't use ServiceBuilder because it requires too complex generic typing for
// this function). ServiceBuilder is recommended to save compile times, but that
// may be a thing of the past as we don't notice any issues with compile times
// using the router directly, and ServiceBuilder has been reported to give
// issues in compile times itself (https://github.com/rust-lang/crates.io/pull/7443).
//
for router in self.collect() {
tracing::info!("{}", router.to_string());

app = app.route(&router.uri, router.method);
}

app = Self::add_powered_by_header(app, &ctx.config.server);
#[cfg(feature = "channels")]
if let Some(channels) = self.channels.as_ref() {
tracing::info!("[Middleware] Adding channels");
let channel_layer_app = tower::ServiceBuilder::new().layer(channels.layer.clone());
if let Some(cors) = &ctx
.config
.server
.middlewares
.cors
.as_ref()
.filter(|c| c.enable)
{
app = app.layer(
tower::ServiceBuilder::new()
.layer(cors_middleware(cors)?)
.layer(channel_layer_app),
);
} else {
app = app.layer(
tower::ServiceBuilder::new()
.layer(tower_http::cors::CorsLayer::permissive())
.layer(channel_layer_app),
);
}
}

if let Some(catch_panic) = &ctx.config.server.middlewares.catch_panic {
if catch_panic.enable {
app = Self::add_catch_panic(app);
}
}

if let Some(etag) = &ctx.config.server.middlewares.etag {
if etag.enable {
app = Self::add_etag_middleware(app);
}
}

if let Some(compression) = &ctx.config.server.middlewares.compression {
if compression.enable {
app = Self::add_compression_middleware(app);
}
}

if let Some(timeout_request) = &ctx.config.server.middlewares.timeout_request {
if timeout_request.enable {
app = Self::add_timeout_middleware(app, timeout_request);
}
}

if let Some(cors) = &ctx.config.server.middlewares.cors {
if cors.enable {
app = app.layer(cors_middleware(cors)?);
}
}

if let Some(limit) = &ctx.config.server.middlewares.limit_payload {
if limit.enable {
app = Self::add_limit_payload_middleware(app, limit)?;
Expand All @@ -211,62 +276,28 @@ impl AppRoutes {
}
}

if let Some(timeout_request) = &ctx.config.server.middlewares.timeout_request {
if timeout_request.enable {
app = Self::add_timeout_middleware(app, timeout_request);
}
}

let cors = ctx
.config
.server
.middlewares
.cors
.as_ref()
.filter(|cors| cors.enable)
.map(Self::get_cors_middleware)
.transpose()?;

if let Some(cors) = &cors {
app = app.layer(cors.clone());
tracing::info!("[Middleware] Adding cors");
}

if let Some(static_assets) = &ctx.config.server.middlewares.static_assets {
if static_assets.enable {
app = Self::add_static_asset_middleware(app, static_assets)?;
}
}

if let Some(etag) = &ctx.config.server.middlewares.etag {
if etag.enable {
app = Self::add_etag_middleware(app);
}
}
// XXX todo: remote IP middleware here

#[cfg(feature = "channels")]
if let Some(channels) = self.channels.as_ref() {
tracing::info!("[Middleware] Adding channels");
let channel_layer_app = tower::ServiceBuilder::new().layer(channels.layer.clone());
if let Some(cors) = cors {
app = app.layer(
tower::ServiceBuilder::new()
.layer(cors)
.layer(channel_layer_app),
);
} else {
app = app.layer(
tower::ServiceBuilder::new()
.layer(tower_http::cors::CorsLayer::permissive())
.layer(channel_layer_app),
);
}
}
app = Self::add_powered_by_header(app, &ctx.config.server);

app = Self::add_request_id_middleware(app);

let router = app.with_state(ctx);
Ok(router)
}

fn add_request_id_middleware(app: AXRouter<AppContext>) -> AXRouter<AppContext> {
let app = app.layer(axum::middleware::from_fn(request_id_middleware));
tracing::info!("[Middleware] Adding request_id middleware");
app
}

fn add_static_asset_middleware(
app: AXRouter<AppContext>,
config: &config::StaticAssetsMiddleware,
Expand Down Expand Up @@ -307,44 +338,6 @@ impl AppRoutes {
app
}

fn get_cors_middleware(config: &config::CorsMiddleware) -> Result<cors::CorsLayer> {
let mut cors: cors::CorsLayer = cors::CorsLayer::permissive();

if let Some(allow_origins) = &config.allow_origins {
// testing CORS, assuming https://example.com in the allow list:
// $ curl -v --request OPTIONS 'localhost:5150/api/_ping' -H 'Origin: https://example.com' -H 'Access-Control-Request-Method: GET'
// look for '< access-control-allow-origin: https://example.com' in response.
// if it doesn't appear (test with a bogus domain), it is not allowed.
let mut list = vec![];
for origins in allow_origins {
list.push(origins.parse()?);
}
cors = cors.allow_origin(list);
}

if let Some(allow_headers) = &config.allow_headers {
let mut headers = vec![];
for header in allow_headers {
headers.push(header.parse()?);
}
cors = cors.allow_headers(headers);
}

if let Some(allow_methods) = &config.allow_methods {
let mut methods = vec![];
for method in allow_methods {
methods.push(method.parse()?);
}
cors = cors.allow_methods(methods);
}

if let Some(max_age) = config.max_age {
cors = cors.max_age(Duration::from_secs(max_age));
}

Ok(cors)
}

fn add_catch_panic(app: AXRouter<AppContext>) -> AXRouter<AppContext> {
app.layer(CatchPanicLayer::custom(handle_panic))
}
Expand Down Expand Up @@ -372,7 +365,10 @@ impl AppRoutes {
let app = app
.layer(
TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| {
let request_id = uuid::Uuid::new_v4();
let ext = request.extensions();
let request_id = ext
.get::<LocoRequestId>()
.map_or_else(|| "req-id-none".to_string(), |r| r.get().to_string());
let user_agent = request
.headers()
.get(axum::http::header::USER_AGENT)
Expand Down
48 changes: 48 additions & 0 deletions src/controller/middleware/cors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use std::time::Duration;

use tower_http::cors;

use crate::{config, Result};

/// Create a CORS layer
///
/// # Errors
///
/// This function will return an error if parsing of header config fail
pub fn cors_middleware(config: &config::CorsMiddleware) -> Result<cors::CorsLayer> {
let mut cors: cors::CorsLayer = cors::CorsLayer::permissive();

if let Some(allow_origins) = &config.allow_origins {
// testing CORS, assuming https://example.com in the allow list:
// $ curl -v --request OPTIONS 'localhost:5150/api/_ping' -H 'Origin: https://example.com' -H 'Access-Control-Request-Method: GET'
// look for '< access-control-allow-origin: https://example.com' in response.
// if it doesn't appear (test with a bogus domain), it is not allowed.
let mut list = vec![];
for origins in allow_origins {
list.push(origins.parse()?);
}
cors = cors.allow_origin(list);
}

if let Some(allow_headers) = &config.allow_headers {
let mut headers = vec![];
for header in allow_headers {
headers.push(header.parse()?);
}
cors = cors.allow_headers(headers);
}

if let Some(allow_methods) = &config.allow_methods {
let mut methods = vec![];
for method in allow_methods {
methods.push(method.parse()?);
}
cors = cors.allow_methods(methods);
}

if let Some(max_age) = config.max_age {
cors = cors.max_age(Duration::from_secs(max_age));
}

Ok(cors)
}
2 changes: 2 additions & 0 deletions src/controller/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#[cfg(all(feature = "auth_jwt", feature = "with-db"))]
pub mod auth;
pub mod cors;
pub mod etag;
pub mod format;
pub mod request_id;
77 changes: 77 additions & 0 deletions src/controller/middleware/request_id.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use axum::{extract::Request, http::HeaderValue, middleware::Next, response::Response};
use lazy_static::lazy_static;
use regex::Regex;
use tracing::warn;
use uuid::Uuid;

#[derive(Debug, Clone)]
pub struct LocoRequestId(String);

impl LocoRequestId {
/// Get the request id
#[must_use]
pub fn get(&self) -> &str {
self.0.as_str()
}
}

const X_REQUEST_ID: &str = "x-request-id";
const MAX_LEN: usize = 255;
lazy_static! {
static ref ID_CLEANUP: Regex = Regex::new(r"[^\w\-@]").unwrap();
}

pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
let header_request_id = request.headers().get(X_REQUEST_ID).cloned();
let request_id = make_request_id(header_request_id);
request
.extensions_mut()
.insert(LocoRequestId(request_id.clone()));
let mut res = next.run(request).await;

if let Ok(v) = HeaderValue::from_str(request_id.as_str()) {
res.headers_mut().insert(X_REQUEST_ID, v);
} else {
warn!("could not set request ID into response headers: `{request_id}`",);
}
res
}

fn make_request_id(maybe_request_id: Option<HeaderValue>) -> String {
maybe_request_id
.and_then(|hdr| {
// see: https://github.com/rails/rails/blob/main/actionpack/lib/action_dispatch/middleware/request_id.rb#L39
let id: Option<String> = hdr.to_str().ok().map(|s| {
ID_CLEANUP
.replace_all(s, "")
.chars()
.take(MAX_LEN)
.collect()
});
id.filter(|s| !s.is_empty())
})
.unwrap_or_else(|| Uuid::new_v4().to_string())
}

#[cfg(test)]
mod tests {
use axum::http::HeaderValue;
use insta::assert_debug_snapshot;

use super::make_request_id;

#[test]
fn create_or_fetch_request_id() {
let id = make_request_id(Some(HeaderValue::from_static("foo-bar=baz")));
assert_debug_snapshot!(id);
let id = make_request_id(Some(HeaderValue::from_static("")));
assert_debug_snapshot!(id.len());
let id = make_request_id(Some(HeaderValue::from_static("==========")));
assert_debug_snapshot!(id.len());
let long_id = "x".repeat(1000);
let id = make_request_id(Some(HeaderValue::from_str(&long_id).unwrap()));
assert_debug_snapshot!(id.len());
let id = make_request_id(None);
assert_debug_snapshot!(id.len());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: src/controller/middleware/request_id.rs
expression: id.len()
---
36
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: src/controller/middleware/request_id.rs
expression: id.len()
---
36
Loading

0 comments on commit 595531f

Please sign in to comment.