Skip to content

Commit

Permalink
Fix tests for h3.
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Mar 19, 2024
1 parent 95222ae commit 058fd95
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 51 deletions.
1 change: 1 addition & 0 deletions contrib/sync_db_pools/lib/tests/shutdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod sqlite_shutdown_test {

let options = map!["url" => ":memory:"];
let config = Figment::from(rocket::Config::debug_default())
.merge(("port", 0))
.merge(("databases", map!["test" => &options]));

rocket::custom(config).attach(Pool::fairing())
Expand Down
66 changes: 39 additions & 27 deletions core/lib/src/listener/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt;
use std::any::Any;
use std::net::{self, Ipv4Addr, AddrParseError};
use std::net::{self, AddrParseError, IpAddr, Ipv4Addr};
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
Expand All @@ -20,9 +20,9 @@ impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> EndpointAddr for T {}
///
/// * [`&str`] - parse with [`FromStr`]
/// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`]
/// * [`std::net::SocketAddr`] - infallibly as [Endpoint::Tcp]
/// * [`PathBuf`] - infallibly as [`Endpoint::Unix`]
#[derive(Debug)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Endpoint {
Tcp(net::SocketAddr),
Quic(net::SocketAddr),
Expand Down Expand Up @@ -52,6 +52,33 @@ impl Endpoint {
}
}

pub fn socket_addr(&self) -> Option<net::SocketAddr> {
match self {
Endpoint::Quic(addr) => Some(*addr),
Endpoint::Tcp(addr) => Some(*addr),
Endpoint::Tls(inner, _) => inner.socket_addr(),
_ => None,
}
}

pub fn ip(&self) -> Option<IpAddr> {
match self {
Endpoint::Quic(addr) => Some(addr.ip()),
Endpoint::Tcp(addr) => Some(addr.ip()),
Endpoint::Tls(inner, _) => inner.ip(),
_ => None,
}
}

pub fn port(&self) -> Option<u16> {
match self {
Endpoint::Quic(addr) => Some(addr.port()),
Endpoint::Tcp(addr) => Some(addr.port()),
Endpoint::Tls(inner, _) => inner.port(),
_ => None,
}
}

pub fn unix(&self) -> Option<&Path> {
match self {
Endpoint::Unix(addr) => Some(addr),
Expand Down Expand Up @@ -189,27 +216,29 @@ impl Default for Endpoint {
/// The syntax is:
///
/// ```text
/// endpoint = 'tcp' ':' tcp_addr | 'unix' ':' unix_addr | tcp_addr
/// tcp_addr := IP_ADDR | SOCKET_ADDR
/// unix_addr := PATH
/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket
/// socket := IP_ADDR | SOCKET_ADDR
/// path := PATH
///
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
/// ```
///
/// If `IP_ADDR` is specified, the port defaults to `8000`.
/// If `IP_ADDR` is specified in socket, port defaults to `8000`.
impl FromStr for Endpoint {
type Err = AddrParseError;

fn from_str(string: &str) -> Result<Self, Self::Err> {
fn parse_tcp(string: &str, def_port: u16) -> Result<net::SocketAddr, AddrParseError> {
string.parse().or_else(|_| string.parse().map(|ip| net::SocketAddr::new(ip, def_port)))
fn parse_tcp(str: &str, def_port: u16) -> Result<net::SocketAddr, AddrParseError> {
str.parse().or_else(|_| str.parse().map(|ip| net::SocketAddr::new(ip, def_port)))
}

if let Some((proto, string)) = string.split_once(':') {
if proto.trim().as_uncased() == "tcp" {
return parse_tcp(string.trim(), 8000).map(Self::Tcp);
} else if proto.trim().as_uncased() == "quic" {
return parse_tcp(string.trim(), 8000).map(Self::Quic);
} else if proto.trim().as_uncased() == "unix" {
return Ok(Self::Unix(PathBuf::from(string.trim())));
}
Expand Down Expand Up @@ -245,6 +274,7 @@ impl PartialEq for Endpoint {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Tcp(l0), Self::Tcp(r0)) => l0 == r0,
(Self::Quic(l0), Self::Quic(r0)) => l0 == r0,
(Self::Unix(l0), Self::Unix(r0)) => l0 == r0,
(Self::Tls(l0, _), Self::Tls(r0, _)) => l0 == r0,
(Self::Custom(l0), Self::Custom(r0)) => l0.to_string() == r0.to_string(),
Expand All @@ -253,24 +283,6 @@ impl PartialEq for Endpoint {
}
}

impl PartialEq<std::net::SocketAddr> for Endpoint {
fn eq(&self, other: &std::net::SocketAddr) -> bool {
self.tcp() == Some(*other)
}
}

impl PartialEq<std::net::SocketAddrV4> for Endpoint {
fn eq(&self, other: &std::net::SocketAddrV4) -> bool {
self.tcp() == Some((*other).into())
}
}

impl PartialEq<std::net::SocketAddrV6> for Endpoint {
fn eq(&self, other: &std::net::SocketAddrV6) -> bool {
self.tcp() == Some((*other).into())
}
}

impl PartialEq<PathBuf> for Endpoint {
fn eq(&self, other: &PathBuf) -> bool {
self.unix() == Some(other.as_path())
Expand Down
9 changes: 5 additions & 4 deletions core/lib/src/local/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,17 @@ macro_rules! pub_request_impl {
/// Set the remote address to "8.8.8.8:80":
///
/// ```rust
/// use std::net::{SocketAddrV4, Ipv4Addr};
/// use std::net::Ipv4Addr;
///
#[doc = $import]
///
/// # Client::_test(|_, request, _| {
/// let request: LocalRequest = request;
/// let req = request.remote("8.8.8.8:80");
/// let req = request.remote("tcp:8.8.8.8:80");
///
/// let addr = SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8).into(), 80);
/// assert_eq!(req.inner().remote().unwrap(), &addr);
/// let remote = req.inner().remote().unwrap().tcp().unwrap();
/// assert_eq!(remote.ip(), Ipv4Addr::new(8, 8, 8, 8));
/// assert_eq!(remote.port(), 80);
/// # });
/// ```
#[inline]
Expand Down
2 changes: 1 addition & 1 deletion core/lib/src/request/from_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ impl<'r> FromRequest<'r> for std::net::SocketAddr {

async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
request.remote()
.and_then(|r| r.tcp())
.and_then(|r| r.socket_addr())
.or_forward(Status::InternalServerError)
}
}
Expand Down
37 changes: 20 additions & 17 deletions core/lib/src/request/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ pub struct Request<'r> {
/// Information derived from an incoming connection, if any.
#[derive(Clone, Default)]
pub(crate) struct ConnectionMeta {
pub peer_address: Option<Arc<Endpoint>>,
pub peer_endpoint: Option<Endpoint>,
#[cfg_attr(not(feature = "mtls"), allow(dead_code))]
pub peer_certs: Option<Arc<Certificates<'static>>>,
}

impl<C: Connection> From<&C> for ConnectionMeta {
fn from(conn: &C) -> Self {
ConnectionMeta {
peer_address: conn.endpoint().ok().map(Arc::new),
peer_endpoint: conn.endpoint().ok(),
peer_certs: conn.certificates().map(|c| c.into_owned()).map(Arc::new),
}
}
Expand Down Expand Up @@ -316,20 +316,21 @@ impl<'r> Request<'r> {
/// # Example
///
/// ```rust
/// use std::net::{SocketAddrV4, Ipv4Addr};
/// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
/// use rocket::listener::Endpoint;
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let mut req = c.get("/");
/// # let request = req.inner_mut();
///
/// assert_eq!(request.remote(), None);
///
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8000);
/// request.set_remote(localhost);
/// assert_eq!(request.remote().unwrap(), &localhost);
/// let localhost = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8111);
/// request.set_remote(Endpoint::Tcp(localhost));
/// assert_eq!(request.remote().unwrap().tcp().unwrap(), localhost);
/// ```
#[inline(always)]
pub fn remote(&self) -> Option<&Endpoint> {
self.connection.peer_address.as_deref()
self.connection.peer_endpoint.as_ref()
}

/// Sets the remote address of `self` to `address`.
Expand All @@ -339,20 +340,21 @@ impl<'r> Request<'r> {
/// Set the remote address to be 127.0.0.1:8111:
///
/// ```rust
/// use std::net::{SocketAddrV4, Ipv4Addr};
/// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
/// use rocket::listener::Endpoint;
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let mut req = c.get("/");
/// # let request = req.inner_mut();
///
/// assert_eq!(request.remote(), None);
///
/// let localhost = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8111);
/// request.set_remote(localhost);
/// assert_eq!(request.remote().unwrap(), &localhost);
/// let localhost = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8111);
/// request.set_remote(Endpoint::Tcp(localhost));
/// assert_eq!(request.remote().unwrap().tcp().unwrap(), localhost);
/// ```
#[inline(always)]
pub fn set_remote<A: Into<Endpoint>>(&mut self, address: A) {
self.connection.peer_address = Some(Arc::new(address.into()));
pub fn set_remote(&mut self, endpoint: Endpoint) {
self.connection.peer_endpoint = Some(endpoint.into());
}

/// Returns the IP address of the configured
Expand Down Expand Up @@ -491,14 +493,15 @@ impl<'r> Request<'r> {
/// # let c = rocket::local::blocking::Client::debug_with(vec![]).unwrap();
/// # let mut req = c.get("/");
/// # let request = req.inner_mut();
/// # use std::net::{SocketAddrV4, Ipv4Addr};
/// # use std::net::{SocketAddr, IpAddr, Ipv4Addr};
/// # use rocket::listener::Endpoint;
///
/// // starting without an "X-Real-IP" header or remote address
/// assert!(request.client_ip().is_none());
///
/// // add a remote address; this is done by Rocket automatically
/// let localhost_9190 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 9190);
/// request.set_remote(localhost_9190);
/// let localhost_9190 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9190);
/// request.set_remote(Endpoint::Tcp(localhost_9190));
/// assert_eq!(request.client_ip().unwrap(), Ipv4Addr::LOCALHOST);
///
/// // now with an X-Real-IP header, the default value for `ip_header`.
Expand All @@ -507,7 +510,7 @@ impl<'r> Request<'r> {
/// ```
#[inline]
pub fn client_ip(&self) -> Option<IpAddr> {
self.real_ip().or_else(|| Some(self.remote()?.tcp()?.ip()))
self.real_ip().or_else(|| self.remote()?.ip())
}

/// Returns a wrapped borrow to the cookies in `self`.
Expand Down
7 changes: 5 additions & 2 deletions examples/tls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ fn mutual(cert: Certificate<'_>) -> String {
}

#[get("/", rank = 2)]
fn hello(endpoint: &Endpoint) -> String {
format!("Hello, {endpoint}!")
fn hello(endpoint: Option<&Endpoint>) -> String {
match endpoint {
Some(endpoint) => format!("Hello, {endpoint}!"),
None => "Hello, world!".into(),
}
}

#[launch]
Expand Down

0 comments on commit 058fd95

Please sign in to comment.