Skip to content

Commit

Permalink
Initial work to support x-forwarded-proto
Browse files Browse the repository at this point in the history
Fix tests

Undo automatic formatting

Fix some typos

Cleanup, fix docs

Add proxy_secure

Rename Protocol to ForwardedProtocol and improve config docs

Rename ForwardedProtocol and forwarded_protocol to ProxyProto and proxy_proto. Implemented suggestions from PR

Fix tests, remove unused config parameter in set_defaults

Reorder from_hyp, fix local CookieJar

Use UncasedStr in ProxyProtocol instead of owned variant
  • Loading branch information
atezet authored and SergioBenitez committed Dec 16, 2023
1 parent b5278de commit d1e8bc4
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 51 deletions.
2 changes: 2 additions & 0 deletions core/http/src/header/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ mod media_type;
mod content_type;
mod accept;
mod header;
mod proxy_proto;

pub use self::content_type::ContentType;
pub use self::accept::{Accept, QMediaType};
pub use self::media_type::MediaType;
pub use self::header::{Header, HeaderMap};
pub use self::proxy_proto::ProxyProto;

pub(crate) use self::media_type::Source;
33 changes: 33 additions & 0 deletions core/http/src/header/proxy_proto.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use std::fmt;

/// A protocol used to identify a specific protocol forwarded by an HTTP proxy.
// Names are case-insensitive
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ProxyProto<'a> {
/// `http` value, Hypertext Transfer Protocol
Http,
/// `https` value, Hypertext Transfer Protocol Secure
Https,
/// Any other protocol name not known to us
Unknown(&'a uncased::UncasedStr),
}

impl<'a> From<&'a str> for ProxyProto<'a> {
fn from(s: &'a str) -> ProxyProto<'a> {
match s.to_lowercase().as_str() {
"http" => ProxyProto::Http,
"https" => ProxyProto::Https,
_ => ProxyProto::Unknown(s.into()),
}
}
}

impl fmt::Display for ProxyProto<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match *self {
ProxyProto::Http => "http",
ProxyProto::Https => "https",
ProxyProto::Unknown(s) => s.as_str(),
})
}
}
23 changes: 22 additions & 1 deletion core/lib/src/config/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,23 @@ pub struct Config {
/// that the value must syntactically be a valid HTTP header name.
///
/// **(default: `"X-Real-IP"`)**
#[serde(deserialize_with = "crate::config::ip_header::deserialize")]
#[serde(deserialize_with = "crate::config::http_header::deserialize")]
pub ip_header: Option<Uncased<'static>>,
/// The name of a header, whose value is typically set by an intermediary
/// server or proxy, which contains the protocol (HTTP or HTTPS) used by the
/// connecting client. This should probably be [`X-Forwarded-Proto`], as
/// that is the de facto standard. Used by [`Request::forwarded_proto()`]
/// to determine the forwarded protocol and [`Request::forwarded_secure()`]
/// to determine whether a request is handled in a secure context.
///
/// [`X-Forwarded-Proto`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-Proto
///
/// To disable using any header for this purpose, set this value to `false`.
/// Deserialization semantics are identical to those of [`ip_header`].
///
/// **(default: `None`)**
#[serde(deserialize_with = "crate::config::http_header::deserialize")]
pub proxy_proto_header: Option<Uncased<'static>>,
/// Streaming read size limits. **(default: [`Limits::default()`])**
pub limits: Limits,
/// Directory to store temporary files in. **(default:
Expand Down Expand Up @@ -189,6 +204,7 @@ impl Config {
max_blocking: 512,
ident: Ident::default(),
ip_header: Some(Uncased::from_borrowed("X-Real-IP")),
proxy_proto_header: None,
limits: Limits::default(),
temp_dir: std::env::temp_dir().into(),
keep_alive: 5,
Expand Down Expand Up @@ -409,6 +425,11 @@ impl Config {
None => launch_meta_!("IP header: {}", "disabled".paint(VAL))
}

match self.proxy_proto_header {
Some(ref name) => launch_meta_!("Protocol header: {}", name.paint(VAL)),
None => launch_meta_!("Protocol header: {}", "disabled".paint(VAL))
}

launch_meta_!("limits: {}", (&self.limits).paint(VAL));
launch_meta_!("temp dir: {}", self.temp_dir.relative().display().paint(VAL));
launch_meta_!("http/2: {}", (cfg!(feature = "http2").paint(VAL)));
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion core/lib/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
mod ident;
mod config;
mod shutdown;
mod ip_header;
mod http_header;

#[cfg(feature = "tls")]
mod tls;
Expand Down
38 changes: 25 additions & 13 deletions core/lib/src/cookies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ pub struct CookieJar<'a> {
jar: cookie::CookieJar,
ops: Mutex<Vec<Op>>,
config: &'a Config,
secure_context: bool,
}

impl<'a> Clone for CookieJar<'a> {
Expand All @@ -163,6 +164,7 @@ impl<'a> Clone for CookieJar<'a> {
jar: self.jar.clone(),
ops: Mutex::new(self.ops.lock().clone()),
config: self.config,
secure_context: self.secure_context,
}
}
}
Expand All @@ -183,12 +185,21 @@ impl Op {

impl<'a> CookieJar<'a> {
#[inline(always)]
pub(crate) fn new(config: &'a Config) -> Self {
CookieJar::from(cookie::CookieJar::new(), config)
pub(crate) fn new(config: &'a Config, secure_context: bool) -> Self {
CookieJar::from(cookie::CookieJar::new(), config, secure_context)
}

pub(crate) fn from(jar: cookie::CookieJar, config: &'a Config) -> Self {
CookieJar { jar, config, ops: Mutex::new(Vec::new()) }
pub(crate) fn from(
jar: cookie::CookieJar,
config: &'a Config,
secure_context: bool,
) -> Self {
CookieJar {
jar,
config,
ops: Mutex::new(Vec::new()),
secure_context,
}
}

/// Returns a reference to the _original_ `Cookie` inside this container
Expand Down Expand Up @@ -309,7 +320,7 @@ impl<'a> CookieJar<'a> {
/// ```
pub fn add<C: Into<Cookie<'static>>>(&self, cookie: C) {
let mut cookie = cookie.into();
Self::set_defaults(self.config, &mut cookie);
Self::set_defaults(self.secure_context, &mut cookie);
self.ops.lock().push(Op::Add(cookie, false));
}

Expand Down Expand Up @@ -346,7 +357,7 @@ impl<'a> CookieJar<'a> {
#[cfg_attr(nightly, doc(cfg(feature = "secrets")))]
pub fn add_private<C: Into<Cookie<'static>>>(&self, cookie: C) {
let mut cookie = cookie.into();
Self::set_private_defaults(self.config, &mut cookie);
Self::set_private_defaults(self.secure_context, &mut cookie);
self.ops.lock().push(Op::Add(cookie, true));
}

Expand Down Expand Up @@ -515,8 +526,9 @@ impl<'a> CookieJar<'a> {
/// * `path`: `"/"`
/// * `SameSite`: `Strict`
///
/// Furthermore, if TLS is enabled, the `Secure` cookie flag is set.
fn set_defaults(config: &Config, cookie: &mut Cookie<'static>) {
/// Furthermore, if TLS is enabled or handled by a proxy, the `Secure`
/// cookie flag is set.
fn set_defaults(secure_context: bool, cookie: &mut Cookie<'static>) {
if cookie.path().is_none() {
cookie.set_path("/");
}
Expand All @@ -525,7 +537,7 @@ impl<'a> CookieJar<'a> {
cookie.set_same_site(SameSite::Strict);
}

if cookie.secure().is_none() && config.tls_enabled() {
if cookie.secure().is_none() && secure_context {
cookie.set_secure(true);
}
}
Expand Down Expand Up @@ -554,10 +566,11 @@ impl<'a> CookieJar<'a> {
/// * `HttpOnly`: `true`
/// * `Expires`: 1 week from now
///
/// Furthermore, if TLS is enabled, the `Secure` cookie flag is set.
/// Furthermore, if TLS is enabled or handled by a proxy, the `Secure`
/// cookie flag is set.
#[cfg(feature = "secrets")]
#[cfg_attr(nightly, doc(cfg(feature = "secrets")))]
fn set_private_defaults(config: &Config, cookie: &mut Cookie<'static>) {
fn set_private_defaults(secure_context: bool, cookie: &mut Cookie<'static>) {
if cookie.path().is_none() {
cookie.set_path("/");
}
Expand All @@ -574,7 +587,7 @@ impl<'a> CookieJar<'a> {
cookie.set_expires(time::OffsetDateTime::now_utc() + time::Duration::weeks(1));
}

if cookie.secure().is_none() && config.tls_enabled() {
if cookie.secure().is_none() && secure_context {
cookie.set_secure(true);
}
}
Expand All @@ -593,5 +606,4 @@ impl fmt::Debug for CookieJar<'_> {
.field("pending", &pending)
.finish()
}

}
5 changes: 4 additions & 1 deletion core/lib/src/local/asynchronous/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ impl<'c> LocalResponse<'c> {

async move {
let response: Response<'c> = f(request).await;
let mut cookies = CookieJar::new(request.rocket().config());
let mut cookies = CookieJar::new(
request.rocket().config(),
request.context_is_likely_secure(),
);
for cookie in response.cookies() {
cookies.add_original(cookie.into_owned());
}
Expand Down
2 changes: 1 addition & 1 deletion core/lib/src/local/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ macro_rules! pub_client_impl {
pub fn cookies(&self) -> crate::http::CookieJar<'_> {
let config = &self.rocket().config();
let jar = self._with_raw_cookies(|jar| jar.clone());
crate::http::CookieJar::from(jar, config)
crate::http::CookieJar::from(jar, config, config.tls_enabled())
}

req_method!($import, "GET", get, Method::Get);
Expand Down
20 changes: 19 additions & 1 deletion core/lib/src/request/from_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{Request, Route};
use crate::outcome::{self, Outcome::*};

use crate::http::uri::{Host, Origin};
use crate::http::{Status, ContentType, Accept, Method, CookieJar};
use crate::http::{Status, ContentType, Accept, Method, ProxyProto, CookieJar};

/// Type alias for the `Outcome` of a `FromRequest` conversion.
pub type Outcome<S, E> = outcome::Outcome<S, (Status, E), Status>;
Expand Down Expand Up @@ -160,6 +160,12 @@ pub type Outcome<S, E> = outcome::Outcome<S, (Status, E), Status>;
/// via [`Request::client_ip()`]. If the client's IP address is not known,
/// the request is forwarded with a 500 Internal Server Error status.
///
/// * **Protocol**
///
/// Extracts the protocol of the incoming request as a [`Protocol`] via
/// [`Request::forwarded_proto()`] (HTTP or HTTPS). If the used protocol is
/// not known, the request is forwarded.
///
/// * **SocketAddr**
///
/// Extracts the remote address of the incoming request as a [`SocketAddr`]
Expand Down Expand Up @@ -470,6 +476,18 @@ impl<'r> FromRequest<'r> for IpAddr {
}
}

#[crate::async_trait]
impl<'r> FromRequest<'r> for ProxyProto<'r> {
type Error = std::convert::Infallible;

async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.proxy_proto() {
Some(proto) => Success(proto),
None => Forward(())
}
}
}

#[crate::async_trait]
impl<'r> FromRequest<'r> for SocketAddr {
type Error = Infallible;
Expand Down
91 changes: 75 additions & 16 deletions core/lib/src/request/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::request::{FromParam, FromSegments, FromRequest, Outcome};
use crate::form::{self, ValueField, FromForm};
use crate::data::Limits;

use crate::http::{hyper, Method, Header, HeaderMap};
use crate::http::{hyper, Method, Header, HeaderMap, ProxyProto};
use crate::http::{ContentType, Accept, MediaType, CookieJar, Cookie};
use crate::http::uncased::UncasedStr;
use crate::http::private::Certificates;
Expand Down Expand Up @@ -97,7 +97,7 @@ impl<'r> Request<'r> {
state: RequestState {
rocket,
route: Atomic::new(None),
cookies: CookieJar::new(rocket.config()),
cookies: CookieJar::new(rocket.config(), rocket.config().tls_enabled()),
accept: InitCell::new(),
content_type: InitCell::new(),
cache: Arc::new(<TypeMap![Send + Sync]>::new()),
Expand Down Expand Up @@ -386,6 +386,60 @@ impl<'r> Request<'r> {
})
}

/// Returns what protocol for a connection was handled by a proxy by
/// inspecting [`proxy_proto_header`](crate::Config::proxy_proto_header) of
/// the request if such a header is configured, exists and contains a known
/// setting ("http" or "https").
///
/// # Example
///
/// ```rust
/// use rocket::http::{Header, ProxyProto};
///
/// # let client = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let request = client.get("/");
/// assert_eq!(request.proxy_proto(), None);
///
/// // `proxy_proto_header` defaults to `None`. `x-forwarded-proto` is considered the de-facto standard
/// # let figment = rocket::Config::figment().merge(("proxy_proto_header", "x-forwarded-proto"));
/// # let client = rocket::local::blocking::Client::debug(rocket::custom(figment)).unwrap();
/// # let request = client.get("/");
/// let request = request.header(Header::new("x-forwarded-proto", "https"));
/// assert_eq!(request.proxy_proto(), Some(ProxyProto::Https));
/// ```
pub fn proxy_proto(&self) -> Option<ProxyProto<'_>> {
let proxy_proto_header = self.rocket().config.proxy_proto_header.as_ref()?.as_str();
self.headers()
.get_one(proxy_proto_header)
.and_then(|proto| Some(proto.into()))
}

/// Returns whether we are *likely* in a secure context; either because tls
/// is enabled in the Rocket config, or a secure connection was handled by
/// a proxy. The later is determined by inspecting the header configured in
/// [`proxy_proto_header`](crate::Config::proxy_proto_header) of the request
/// if such a header is configured, exists and contains the value "https"
///
/// # Example
///
/// ```rust
/// use rocket::http::{Header, ProxyProto};
///
/// # let client = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let request = client.get("/");
/// assert_eq!(request.context_is_likely_secure(), false);
///
/// // `proxy_proto_header` defaults to `None`. `x-forwarded-proto` is considered the de-facto standard
/// # let figment = rocket::Config::figment().merge(("proxy_proto_header", "x-forwarded-proto"));
/// # let client = rocket::local::blocking::Client::debug(rocket::custom(figment)).unwrap();
/// # let request = client.get("/");
/// let request = request.header(Header::new("x-forwarded-proto", "https"));
/// assert_eq!(request.context_is_likely_secure(), true);
/// ```
pub fn context_is_likely_secure(&self) -> bool {
self.rocket().config.tls_enabled() || self.proxy_proto() == Some(ProxyProto::Https)
}

/// Attempts to return the client's IP address by first inspecting the
/// [`ip_header`](crate::Config::ip_header) and then using the remote
/// connection's IP address. Note that the built-in `IpAddr` request guard
Expand Down Expand Up @@ -1037,20 +1091,6 @@ impl<'r> Request<'r> {
hyper.uri.host().map(|h| Host::new(Authority::new(None, h, hyper.uri.port_u16())))
};

// Set the request cookies, if they exist.
for header in hyper.headers.get_all("Cookie") {
let raw_str = match std::str::from_utf8(header.as_bytes()) {
Ok(string) => string,
Err(_) => continue
};

for cookie_str in raw_str.split(';').map(|s| s.trim()) {
if let Ok(cookie) = Cookie::parse_encoded(cookie_str) {
request.state.cookies.add_original(cookie.into_owned());
}
}
}

// Set the rest of the headers. This is rather unfortunate and slow.
for (name, value) in hyper.headers.iter() {
// FIXME: This is rather unfortunate. Header values needn't be UTF8.
Expand All @@ -1066,6 +1106,25 @@ impl<'r> Request<'r> {
request.add_header(Header::new(name.as_str(), value));
}

request.state.cookies = CookieJar::new(
rocket.config(),
request.context_is_likely_secure(),
);

// Set the request cookies, if they exist.
for header in hyper.headers.get_all("Cookie") {
let raw_str = match std::str::from_utf8(header.as_bytes()) {
Ok(string) => string,
Err(_) => continue
};

for cookie_str in raw_str.split(';').map(|s| s.trim()) {
if let Ok(cookie) = Cookie::parse_encoded(cookie_str) {
request.state.cookies.add_original(cookie.into_owned());
}
}
}

if errors.is_empty() {
Ok(request)
} else {
Expand Down
Loading

0 comments on commit d1e8bc4

Please sign in to comment.