From 357f107d3cc3f50eb29be0023721186b1037eba1 Mon Sep 17 00:00:00 2001 From: Fangdun Tsai Date: Mon, 18 Dec 2023 07:45:20 +0800 Subject: [PATCH] chore: improve builtin middleware (#119) * chore: improve builtin middleware * fix(docs): type link * chore(core): improve realip type --- viz-core/src/middleware/compression.rs | 3 +- viz-core/src/middleware/cookie.rs | 38 ++++++++++++----------- viz-core/src/middleware/cors.rs | 23 ++++++++------ viz-core/src/middleware/csrf.rs | 2 +- viz-core/src/middleware/limits.rs | 2 +- viz-core/src/middleware/otel/metrics.rs | 2 +- viz-core/src/middleware/otel/tracing.rs | 11 ++++--- viz-core/src/middleware/session/config.rs | 22 +++++-------- viz-core/src/types/realip.rs | 36 ++++++++++++++------- viz-core/tests/handler.rs | 4 +-- 10 files changed, 79 insertions(+), 64 deletions(-) diff --git a/viz-core/src/middleware/compression.rs b/viz-core/src/middleware/compression.rs index f41a0e46..09b91e9a 100644 --- a/viz-core/src/middleware/compression.rs +++ b/viz-core/src/middleware/compression.rs @@ -44,7 +44,8 @@ where let accept_encoding = req .headers() .get(ACCEPT_ENCODING) - .and_then(|v| v.to_str().ok()) + .map(HeaderValue::to_str) + .and_then(Result::ok) .and_then(parse_accept_encoding); let raw = self.h.call(req).await?; diff --git a/viz-core/src/middleware/cookie.rs b/viz-core/src/middleware/cookie.rs index 4ba4a39b..e8ac8b64 100644 --- a/viz-core/src/middleware/cookie.rs +++ b/viz-core/src/middleware/cookie.rs @@ -5,20 +5,21 @@ use std::fmt; use crate::{ async_trait, header::{HeaderValue, COOKIE, SET_COOKIE}, - types, Handler, IntoResponse, Request, Response, Result, Transform, + types::{Cookie, CookieJar, CookieKey, Cookies}, + Handler, IntoResponse, Request, Response, Result, Transform, }; /// A configure for [`CookieMiddleware`]. pub struct Config { #[cfg(any(feature = "cookie-signed", feature = "cookie-private"))] - key: std::sync::Arc, + key: std::sync::Arc, } impl Config { - /// Creates a new config with the [`Key`][types::CookieKey]. + /// Creates a new config with the [`Key`][CookieKey]. #[cfg(any(feature = "cookie-signed", feature = "cookie-private"))] #[must_use] - pub fn with_key(key: types::CookieKey) -> Self { + pub fn with_key(key: CookieKey) -> Self { Self { key: std::sync::Arc::new(key), } @@ -29,7 +30,7 @@ impl Default for Config { fn default() -> Self { Self { #[cfg(any(feature = "cookie-signed", feature = "cookie-private"))] - key: std::sync::Arc::new(types::CookieKey::generate()), + key: std::sync::Arc::new(CookieKey::generate()), } } } @@ -65,7 +66,7 @@ where pub struct CookieMiddleware { h: H, #[cfg(any(feature = "cookie-signed", feature = "cookie-private"))] - key: std::sync::Arc, + key: std::sync::Arc, } impl fmt::Debug for CookieMiddleware { @@ -82,8 +83,8 @@ impl fmt::Debug for CookieMiddleware { #[async_trait] impl Handler for CookieMiddleware where - O: IntoResponse, H: Handler> + Clone, + O: IntoResponse, { type Output = Result; @@ -92,15 +93,15 @@ where .headers() .get_all(COOKIE) .iter() - .filter_map(|c| HeaderValue::to_str(c).ok()) - .fold(types::CookieJar::new(), add_cookie); + .map(HeaderValue::to_str) + .filter_map(Result::ok) + .fold(CookieJar::new(), add_cookie); - let cookies = types::Cookies::new(jar); + let cookies = Cookies::new(jar); #[cfg(any(feature = "cookie-signed", feature = "cookie-private"))] let cookies = cookies.with_key(self.key.clone()); - req.extensions_mut() - .insert::(cookies.clone()); + req.extensions_mut().insert::(cookies.clone()); self.h .call(req) @@ -109,9 +110,9 @@ where .map(|mut res| { if let Ok(c) = cookies.jar().lock() { c.delta() - .filter_map(|cookie| { - HeaderValue::from_str(&cookie.encoded().to_string()).ok() - }) + .map(Cookie::encoded) + .map(|cookie| HeaderValue::from_str(&cookie.to_string())) + .filter_map(Result::ok) .fold(res.headers_mut(), |headers, cookie| { headers.append(SET_COOKIE, cookie); headers @@ -123,9 +124,10 @@ where } #[inline] -fn add_cookie(mut jar: types::CookieJar, value: &str) -> types::CookieJar { - types::Cookie::split_parse_encoded(value) +fn add_cookie(mut jar: CookieJar, value: &str) -> CookieJar { + Cookie::split_parse_encoded(value) .filter_map(Result::ok) - .for_each(|cookie| jar.add_original(cookie.into_owned())); + .map(Cookie::into_owned) + .for_each(|cookie| jar.add_original(cookie)); jar } diff --git a/viz-core/src/middleware/cors.rs b/viz-core/src/middleware/cors.rs index 2790330a..d0327899 100644 --- a/viz-core/src/middleware/cors.rs +++ b/viz-core/src/middleware/cors.rs @@ -63,7 +63,8 @@ impl Config { { self.allow_methods = allow_methods .into_iter() - .filter_map(|m| m.try_into().ok()) + .map(TryInto::try_into) + .filter_map(Result::ok) .collect(); self } @@ -79,7 +80,8 @@ impl Config { { self.allow_headers = allow_headers .into_iter() - .filter_map(|h| h.try_into().ok()) + .map(TryInto::try_into) + .filter_map(Result::ok) .collect(); self } @@ -95,7 +97,8 @@ impl Config { { self.allow_origins = allow_origins .into_iter() - .filter_map(|h| h.try_into().ok()) + .map(TryInto::try_into) + .filter_map(Result::ok) .collect(); self } @@ -111,7 +114,8 @@ impl Config { { self.expose_headers = expose_headers .into_iter() - .filter_map(|h| h.try_into().ok()) + .map(TryInto::try_into) + .filter_map(Result::ok) .collect(); self } @@ -202,8 +206,8 @@ pub struct CorsMiddleware { #[async_trait] impl Handler for CorsMiddleware where - O: IntoResponse, H: Handler> + Clone, + O: IntoResponse, { type Output = Result; @@ -244,7 +248,9 @@ where hs.to_str() .map(|hs| { hs.split(',') - .filter_map(|h| HeaderName::from_bytes(h.as_bytes()).ok()) + .map(str::as_bytes) + .map(HeaderName::from_bytes) + .filter_map(Result::ok) .any(|header| self.config.allow_headers.contains(&header)) }) .unwrap_or(false), @@ -273,10 +279,7 @@ where headers.typed_insert(self.aceh.clone()); } - self.h - .call(req) - .await - .map_or_else(IntoResponse::into_response, IntoResponse::into_response) + self.h.call(req).await.map(IntoResponse::into_response)? }; // https://github.com/rs/cors/issues/10 diff --git a/viz-core/src/middleware/csrf.rs b/viz-core/src/middleware/csrf.rs index 40584707..321e8241 100644 --- a/viz-core/src/middleware/csrf.rs +++ b/viz-core/src/middleware/csrf.rs @@ -189,8 +189,8 @@ where #[async_trait] impl Handler for CsrfMiddleware where - O: IntoResponse, H: Handler> + Clone, + O: IntoResponse, S: Fn() -> Result> + Send + Sync + 'static, G: Fn(&[u8], Vec) -> Vec + Send + Sync + 'static, V: Fn(&[u8], String) -> bool + Send + Sync + 'static, diff --git a/viz-core/src/middleware/limits.rs b/viz-core/src/middleware/limits.rs index 2591ff41..9b4a7441 100644 --- a/viz-core/src/middleware/limits.rs +++ b/viz-core/src/middleware/limits.rs @@ -70,8 +70,8 @@ pub struct LimitsMiddleware { #[async_trait] impl Handler for LimitsMiddleware where - O: IntoResponse, H: Handler> + Clone, + O: IntoResponse, { type Output = Result; diff --git a/viz-core/src/middleware/otel/metrics.rs b/viz-core/src/middleware/otel/metrics.rs index 4a42b71a..16b3c51b 100644 --- a/viz-core/src/middleware/otel/metrics.rs +++ b/viz-core/src/middleware/otel/metrics.rs @@ -99,8 +99,8 @@ pub struct MetricsMiddleware { #[async_trait] impl Handler for MetricsMiddleware where - O: IntoResponse, H: Handler> + Clone, + O: IntoResponse, { type Output = Result; diff --git a/viz-core/src/middleware/otel/tracing.rs b/viz-core/src/middleware/otel/tracing.rs index d60af86c..a6c0d977 100644 --- a/viz-core/src/middleware/otel/tracing.rs +++ b/viz-core/src/middleware/otel/tracing.rs @@ -4,7 +4,7 @@ use std::{net::SocketAddr, sync::Arc}; -use http::uri::Scheme; +use http::{uri::Scheme, HeaderValue}; use opentelemetry::{ global, propagation::Extractor, @@ -61,10 +61,10 @@ pub struct TracingMiddleware { #[async_trait] impl Handler for TracingMiddleware where + H: Handler> + Clone, + O: IntoResponse, T: Tracer + Send + Sync + Clone + 'static, T::Span: Send + Sync + 'static, - O: IntoResponse, - H: Handler> + Clone, { type Output = Result; @@ -143,7 +143,10 @@ impl<'a> RequestHeaderCarrier<'a> { impl Extractor for RequestHeaderCarrier<'_> { fn get(&self, key: &str) -> Option<&str> { - self.headers.get(key).and_then(|v| v.to_str().ok()) + self.headers + .get(key) + .map(HeaderValue::to_str) + .and_then(Result::ok) } fn keys(&self) -> Vec<&str> { diff --git a/viz-core/src/middleware/session/config.rs b/viz-core/src/middleware/session/config.rs index 005ab8d3..c1fdec8c 100644 --- a/viz-core/src/middleware/session/config.rs +++ b/viz-core/src/middleware/session/config.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ async_trait, middleware::helper::{CookieOptions, Cookieable}, - types::Session, + types::{Cookie, Session}, Error, Handler, IntoResponse, Request, RequestExt, Response, Result, StatusCode, Transform, }; @@ -92,8 +92,8 @@ where #[async_trait] impl Handler for SessionMiddleware where - O: IntoResponse, H: Handler> + Clone, + O: IntoResponse, S: Storage + 'static, G: Fn() -> String + Send + Sync + 'static, V: Fn(&str) -> bool + Send + Sync + 'static, @@ -101,10 +101,10 @@ where type Output = Result; async fn call(&self, mut req: Request) -> Self::Output { - let cookies = req.cookies().map_err(Into::::into)?; + let cookies = req.cookies().map_err(Error::from)?; let cookie = self.config.get_cookie(&cookies); - let mut session_id = cookie.map(|cookie| cookie.value().to_string()); + let mut session_id = cookie.as_ref().map(Cookie::value).map(ToString::to_string); let data = match &session_id { Some(sid) if (self.config.store().verify)(sid) => self.config.store().get(sid).await?, _ => None, @@ -125,11 +125,7 @@ where if status == PURGED { if let Some(sid) = &session_id { - self.config - .store() - .remove(sid) - .await - .map_err(Into::::into)?; + self.config.store().remove(sid).await.map_err(Error::from)?; self.config.remove_cookie(&cookies); } @@ -138,11 +134,7 @@ where if status == RENEWED { if let Some(sid) = &session_id.take() { - self.config - .store() - .remove(sid) - .await - .map_err(Into::::into)?; + self.config.store().remove(sid).await.map_err(Error::from)?; } } @@ -160,7 +152,7 @@ where &self.config.ttl().unwrap_or_else(max_age), ) .await - .map_err(Into::::into)?; + .map_err(Error::from)?; resp } diff --git a/viz-core/src/types/realip.rs b/viz-core/src/types/realip.rs index 2afbad77..ee918686 100644 --- a/viz-core/src/types/realip.rs +++ b/viz-core/src/types/realip.rs @@ -1,8 +1,14 @@ -use std::{net::IpAddr, str}; +use std::{ + net::{IpAddr, SocketAddr}, + str, +}; use rfc7239::{NodeIdentifier, NodeName}; -use crate::{header::FORWARDED, Request, RequestExt, Result}; +use crate::{ + header::{HeaderValue, FORWARDED}, + Request, RequestExt, Result, +}; /// Gets real ip remote addr from request headers. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] @@ -19,15 +25,21 @@ impl RealIp { pub fn parse(req: &Request) -> Option { req.headers() .get(Self::X_REAL_IP) - .and_then(|value| value.to_str().ok()) - .and_then(|value| value.parse::().ok()) + .map(HeaderValue::to_str) + .and_then(Result::ok) + .map(str::parse) + .and_then(Result::ok) .or_else(|| { req.headers() .get(FORWARDED) - .and_then(|value| value.to_str().ok()) - .and_then(|value| rfc7239::parse(value).collect::, _>>().ok()) - .and_then(|value| { - value.into_iter().find_map(|item| match item.forwarded_for { + .map(HeaderValue::to_str) + .and_then(Result::ok) + .map(rfc7239::parse) + .map(Iterator::collect) + .and_then(Result::ok) + .map(Vec::into_iter) + .and_then(|mut value| { + value.find_map(|item| match item.forwarded_for { Some(NodeIdentifier { name: NodeName::Ip(ip_addr), .. @@ -39,15 +51,17 @@ impl RealIp { .or_else(|| { req.headers() .get(Self::X_FORWARDED_FOR) - .and_then(|value| value.to_str().ok()) + .map(HeaderValue::to_str) + .and_then(Result::ok) .and_then(|value| { value .split(',') .map(str::trim) - .find_map(|value| value.parse::().ok()) + .map(str::parse) + .find_map(Result::ok) }) }) .map(RealIp) - .or_else(|| req.remote_addr().map(|addr| RealIp(addr.ip()))) + .or_else(|| req.remote_addr().map(SocketAddr::ip).map(RealIp)) } } diff --git a/viz-core/tests/handler.rs b/viz-core/tests/handler.rs index 1976d5c1..49e4bf69 100644 --- a/viz-core/tests/handler.rs +++ b/viz-core/tests/handler.rs @@ -107,8 +107,8 @@ async fn handler() -> Result<()> { } async fn a(_: Request) -> Result { - Err(CustomError::NotFound)?; - Err(CustomError2::NotFound)?; + // Err(CustomError::NotFound)?; + // Err(CustomError2::NotFound)?; Ok(().into_response()) } async fn b(_: Request) -> Result {