From 519c1f469957f67b4d5974c37a6d10ce31cd0d96 Mon Sep 17 00:00:00 2001 From: Yukiteru Date: Fri, 29 Nov 2024 16:50:43 +0800 Subject: [PATCH] chore(volo-http): format codes of client ip (#539) Signed-off-by: Yu Li --- volo-http/src/server/extract.rs | 13 +-- volo-http/src/server/utils/client_ip.rs | 136 ++++++++++++------------ volo-http/src/server/utils/mod.rs | 1 + 3 files changed, 76 insertions(+), 74 deletions(-) diff --git a/volo-http/src/server/extract.rs b/volo-http/src/server/extract.rs index 3aa81c89..4aaafa14 100644 --- a/volo-http/src/server/extract.rs +++ b/volo-http/src/server/extract.rs @@ -23,7 +23,7 @@ use crate::{ context::ServerContext, error::server::{body_collection_error, ExtractBodyError}, request::{Request, RequestPartsExt}, - server::utils::client_ip::ClientIP, + server::utils::client_ip::ClientIp, utils::macros::impl_deref_and_deref_mut, }; @@ -291,18 +291,15 @@ impl FromContext for Method { } } -impl FromContext for ClientIP { +impl FromContext for ClientIp { type Rejection = Infallible; - async fn from_context( - cx: &mut ServerContext, - _: &mut Parts, - ) -> Result { - Ok(ClientIP( + async fn from_context(cx: &mut ServerContext, _: &mut Parts) -> Result { + Ok(ClientIp( cx.rpc_info .caller() .tags - .get::() + .get::() .and_then(|v| v.0), )) } diff --git a/volo-http/src/server/utils/client_ip.rs b/volo-http/src/server/utils/client_ip.rs index 8869e3f6..e8a8c14a 100644 --- a/volo-http/src/server/utils/client_ip.rs +++ b/volo-http/src/server/utils/client_ip.rs @@ -1,43 +1,46 @@ //! Utilities for extracting original client ip //! -//! See [`ClientIP`] for more details. -use std::{net::IpAddr, str::FromStr}; +//! See [`ClientIp`] for more details. +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + str::FromStr, +}; use http::{HeaderMap, HeaderName}; -use ipnet::IpNet; +use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use motore::{layer::Layer, Service}; use volo::{context::Context, net::Address}; -use crate::{context::ServerContext, request::Request, utils::macros::impl_deref_and_deref_mut}; +use crate::{context::ServerContext, request::Request}; /// [`Layer`] for extracting client ip /// -/// See [`ClientIP`] for more details. -#[derive(Clone, Default)] -pub struct ClientIPLayer { - config: ClientIPConfig, +/// See [`ClientIp`] for more details. +#[derive(Clone, Debug, Default)] +pub struct ClientIpLayer { + config: ClientIpConfig, } -impl ClientIPLayer { - /// Create a new [`ClientIPLayer`] with default config +impl ClientIpLayer { + /// Create a new [`ClientIpLayer`] with default config pub fn new() -> Self { Default::default() } - /// Create a new [`ClientIPLayer`] with the given [`ClientIPConfig`] - pub fn with_config(self, config: ClientIPConfig) -> Self { + /// Create a new [`ClientIpLayer`] with the given [`ClientIpConfig`] + pub fn with_config(self, config: ClientIpConfig) -> Self { Self { config } } } -impl Layer for ClientIPLayer +impl Layer for ClientIpLayer where S: Send + Sync + 'static, { - type Service = ClientIPService; + type Service = ClientIpService; fn layer(self, inner: S) -> Self::Service { - ClientIPService { + ClientIpService { service: inner, config: self.config, } @@ -46,25 +49,31 @@ where /// Config for extract client ip #[derive(Clone, Debug)] -pub struct ClientIPConfig { +pub struct ClientIpConfig { remote_ip_headers: Vec, trusted_cidrs: Vec, } -impl Default for ClientIPConfig { +impl Default for ClientIpConfig { fn default() -> Self { Self { remote_ip_headers: vec![ HeaderName::from_static("x-real-ip"), HeaderName::from_static("x-forwarded-for"), ], - trusted_cidrs: vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()], + trusted_cidrs: vec![ + IpNet::V4(Ipv4Net::new_assert(Ipv4Addr::new(0, 0, 0, 0), 0)), + IpNet::V6(Ipv6Net::new_assert( + Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), + 0, + )), + ], } } } -impl ClientIPConfig { - /// Create a new [`ClientIPConfig`] with default values +impl ClientIpConfig { + /// Create a new [`ClientIpConfig`] with default values /// /// default remote ip headers: `["X-Real-IP", "X-Forwarded-For"]` /// @@ -75,15 +84,15 @@ impl ClientIPConfig { /// Get Real Client IP by parsing the given headers. /// - /// See [`ClientIP`] for more details. + /// See [`ClientIp`] for more details. /// /// # Example /// /// ```rust - /// use volo_http::server::utils::client_ip::ClientIPConfig; + /// use volo_http::server::utils::client_ip::ClientIpConfig; /// /// let client_ip_config = - /// ClientIPConfig::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]); + /// ClientIpConfig::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]); /// ``` pub fn with_remote_ip_headers( self, @@ -108,14 +117,14 @@ impl ClientIPConfig { /// Get Real Client IP if it is trusted, otherwise it will just return caller ip. /// - /// See [`ClientIP`] for more details. + /// See [`ClientIp`] for more details. /// /// # Example /// /// ```rust - /// use volo_http::server::utils::client_ip::ClientIPConfig; + /// use volo_http::server::utils::client_ip::ClientIpConfig; /// - /// let client_ip_config = ClientIPConfig::new() + /// let client_ip_config = ClientIpConfig::new() /// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]); /// ``` pub fn with_trusted_cidrs(self, cidrs: H) -> Self @@ -132,11 +141,11 @@ impl ClientIPConfig { /// Return original client IP Address /// /// If you want to get client IP by retrieving specific headers, you can use -/// [`with_remote_ip_headers`](ClientIPConfig::with_remote_ip_headers) to set the +/// [`with_remote_ip_headers`](ClientIpConfig::with_remote_ip_headers) to set the /// headers. /// /// If you want to get client IP that is trusted with specific cidrs, you can use -/// [`with_trusted_cidrs`](ClientIPConfig::with_trusted_cidrs) to set the cidrs. +/// [`with_trusted_cidrs`](ClientIpConfig::with_trusted_cidrs) to set the cidrs. /// /// # Example /// @@ -148,20 +157,20 @@ impl ClientIPConfig { /// /// ```rust /// /// -/// use volo_http::server::utils::client_ip::ClientIP; +/// use volo_http::server::utils::client_ip::ClientIp; /// use volo_http::server::{ /// route::{get, Router}, -/// utils::client_ip::{ClientIPConfig, ClientIPLayer}, +/// utils::client_ip::{ClientIpConfig, ClientIpLayer}, /// Server, /// }; /// -/// async fn handler(client_ip: ClientIP) -> String { +/// async fn handler(ClientIp(client_ip): ClientIp) -> String { /// client_ip.unwrap().to_string() /// } /// /// let router: Router = Router::new() /// .route("/", get(handler)) -/// .layer(ClientIPLayer::new()); +/// .layer(ClientIpLayer::new()); /// ``` /// /// ## With custom config @@ -172,67 +181,62 @@ impl ClientIPConfig { /// context::ServerContext, /// server::{ /// route::{get, Router}, -/// utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer}, +/// utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer}, /// Server, /// }, /// }; /// -/// async fn handler(client_ip: ClientIP) -> String { +/// async fn handler(ClientIp(client_ip): ClientIp) -> String { /// client_ip.unwrap().to_string() /// } /// /// let router: Router = Router::new().route("/", get(handler)).layer( -/// ClientIPLayer::new().with_config( -/// ClientIPConfig::new() +/// ClientIpLayer::new().with_config( +/// ClientIpConfig::new() /// .with_remote_ip_headers(vec!["x-real-ip", "x-forwarded-for"]) /// .unwrap() /// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]), /// ), /// ); /// ``` -pub struct ClientIP(pub Option); - -impl_deref_and_deref_mut!(ClientIP, Option, 0); +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ClientIp(pub Option); -/// [`ClientIPLayer`] generated [`Service`] +/// [`ClientIpLayer`] generated [`Service`] /// -/// See [`ClientIP`] for more details. -#[derive(Clone)] -pub struct ClientIPService { +/// See [`ClientIp`] for more details. +#[derive(Clone, Debug)] +pub struct ClientIpService { service: S, - config: ClientIPConfig, + config: ClientIpConfig, } -impl ClientIPService { - fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIP { +impl ClientIpService { + fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIp { let remote_ip = match &cx.rpc_info().caller().address { Some(Address::Ip(socket_addr)) => Some(socket_addr.ip()), #[cfg(target_family = "unix")] Some(Address::Unix(_)) => None, - None => return ClientIP(None), + None => return ClientIp(None), }; - if let Some(remote_ip) = remote_ip { + if let Some(remote_ip) = &remote_ip { if !self .config .trusted_cidrs .iter() - .any(|cidr| cidr.contains(&IpNet::from(remote_ip))) + .any(|cidr| cidr.contains(remote_ip)) { - return ClientIP(None); + return ClientIp(None); } } for remote_ip_header in self.config.remote_ip_headers.iter() { - let remote_ips = match headers - .get(remote_ip_header) - .and_then(|v| v.to_str().ok()) - .map(|v| v.split(',').map(|s| s.trim()).collect::>()) - { - Some(remote_ips) => remote_ips, - None => continue, + let Some(remote_ips) = headers.get(remote_ip_header).and_then(|v| v.to_str().ok()) + else { + continue; }; - for remote_ip in remote_ips.iter() { + for remote_ip in remote_ips.split(',').map(str::trim) { if let Ok(remote_ip_addr) = IpAddr::from_str(remote_ip) { if self .config @@ -240,17 +244,17 @@ impl ClientIPService { .iter() .any(|cidr| cidr.contains(&remote_ip_addr)) { - return ClientIP(Some(remote_ip_addr)); + return ClientIp(Some(remote_ip_addr)); } } } } - ClientIP(remote_ip) + ClientIp(remote_ip) } } -impl Service> for ClientIPService +impl Service> for ClientIpService where S: Service> + Send + Sync + 'static, B: Send, @@ -264,7 +268,7 @@ where req: Request, ) -> Result { let client_ip = self.get_client_ip(cx, req.headers()); - cx.rpc_info_mut().caller_mut().tags.insert(client_ip); + cx.rpc_info_mut().caller_mut().insert(client_ip); self.service.call(cx, req).await } @@ -283,21 +287,21 @@ mod client_ip_tests { context::ServerContext, server::{ route::{get, Route}, - utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer}, + utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer}, }, utils::test_helpers::simple_req, }; #[tokio::test] async fn test_client_ip() { - async fn handler(client_ip: ClientIP) -> String { + async fn handler(ClientIp(client_ip): ClientIp) -> String { client_ip.unwrap().to_string() } let route: Route<&str> = Route::new(get(handler)); - let service = ClientIPLayer::new() + let service = ClientIpLayer::new() .with_config( - ClientIPConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]), + ClientIpConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]), ) .layer(route); diff --git a/volo-http/src/server/utils/mod.rs b/volo-http/src/server/utils/mod.rs index c9160cf5..4ed1be34 100644 --- a/volo-http/src/server/utils/mod.rs +++ b/volo-http/src/server/utils/mod.rs @@ -5,6 +5,7 @@ mod serve_dir; pub use file_response::FileResponse; pub use serve_dir::ServeDir; + pub mod client_ip; #[cfg(feature = "multipart")] pub mod multipart;