Skip to content

Commit

Permalink
chore(volo-http): format codes of client ip (#539)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Li <[email protected]>
  • Loading branch information
yukiiiteru authored Nov 29, 2024
1 parent 70b9da7 commit 519c1f4
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 74 deletions.
13 changes: 5 additions & 8 deletions volo-http/src/server/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
context::ServerContext,
error::server::{body_collection_error, ExtractBodyError},
request::{Request, RequestPartsExt},
server::utils::client_ip::ClientIP,
server::utils::client_ip::ClientIp,
utils::macros::impl_deref_and_deref_mut,
};

Expand Down Expand Up @@ -291,18 +291,15 @@ impl FromContext for Method {
}
}

impl FromContext for ClientIP {
impl FromContext for ClientIp {
type Rejection = Infallible;

async fn from_context(
cx: &mut ServerContext,
_: &mut Parts,
) -> Result<ClientIP, Self::Rejection> {
Ok(ClientIP(
async fn from_context(cx: &mut ServerContext, _: &mut Parts) -> Result<Self, Self::Rejection> {
Ok(ClientIp(
cx.rpc_info
.caller()
.tags
.get::<ClientIP>()
.get::<ClientIp>()
.and_then(|v| v.0),
))
}
Expand Down
136 changes: 70 additions & 66 deletions volo-http/src/server/utils/client_ip.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,46 @@
//! Utilities for extracting original client ip
//!
//! See [`ClientIP`] for more details.
use std::{net::IpAddr, str::FromStr};
//! See [`ClientIp`] for more details.
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
str::FromStr,
};

use http::{HeaderMap, HeaderName};
use ipnet::IpNet;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use motore::{layer::Layer, Service};
use volo::{context::Context, net::Address};

use crate::{context::ServerContext, request::Request, utils::macros::impl_deref_and_deref_mut};
use crate::{context::ServerContext, request::Request};

/// [`Layer`] for extracting client ip
///
/// See [`ClientIP`] for more details.
#[derive(Clone, Default)]
pub struct ClientIPLayer {
config: ClientIPConfig,
/// See [`ClientIp`] for more details.
#[derive(Clone, Debug, Default)]
pub struct ClientIpLayer {
config: ClientIpConfig,
}

impl ClientIPLayer {
/// Create a new [`ClientIPLayer`] with default config
impl ClientIpLayer {
/// Create a new [`ClientIpLayer`] with default config
pub fn new() -> Self {
Default::default()
}

/// Create a new [`ClientIPLayer`] with the given [`ClientIPConfig`]
pub fn with_config(self, config: ClientIPConfig) -> Self {
/// Create a new [`ClientIpLayer`] with the given [`ClientIpConfig`]
pub fn with_config(self, config: ClientIpConfig) -> Self {
Self { config }
}
}

impl<S> Layer<S> for ClientIPLayer
impl<S> Layer<S> for ClientIpLayer
where
S: Send + Sync + 'static,
{
type Service = ClientIPService<S>;
type Service = ClientIpService<S>;

fn layer(self, inner: S) -> Self::Service {
ClientIPService {
ClientIpService {
service: inner,
config: self.config,
}
Expand All @@ -46,25 +49,31 @@ where

/// Config for extract client ip
#[derive(Clone, Debug)]
pub struct ClientIPConfig {
pub struct ClientIpConfig {
remote_ip_headers: Vec<HeaderName>,
trusted_cidrs: Vec<IpNet>,
}

impl Default for ClientIPConfig {
impl Default for ClientIpConfig {
fn default() -> Self {
Self {
remote_ip_headers: vec![
HeaderName::from_static("x-real-ip"),
HeaderName::from_static("x-forwarded-for"),
],
trusted_cidrs: vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()],
trusted_cidrs: vec![
IpNet::V4(Ipv4Net::new_assert(Ipv4Addr::new(0, 0, 0, 0), 0)),
IpNet::V6(Ipv6Net::new_assert(
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
0,
)),
],
}
}
}

impl ClientIPConfig {
/// Create a new [`ClientIPConfig`] with default values
impl ClientIpConfig {
/// Create a new [`ClientIpConfig`] with default values
///
/// default remote ip headers: `["X-Real-IP", "X-Forwarded-For"]`
///
Expand All @@ -75,15 +84,15 @@ impl ClientIPConfig {

/// Get Real Client IP by parsing the given headers.
///
/// See [`ClientIP`] for more details.
/// See [`ClientIp`] for more details.
///
/// # Example
///
/// ```rust
/// use volo_http::server::utils::client_ip::ClientIPConfig;
/// use volo_http::server::utils::client_ip::ClientIpConfig;
///
/// let client_ip_config =
/// ClientIPConfig::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]);
/// ClientIpConfig::new().with_remote_ip_headers(vec!["X-Real-IP", "X-Forwarded-For"]);
/// ```
pub fn with_remote_ip_headers<I>(
self,
Expand All @@ -108,14 +117,14 @@ impl ClientIPConfig {

/// Get Real Client IP if it is trusted, otherwise it will just return caller ip.
///
/// See [`ClientIP`] for more details.
/// See [`ClientIp`] for more details.
///
/// # Example
///
/// ```rust
/// use volo_http::server::utils::client_ip::ClientIPConfig;
/// use volo_http::server::utils::client_ip::ClientIpConfig;
///
/// let client_ip_config = ClientIPConfig::new()
/// let client_ip_config = ClientIpConfig::new()
/// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]);
/// ```
pub fn with_trusted_cidrs<H>(self, cidrs: H) -> Self
Expand All @@ -132,11 +141,11 @@ impl ClientIPConfig {
/// Return original client IP Address
///
/// If you want to get client IP by retrieving specific headers, you can use
/// [`with_remote_ip_headers`](ClientIPConfig::with_remote_ip_headers) to set the
/// [`with_remote_ip_headers`](ClientIpConfig::with_remote_ip_headers) to set the
/// headers.
///
/// If you want to get client IP that is trusted with specific cidrs, you can use
/// [`with_trusted_cidrs`](ClientIPConfig::with_trusted_cidrs) to set the cidrs.
/// [`with_trusted_cidrs`](ClientIpConfig::with_trusted_cidrs) to set the cidrs.
///
/// # Example
///
Expand All @@ -148,20 +157,20 @@ impl ClientIPConfig {
///
/// ```rust
/// ///
/// use volo_http::server::utils::client_ip::ClientIP;
/// use volo_http::server::utils::client_ip::ClientIp;
/// use volo_http::server::{
/// route::{get, Router},
/// utils::client_ip::{ClientIPConfig, ClientIPLayer},
/// utils::client_ip::{ClientIpConfig, ClientIpLayer},
/// Server,
/// };
///
/// async fn handler(client_ip: ClientIP) -> String {
/// async fn handler(ClientIp(client_ip): ClientIp) -> String {
/// client_ip.unwrap().to_string()
/// }
///
/// let router: Router = Router::new()
/// .route("/", get(handler))
/// .layer(ClientIPLayer::new());
/// .layer(ClientIpLayer::new());
/// ```
///
/// ## With custom config
Expand All @@ -172,85 +181,80 @@ impl ClientIPConfig {
/// context::ServerContext,
/// server::{
/// route::{get, Router},
/// utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer},
/// utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer},
/// Server,
/// },
/// };
///
/// async fn handler(client_ip: ClientIP) -> String {
/// async fn handler(ClientIp(client_ip): ClientIp) -> String {
/// client_ip.unwrap().to_string()
/// }
///
/// let router: Router = Router::new().route("/", get(handler)).layer(
/// ClientIPLayer::new().with_config(
/// ClientIPConfig::new()
/// ClientIpLayer::new().with_config(
/// ClientIpConfig::new()
/// .with_remote_ip_headers(vec!["x-real-ip", "x-forwarded-for"])
/// .unwrap()
/// .with_trusted_cidrs(vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]),
/// ),
/// );
/// ```
pub struct ClientIP(pub Option<IpAddr>);

impl_deref_and_deref_mut!(ClientIP, Option<IpAddr>, 0);
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ClientIp(pub Option<IpAddr>);

/// [`ClientIPLayer`] generated [`Service`]
/// [`ClientIpLayer`] generated [`Service`]
///
/// See [`ClientIP`] for more details.
#[derive(Clone)]
pub struct ClientIPService<S> {
/// See [`ClientIp`] for more details.
#[derive(Clone, Debug)]
pub struct ClientIpService<S> {
service: S,
config: ClientIPConfig,
config: ClientIpConfig,
}

impl<S> ClientIPService<S> {
fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIP {
impl<S> ClientIpService<S> {
fn get_client_ip(&self, cx: &ServerContext, headers: &HeaderMap) -> ClientIp {
let remote_ip = match &cx.rpc_info().caller().address {
Some(Address::Ip(socket_addr)) => Some(socket_addr.ip()),
#[cfg(target_family = "unix")]
Some(Address::Unix(_)) => None,
None => return ClientIP(None),
None => return ClientIp(None),
};

if let Some(remote_ip) = remote_ip {
if let Some(remote_ip) = &remote_ip {
if !self
.config
.trusted_cidrs
.iter()
.any(|cidr| cidr.contains(&IpNet::from(remote_ip)))
.any(|cidr| cidr.contains(remote_ip))
{
return ClientIP(None);
return ClientIp(None);
}
}

for remote_ip_header in self.config.remote_ip_headers.iter() {
let remote_ips = match headers
.get(remote_ip_header)
.and_then(|v| v.to_str().ok())
.map(|v| v.split(',').map(|s| s.trim()).collect::<Vec<_>>())
{
Some(remote_ips) => remote_ips,
None => continue,
let Some(remote_ips) = headers.get(remote_ip_header).and_then(|v| v.to_str().ok())
else {
continue;
};
for remote_ip in remote_ips.iter() {
for remote_ip in remote_ips.split(',').map(str::trim) {
if let Ok(remote_ip_addr) = IpAddr::from_str(remote_ip) {
if self
.config
.trusted_cidrs
.iter()
.any(|cidr| cidr.contains(&remote_ip_addr))
{
return ClientIP(Some(remote_ip_addr));
return ClientIp(Some(remote_ip_addr));
}
}
}
}

ClientIP(remote_ip)
ClientIp(remote_ip)
}
}

impl<S, B> Service<ServerContext, Request<B>> for ClientIPService<S>
impl<S, B> Service<ServerContext, Request<B>> for ClientIpService<S>
where
S: Service<ServerContext, Request<B>> + Send + Sync + 'static,
B: Send,
Expand All @@ -264,7 +268,7 @@ where
req: Request<B>,
) -> Result<Self::Response, Self::Error> {
let client_ip = self.get_client_ip(cx, req.headers());
cx.rpc_info_mut().caller_mut().tags.insert(client_ip);
cx.rpc_info_mut().caller_mut().insert(client_ip);

self.service.call(cx, req).await
}
Expand All @@ -283,21 +287,21 @@ mod client_ip_tests {
context::ServerContext,
server::{
route::{get, Route},
utils::client_ip::{ClientIP, ClientIPConfig, ClientIPLayer},
utils::client_ip::{ClientIp, ClientIpConfig, ClientIpLayer},
},
utils::test_helpers::simple_req,
};

#[tokio::test]
async fn test_client_ip() {
async fn handler(client_ip: ClientIP) -> String {
async fn handler(ClientIp(client_ip): ClientIp) -> String {
client_ip.unwrap().to_string()
}

let route: Route<&str> = Route::new(get(handler));
let service = ClientIPLayer::new()
let service = ClientIpLayer::new()
.with_config(
ClientIPConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]),
ClientIpConfig::default().with_trusted_cidrs(vec!["10.0.0.0/8".parse().unwrap()]),
)
.layer(route);

Expand Down
1 change: 1 addition & 0 deletions volo-http/src/server/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod serve_dir;

pub use file_response::FileResponse;
pub use serve_dir::ServeDir;

pub mod client_ip;
#[cfg(feature = "multipart")]
pub mod multipart;
Expand Down

0 comments on commit 519c1f4

Please sign in to comment.