Skip to content

Commit

Permalink
chore: improve builtin middleware (#119)
Browse files Browse the repository at this point in the history
* chore: improve builtin middleware

* fix(docs): type link

* chore(core): improve realip type
  • Loading branch information
fundon committed Dec 18, 2023
1 parent 28d0943 commit 9336b1c
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 64 deletions.
3 changes: 2 additions & 1 deletion viz-core/src/middleware/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ where
let accept_encoding = req
.headers()
.get(ACCEPT_ENCODING)
.and_then(|v| v.to_str().ok())
.map(HeaderValue::to_str)
.and_then(Result::ok)
.and_then(parse_accept_encoding);

let raw = self.h.call(req).await?;
Expand Down
38 changes: 20 additions & 18 deletions viz-core/src/middleware/cookie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@ use std::fmt;
use crate::{
async_trait,
header::{HeaderValue, COOKIE, SET_COOKIE},
types, Handler, IntoResponse, Request, Response, Result, Transform,
types::{Cookie, CookieJar, CookieKey, Cookies},
Handler, IntoResponse, Request, Response, Result, Transform,
};

/// A configure for [`CookieMiddleware`].
pub struct Config {
#[cfg(any(feature = "cookie-signed", feature = "cookie-private"))]
key: std::sync::Arc<types::CookieKey>,
key: std::sync::Arc<CookieKey>,
}

impl Config {
/// Creates a new config with the [`Key`][types::CookieKey].
/// Creates a new config with the [`Key`][CookieKey].
#[cfg(any(feature = "cookie-signed", feature = "cookie-private"))]
#[must_use]
pub fn with_key(key: types::CookieKey) -> Self {
pub fn with_key(key: CookieKey) -> Self {
Self {
key: std::sync::Arc::new(key),
}
Expand All @@ -29,7 +30,7 @@ impl Default for Config {
fn default() -> Self {
Self {
#[cfg(any(feature = "cookie-signed", feature = "cookie-private"))]
key: std::sync::Arc::new(types::CookieKey::generate()),
key: std::sync::Arc::new(CookieKey::generate()),
}
}
}
Expand Down Expand Up @@ -65,7 +66,7 @@ where
pub struct CookieMiddleware<H> {
h: H,
#[cfg(any(feature = "cookie-signed", feature = "cookie-private"))]
key: std::sync::Arc<types::CookieKey>,
key: std::sync::Arc<CookieKey>,
}

impl<H> fmt::Debug for CookieMiddleware<H> {
Expand All @@ -82,8 +83,8 @@ impl<H> fmt::Debug for CookieMiddleware<H> {
#[async_trait]
impl<H, O> Handler<Request> for CookieMiddleware<H>
where
O: IntoResponse,
H: Handler<Request, Output = Result<O>> + Clone,
O: IntoResponse,
{
type Output = Result<Response>;

Expand All @@ -92,15 +93,15 @@ where
.headers()
.get_all(COOKIE)
.iter()
.filter_map(|c| HeaderValue::to_str(c).ok())
.fold(types::CookieJar::new(), add_cookie);
.map(HeaderValue::to_str)
.filter_map(Result::ok)
.fold(CookieJar::new(), add_cookie);

let cookies = types::Cookies::new(jar);
let cookies = Cookies::new(jar);
#[cfg(any(feature = "cookie-signed", feature = "cookie-private"))]
let cookies = cookies.with_key(self.key.clone());

req.extensions_mut()
.insert::<types::Cookies>(cookies.clone());
req.extensions_mut().insert::<Cookies>(cookies.clone());

self.h
.call(req)
Expand All @@ -109,9 +110,9 @@ where
.map(|mut res| {
if let Ok(c) = cookies.jar().lock() {
c.delta()
.filter_map(|cookie| {
HeaderValue::from_str(&cookie.encoded().to_string()).ok()
})
.map(Cookie::encoded)
.map(|cookie| HeaderValue::from_str(&cookie.to_string()))
.filter_map(Result::ok)
.fold(res.headers_mut(), |headers, cookie| {
headers.append(SET_COOKIE, cookie);
headers
Expand All @@ -123,9 +124,10 @@ where
}

#[inline]
fn add_cookie(mut jar: types::CookieJar, value: &str) -> types::CookieJar {
types::Cookie::split_parse_encoded(value)
fn add_cookie(mut jar: CookieJar, value: &str) -> CookieJar {
Cookie::split_parse_encoded(value)
.filter_map(Result::ok)
.for_each(|cookie| jar.add_original(cookie.into_owned()));
.map(Cookie::into_owned)
.for_each(|cookie| jar.add_original(cookie));
jar
}
23 changes: 13 additions & 10 deletions viz-core/src/middleware/cors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ impl Config {
{
self.allow_methods = allow_methods
.into_iter()
.filter_map(|m| m.try_into().ok())
.map(TryInto::try_into)
.filter_map(Result::ok)
.collect();
self
}
Expand All @@ -79,7 +80,8 @@ impl Config {
{
self.allow_headers = allow_headers
.into_iter()
.filter_map(|h| h.try_into().ok())
.map(TryInto::try_into)
.filter_map(Result::ok)
.collect();
self
}
Expand All @@ -95,7 +97,8 @@ impl Config {
{
self.allow_origins = allow_origins
.into_iter()
.filter_map(|h| h.try_into().ok())
.map(TryInto::try_into)
.filter_map(Result::ok)
.collect();
self
}
Expand All @@ -111,7 +114,8 @@ impl Config {
{
self.expose_headers = expose_headers
.into_iter()
.filter_map(|h| h.try_into().ok())
.map(TryInto::try_into)
.filter_map(Result::ok)
.collect();
self
}
Expand Down Expand Up @@ -202,8 +206,8 @@ pub struct CorsMiddleware<H> {
#[async_trait]
impl<H, O> Handler<Request> for CorsMiddleware<H>
where
O: IntoResponse,
H: Handler<Request, Output = Result<O>> + Clone,
O: IntoResponse,
{
type Output = Result<Response>;

Expand Down Expand Up @@ -244,7 +248,9 @@ where
hs.to_str()
.map(|hs| {
hs.split(',')
.filter_map(|h| HeaderName::from_bytes(h.as_bytes()).ok())
.map(str::as_bytes)
.map(HeaderName::from_bytes)
.filter_map(Result::ok)
.any(|header| self.config.allow_headers.contains(&header))
})
.unwrap_or(false),
Expand Down Expand Up @@ -273,10 +279,7 @@ where
headers.typed_insert(self.aceh.clone());
}

self.h
.call(req)
.await
.map_or_else(IntoResponse::into_response, IntoResponse::into_response)
self.h.call(req).await.map(IntoResponse::into_response)?
};

// https://github.com/rs/cors/issues/10
Expand Down
2 changes: 1 addition & 1 deletion viz-core/src/middleware/csrf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ where
#[async_trait]
impl<H, O, S, G, V> Handler<Request> for CsrfMiddleware<H, S, G, V>
where
O: IntoResponse,
H: Handler<Request, Output = Result<O>> + Clone,
O: IntoResponse,
S: Fn() -> Result<Vec<u8>> + Send + Sync + 'static,
G: Fn(&[u8], Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
V: Fn(&[u8], String) -> bool + Send + Sync + 'static,
Expand Down
2 changes: 1 addition & 1 deletion viz-core/src/middleware/limits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ pub struct LimitsMiddleware<H> {
#[async_trait]
impl<H, O> Handler<Request> for LimitsMiddleware<H>
where
O: IntoResponse,
H: Handler<Request, Output = Result<O>> + Clone,
O: IntoResponse,
{
type Output = Result<Response>;

Expand Down
2 changes: 1 addition & 1 deletion viz-core/src/middleware/otel/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ pub struct MetricsMiddleware<H> {
#[async_trait]
impl<H, O> Handler<Request> for MetricsMiddleware<H>
where
O: IntoResponse,
H: Handler<Request, Output = Result<O>> + Clone,
O: IntoResponse,
{
type Output = Result<Response>;

Expand Down
11 changes: 7 additions & 4 deletions viz-core/src/middleware/otel/tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use std::{net::SocketAddr, sync::Arc};

use http::uri::Scheme;
use http::{uri::Scheme, HeaderValue};
use opentelemetry::{
global,
propagation::Extractor,
Expand Down Expand Up @@ -61,10 +61,10 @@ pub struct TracingMiddleware<H, T> {
#[async_trait]
impl<H, O, T> Handler<Request> for TracingMiddleware<H, T>
where
H: Handler<Request, Output = Result<O>> + Clone,
O: IntoResponse,
T: Tracer + Send + Sync + Clone + 'static,
T::Span: Send + Sync + 'static,
O: IntoResponse,
H: Handler<Request, Output = Result<O>> + Clone,
{
type Output = Result<Response>;

Expand Down Expand Up @@ -143,7 +143,10 @@ impl<'a> RequestHeaderCarrier<'a> {

impl Extractor for RequestHeaderCarrier<'_> {
fn get(&self, key: &str) -> Option<&str> {
self.headers.get(key).and_then(|v| v.to_str().ok())
self.headers
.get(key)
.map(HeaderValue::to_str)
.and_then(Result::ok)
}

fn keys(&self) -> Vec<&str> {
Expand Down
22 changes: 7 additions & 15 deletions viz-core/src/middleware/session/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
use crate::{
async_trait,
middleware::helper::{CookieOptions, Cookieable},
types::Session,
types::{Cookie, Session},
Error, Handler, IntoResponse, Request, RequestExt, Response, Result, StatusCode, Transform,
};

Expand Down Expand Up @@ -92,19 +92,19 @@ where
#[async_trait]
impl<H, O, S, G, V> Handler<Request> for SessionMiddleware<H, S, G, V>
where
O: IntoResponse,
H: Handler<Request, Output = Result<O>> + Clone,
O: IntoResponse,
S: Storage + 'static,
G: Fn() -> String + Send + Sync + 'static,
V: Fn(&str) -> bool + Send + Sync + 'static,
{
type Output = Result<Response>;

async fn call(&self, mut req: Request) -> Self::Output {
let cookies = req.cookies().map_err(Into::<Error>::into)?;
let cookies = req.cookies().map_err(Error::from)?;
let cookie = self.config.get_cookie(&cookies);

let mut session_id = cookie.map(|cookie| cookie.value().to_string());
let mut session_id = cookie.as_ref().map(Cookie::value).map(ToString::to_string);
let data = match &session_id {
Some(sid) if (self.config.store().verify)(sid) => self.config.store().get(sid).await?,
_ => None,
Expand All @@ -125,11 +125,7 @@ where

if status == PURGED {
if let Some(sid) = &session_id {
self.config
.store()
.remove(sid)
.await
.map_err(Into::<Error>::into)?;
self.config.store().remove(sid).await.map_err(Error::from)?;
self.config.remove_cookie(&cookies);
}

Expand All @@ -138,11 +134,7 @@ where

if status == RENEWED {
if let Some(sid) = &session_id.take() {
self.config
.store()
.remove(sid)
.await
.map_err(Into::<Error>::into)?;
self.config.store().remove(sid).await.map_err(Error::from)?;
}
}

Expand All @@ -160,7 +152,7 @@ where
&self.config.ttl().unwrap_or_else(max_age),
)
.await
.map_err(Into::<Error>::into)?;
.map_err(Error::from)?;

resp
}
Expand Down
36 changes: 25 additions & 11 deletions viz-core/src/types/realip.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
use std::{net::IpAddr, str};
use std::{
net::{IpAddr, SocketAddr},
str,
};

use rfc7239::{NodeIdentifier, NodeName};

use crate::{header::FORWARDED, Request, RequestExt, Result};
use crate::{
header::{HeaderValue, FORWARDED},
Request, RequestExt, Result,
};

/// Gets real ip remote addr from request headers.
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
Expand All @@ -19,15 +25,21 @@ impl RealIp {
pub fn parse(req: &Request) -> Option<Self> {
req.headers()
.get(Self::X_REAL_IP)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<IpAddr>().ok())
.map(HeaderValue::to_str)
.and_then(Result::ok)
.map(str::parse)
.and_then(Result::ok)
.or_else(|| {
req.headers()
.get(FORWARDED)
.and_then(|value| value.to_str().ok())
.and_then(|value| rfc7239::parse(value).collect::<Result<Vec<_>, _>>().ok())
.and_then(|value| {
value.into_iter().find_map(|item| match item.forwarded_for {
.map(HeaderValue::to_str)
.and_then(Result::ok)
.map(rfc7239::parse)
.map(Iterator::collect)
.and_then(Result::ok)
.map(Vec::into_iter)
.and_then(|mut value| {
value.find_map(|item| match item.forwarded_for {
Some(NodeIdentifier {
name: NodeName::Ip(ip_addr),
..
Expand All @@ -39,15 +51,17 @@ impl RealIp {
.or_else(|| {
req.headers()
.get(Self::X_FORWARDED_FOR)
.and_then(|value| value.to_str().ok())
.map(HeaderValue::to_str)
.and_then(Result::ok)
.and_then(|value| {
value
.split(',')
.map(str::trim)
.find_map(|value| value.parse::<IpAddr>().ok())
.map(str::parse)
.find_map(Result::ok)
})
})
.map(RealIp)
.or_else(|| req.remote_addr().map(|addr| RealIp(addr.ip())))
.or_else(|| req.remote_addr().map(SocketAddr::ip).map(RealIp))
}
}
4 changes: 2 additions & 2 deletions viz-core/tests/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ async fn handler() -> Result<()> {
}

async fn a(_: Request) -> Result<Response> {
Err(CustomError::NotFound)?;
Err(CustomError2::NotFound)?;
// Err(CustomError::NotFound)?;
// Err(CustomError2::NotFound)?;
Ok(().into_response())
}
async fn b(_: Request) -> Result<Response> {
Expand Down

0 comments on commit 9336b1c

Please sign in to comment.