diff --git a/Cargo.lock b/Cargo.lock index bbace764..dc57c3bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4093,6 +4093,7 @@ dependencies = [ "http-body-util", "hyper 1.5.0", "hyper-util", + "ipnet", "itoa", "libc", "matchit 0.8.5", diff --git a/Cargo.toml b/Cargo.toml index 30a7abbe..33a8ea36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,7 @@ http-body-util = "0.1" hyper = "1" hyper-timeout = "0.5" hyper-util = "0.1" +ipnet = "2.10" itertools = "0.13" itoa = "1" libc = "0.2" diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index d39cc8f5..44603a18 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -34,6 +34,7 @@ http-body.workspace = true http-body-util.workspace = true hyper.workspace = true hyper-util = { workspace = true, features = ["tokio"] } +ipnet.workspace = true itoa.workspace = true memchr.workspace = true metainfo.workspace = true diff --git a/volo-http/src/server/extract.rs b/volo-http/src/server/extract.rs index d616c32e..315aaae3 100644 --- a/volo-http/src/server/extract.rs +++ b/volo-http/src/server/extract.rs @@ -23,6 +23,7 @@ use crate::{ context::ServerContext, error::server::{body_collection_error, ExtractBodyError}, request::{Request, RequestPartsExt}, + server::utils::client_ip::ClientIP, utils::macros::impl_deref_and_deref_mut, }; @@ -290,6 +291,23 @@ impl FromContext for Method { } } +impl FromContext for ClientIP { + type Rejection = Infallible; + + async fn from_context( + cx: &mut ServerContext, + _: &mut Parts, + ) -> Result { + Ok(ClientIP( + cx.rpc_info + .caller() + .tags + .get::() + .and_then(|v| v.0), + )) + } +} + #[cfg(feature = "query")] impl FromContext for Query where diff --git a/volo-http/src/server/layer/timeout.rs b/volo-http/src/server/layer/timeout.rs index 67119f12..9aabfc9f 100644 --- a/volo-http/src/server/layer/timeout.rs +++ b/volo-http/src/server/layer/timeout.rs @@ -111,7 +111,7 @@ where tokio::select! { resp = fut_service => resp.map(IntoResponse::into_response), _ = fut_timeout => { - Ok((self.handler.clone()).call(cx)) + Ok(self.handler.clone().call(cx)) }, } } diff --git a/volo-http/src/server/utils/client_ip.rs b/volo-http/src/server/utils/client_ip.rs new file mode 100644 index 00000000..8869e3f6 --- /dev/null +++ b/volo-http/src/server/utils/client_ip.rs @@ -0,0 +1,333 @@ +//! Utilities for extracting original client ip +//! +//! See [`ClientIP`] for more details. +use std::{net::IpAddr, str::FromStr}; + +use http::{HeaderMap, HeaderName}; +use ipnet::IpNet; +use motore::{layer::Layer, Service}; +use volo::{context::Context, net::Address}; + +use crate::{context::ServerContext, request::Request, utils::macros::impl_deref_and_deref_mut}; + +/// [`Layer`] for extracting client ip +/// +/// See [`ClientIP`] for more details. +#[derive(Clone, Default)] +pub struct ClientIPLayer { + config: ClientIPConfig, +} + +impl ClientIPLayer { + /// Create a new [`ClientIPLayer`] with default config + pub fn new() -> Self { + Default::default() + } + + /// Create a new [`ClientIPLayer`] with the given [`ClientIPConfig`] + pub fn with_config(self, config: ClientIPConfig) -> Self { + Self { config } + } +} + +impl Layer for ClientIPLayer +where + S: Send + Sync + 'static, +{ + type Service = ClientIPService; + + fn layer(self, inner: S) -> Self::Service { + ClientIPService { + service: inner, + config: self.config, + } + } +} + +/// Config for extract client ip +#[derive(Clone, Debug)] +pub struct ClientIPConfig { + remote_ip_headers: Vec, + trusted_cidrs: Vec, +} + +impl Default for ClientIPConfig { + fn default() -> Self { + Self { + remote_ip_headers: vec![ + HeaderName::from_static("x-real-ip"), + HeaderName::from_static("x-forwarded-for"), + ], + trusted_cidrs: vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()], + } + } +} + +impl ClientIPConfig { + /// Create a new [`ClientIPConfig`] with default values + /// + /// default remote ip headers: `["X-Real-IP", "X-Forwarded-For"]` + /// + /// default trusted cidrs: `["0.0.0.0/0", "::/0"]` + pub fn new() -> Self { + Default::default() + } + + /// Get Real Client IP by parsing the given headers. + /// + /// See [`ClientIP`] for more details. + /// + /// # Example + /// + /// ```rust + /// use volo_http::server::utils::client_ip::ClientIPConfig; + /// + /// let client_ip_config = + /// ClientIPConfig::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]); + /// ``` + pub fn with_remote_ip_headers( + self, + headers: I, + ) -> Result + where + I: IntoIterator, + I::Item: AsRef, + { + let headers = headers.into_iter().collect::>(); + let mut remote_ip_headers = Vec::with_capacity(headers.len()); + for header_str in headers { + let header_value = HeaderName::from_str(header_str.as_ref())?; + remote_ip_headers.push(header_value); + } + + Ok(Self { + remote_ip_headers, + trusted_cidrs: self.trusted_cidrs, + }) + } + + /// Get Real Client IP if it is trusted, otherwise it will just return caller ip. + /// + /// See [`ClientIP`] for more details. + /// + /// # Example + /// + /// ```rust + /// use volo_http::server::utils::client_ip::ClientIPConfig; + /// + /// let client_ip_config = ClientIPConfig::new() + /// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]); + /// ``` + pub fn with_trusted_cidrs(self, cidrs: H) -> Self + where + H: IntoIterator, + { + Self { + remote_ip_headers: self.remote_ip_headers, + trusted_cidrs: cidrs.into_iter().collect(), + } + } +} + +/// Return original client IP Address +/// +/// If you want to get client IP by retrieving specific headers, you can use +/// [`with_remote_ip_headers`](ClientIPConfig::with_remote_ip_headers) to set the +/// headers. +/// +/// If you want to get client IP that is trusted with specific cidrs, you can use +/// [`with_trusted_cidrs`](ClientIPConfig::with_trusted_cidrs) to set the cidrs. +/// +/// # Example +/// +/// ## Default config +/// +/// default remote ip headers: `["X-Real-IP", "X-Forwarded-For"]` +/// +/// default trusted cidrs: `["0.0.0.0/0", "::/0"]` +/// +/// ```rust +/// /// +/// use volo_http::server::utils::client_ip::ClientIP; +/// use volo_http::server::{ +/// route::{get, Router}, +/// utils::client_ip::{ClientIPConfig, ClientIPLayer}, +/// Server, +/// }; +/// +/// async fn handler(client_ip: ClientIP) -> String { +/// client_ip.unwrap().to_string() +/// } +/// +/// let router: Router = Router::new() +/// .route("/", get(handler)) +/// .layer(ClientIPLayer::new()); +/// ``` +/// +/// ## With custom config +/// +/// ```rust +/// use http::HeaderMap; +/// use volo_http::{ +/// context::ServerContext, +/// server::{ +/// route::{get, Router}, +/// utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer}, +/// Server, +/// }, +/// }; +/// +/// async fn handler(client_ip: ClientIP) -> String { +/// client_ip.unwrap().to_string() +/// } +/// +/// let router: Router = Router::new().route("/", get(handler)).layer( +/// ClientIPLayer::new().with_config( +/// ClientIPConfig::new() +/// .with_remote_ip_headers(vec!["x-real-ip", "x-forwarded-for"]) +/// .unwrap() +/// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]), +/// ), +/// ); +/// ``` +pub struct ClientIP(pub Option); + +impl_deref_and_deref_mut!(ClientIP, Option, 0); + +/// [`ClientIPLayer`] generated [`Service`] +/// +/// See [`ClientIP`] for more details. +#[derive(Clone)] +pub struct ClientIPService { + service: S, + config: ClientIPConfig, +} + +impl ClientIPService { + fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIP { + let remote_ip = match &cx.rpc_info().caller().address { + Some(Address::Ip(socket_addr)) => Some(socket_addr.ip()), + #[cfg(target_family = "unix")] + Some(Address::Unix(_)) => None, + None => return ClientIP(None), + }; + + if let Some(remote_ip) = remote_ip { + if !self + .config + .trusted_cidrs + .iter() + .any(|cidr| cidr.contains(&IpNet::from(remote_ip))) + { + return ClientIP(None); + } + } + + for remote_ip_header in self.config.remote_ip_headers.iter() { + let remote_ips = match headers + .get(remote_ip_header) + .and_then(|v| v.to_str().ok()) + .map(|v| v.split(',').map(|s| s.trim()).collect::>()) + { + Some(remote_ips) => remote_ips, + None => continue, + }; + for remote_ip in remote_ips.iter() { + if let Ok(remote_ip_addr) = IpAddr::from_str(remote_ip) { + if self + .config + .trusted_cidrs + .iter() + .any(|cidr| cidr.contains(&remote_ip_addr)) + { + return ClientIP(Some(remote_ip_addr)); + } + } + } + } + + ClientIP(remote_ip) + } +} + +impl Service> for ClientIPService +where + S: Service> + Send + Sync + 'static, + B: Send, +{ + type Response = S::Response; + type Error = S::Error; + + async fn call( + &self, + cx: &mut ServerContext, + req: Request, + ) -> Result { + let client_ip = self.get_client_ip(cx, req.headers()); + cx.rpc_info_mut().caller_mut().tags.insert(client_ip); + + self.service.call(cx, req).await + } +} + +#[cfg(test)] +mod client_ip_tests { + use std::{net::SocketAddr, str::FromStr}; + + use http::{HeaderValue, Method}; + use motore::{layer::Layer, Service}; + use volo::net::Address; + + use crate::{ + body::BodyConversion, + context::ServerContext, + server::{ + route::{get, Route}, + utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer}, + }, + utils::test_helpers::simple_req, + }; + + #[tokio::test] + async fn test_client_ip() { + async fn handler(client_ip: ClientIP) -> String { + client_ip.unwrap().to_string() + } + + let route: Route<&str> = Route::new(get(handler)); + let service = ClientIPLayer::new() + .with_config( + ClientIPConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]), + ) + .layer(route); + + let mut cx = ServerContext::new(Address::from( + SocketAddr::from_str("10.0.0.1:8080").unwrap(), + )); + + // Test case 1: no remote ip header + let req = simple_req(Method::GET, "/", ""); + let resp = service.call(&mut cx, req).await.unwrap(); + assert_eq!("10.0.0.1", resp.into_string().await.unwrap()); + + // Test case 2: with remote ip header + let mut req = simple_req(Method::GET, "/", ""); + req.headers_mut() + .insert("X-Real-IP", HeaderValue::from_static("10.0.0.2")); + let resp = service.call(&mut cx, req).await.unwrap(); + assert_eq!("10.0.0.2", resp.into_string().await.unwrap()); + + let mut req = simple_req(Method::GET, "/", ""); + req.headers_mut() + .insert("X-Forwarded-For", HeaderValue::from_static("10.0.1.0")); + let resp = service.call(&mut cx, req).await.unwrap(); + assert_eq!("10.0.1.0", resp.into_string().await.unwrap()); + + // Test case 3: with untrusted remote ip + let mut req = simple_req(Method::GET, "/", ""); + req.headers_mut() + .insert("X-Real-IP", HeaderValue::from_static("11.0.0.1")); + let resp = service.call(&mut cx, req).await.unwrap(); + assert_eq!("10.0.0.1", resp.into_string().await.unwrap()); + } +} diff --git a/volo-http/src/server/utils/mod.rs b/volo-http/src/server/utils/mod.rs index 38f75ff8..c9160cf5 100644 --- a/volo-http/src/server/utils/mod.rs +++ b/volo-http/src/server/utils/mod.rs @@ -5,6 +5,7 @@ mod serve_dir; pub use file_response::FileResponse; pub use serve_dir::ServeDir; +pub mod client_ip; #[cfg(feature = "multipart")] pub mod multipart; #[cfg(feature = "ws")]