From b21225dc29426d8539c9c76347a1ad9a11682eff Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Tue, 16 Apr 2024 02:39:52 -0700 Subject: [PATCH] wip: tls-resolver --- contrib/sync_db_pools/lib/tests/shutdown.rs | 2 +- core/http/src/uri/authority.rs | 24 +- core/lib/Cargo.toml | 5 +- core/lib/src/config/mod.rs | 3 - core/lib/src/error.rs | 10 +- core/lib/src/listener/bind.rs | 10 + core/lib/src/listener/bindable.rs | 52 -- core/lib/src/listener/connection.rs | 3 +- core/lib/src/listener/default.rs | 140 ++++-- core/lib/src/listener/endpoint.rs | 95 +++- core/lib/src/listener/listener.rs | 2 +- core/lib/src/listener/mod.rs | 7 +- core/lib/src/listener/quic.rs | 27 +- core/lib/src/listener/tcp.rs | 45 +- core/lib/src/listener/tls.rs | 119 ----- core/lib/src/listener/unix.rs | 65 +-- core/lib/src/log.rs | 8 +- core/lib/src/mtls/certificate.rs | 5 +- core/lib/src/request/request.rs | 12 +- core/lib/src/rocket.rs | 66 ++- core/lib/src/server.rs | 61 ++- core/lib/src/tls/config.rs | 67 ++- core/lib/src/tls/error.rs | 9 + core/lib/src/tls/listener.rs | 104 ++++ core/lib/src/tls/mod.rs | 9 +- core/lib/src/tls/resolver.rs | 82 ++++ core/lib/src/util/mod.rs | 46 +- .../on_launch_fairing_can_inspect_port.rs | 4 +- docs/tests/Cargo.toml | 2 +- examples/tls/private/client.pem | 88 ++++ examples/tls/private/gen_certs.sh | 12 + examples/tls/src/redirector.rs | 3 +- examples/tls/src/tests.rs | 8 +- testbench/Cargo.toml | 5 +- testbench/src/client.rs | 224 ++------- testbench/src/lib.rs | 34 +- testbench/src/main.rs | 443 ++++++++++++++++-- testbench/src/server.rs | 178 +++++++ 38 files changed, 1424 insertions(+), 655 deletions(-) create mode 100644 core/lib/src/listener/bind.rs delete mode 100644 core/lib/src/listener/bindable.rs delete mode 100644 core/lib/src/listener/tls.rs create mode 100644 core/lib/src/tls/listener.rs create mode 100644 core/lib/src/tls/resolver.rs create mode 100644 examples/tls/private/client.pem create mode 100644 testbench/src/server.rs diff --git a/contrib/sync_db_pools/lib/tests/shutdown.rs b/contrib/sync_db_pools/lib/tests/shutdown.rs index d21e6f87cd..eb1c53d1e7 100644 --- a/contrib/sync_db_pools/lib/tests/shutdown.rs +++ b/contrib/sync_db_pools/lib/tests/shutdown.rs @@ -1,5 +1,5 @@ -#[cfg(all(feature = "diesel_sqlite_pool"))] #[cfg(test)] +#[cfg(all(feature = "diesel_sqlite_pool"))] mod sqlite_shutdown_test { use rocket::{async_test, Build, Rocket}; use rocket_sync_db_pools::database; diff --git a/core/http/src/uri/authority.rs b/core/http/src/uri/authority.rs index 9175607df2..6fdc01a61f 100644 --- a/core/http/src/uri/authority.rs +++ b/core/http/src/uri/authority.rs @@ -185,7 +185,7 @@ impl<'a> Authority<'a> { self.host.from_cow_source(&self.source) } - /// Returns the port part of the authority URI, if there is one. + /// Returns the `port` part of the authority URI, if there is one. /// /// # Example /// @@ -206,6 +206,28 @@ impl<'a> Authority<'a> { pub fn port(&self) -> Option { self.port } + + /// Set the `port` of the authority URI. + /// + /// # Example + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// let mut uri = uri!("username:password@host:123"); + /// assert_eq!(uri.port(), Some(123)); + /// + /// uri.set_port(1024); + /// assert_eq!(uri.port(), Some(1024)); + /// assert_eq!(uri, "username:password@host:1024"); + /// + /// uri.set_port(None); + /// assert_eq!(uri.port(), None); + /// assert_eq!(uri, "username:password@host"); + /// ``` + #[inline(always)] + pub fn set_port>>(&mut self, port: T) { + self.port = port.into(); + } } impl_serde!(Authority<'a>, "an authority-form URI"); diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 4f08113e44..f39446bc69 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -69,7 +69,7 @@ ref-swap = "0.1.2" parking_lot = "0.12" ubyte = {version = "0.10.2", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } -figment = { version = "0.10.13", features = ["toml", "env"] } +figment = { version = "0.10.17", features = ["toml", "env"] } rand = "0.8" either = "1" pin-project-lite = "0.2" @@ -140,5 +140,6 @@ version_check = "0.9.1" [dev-dependencies] tokio = { version = "1", features = ["macros", "io-std"] } -figment = { version = "0.10", features = ["test"] } +figment = { version = "0.10.17", features = ["test"] } pretty_assertions = "1" +arc-swap = "1.7" diff --git a/core/lib/src/config/mod.rs b/core/lib/src/config/mod.rs index 9f07e9192c..c003a43a4c 100644 --- a/core/lib/src/config/mod.rs +++ b/core/lib/src/config/mod.rs @@ -137,9 +137,6 @@ mod secret_key; #[cfg(unix)] pub use crate::shutdown::Sig; -#[cfg(unix)] -pub use crate::listener::unix::UdsConfig; - #[cfg(feature = "secrets")] pub use secret_key::SecretKey; diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 473eb8a885..85867017dd 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -178,9 +178,13 @@ impl Error { self.mark_handled(); match self.kind() { ErrorKind::Bind(ref a, ref e) => { - match a { - Some(a) => error!("Binding to {} failed.", a.primary().underline()), - None => error!("Binding to network interface failed."), + if let Some(e) = e.downcast_ref::() { + e.pretty_print(); + } else { + match a { + Some(a) => error!("Binding to {} failed.", a.primary().underline()), + None => error!("Binding to network interface failed."), + } } info_!("{}", e); diff --git a/core/lib/src/listener/bind.rs b/core/lib/src/listener/bind.rs new file mode 100644 index 0000000000..67e4cf7dfe --- /dev/null +++ b/core/lib/src/listener/bind.rs @@ -0,0 +1,10 @@ +use crate::listener::{Endpoint, Listener}; + +pub trait Bind: Listener + 'static { + type Error: std::error::Error + Send + 'static; + + #[crate::async_bound(Send)] + async fn bind(to: T) -> Result; + + fn bind_endpoint(to: &T) -> Result; +} diff --git a/core/lib/src/listener/bindable.rs b/core/lib/src/listener/bindable.rs deleted file mode 100644 index 09cd78f20c..0000000000 --- a/core/lib/src/listener/bindable.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::io; -use futures::TryFutureExt; - -use crate::listener::{Listener, Endpoint}; - -pub trait Bindable: Sized { - type Listener: Listener + 'static; - - type Error: std::error::Error + Send + 'static; - - async fn bind(self) -> Result; - - /// The endpoint that `self` binds on. - fn bind_endpoint(&self) -> io::Result; -} - -impl Bindable for L { - type Listener = L; - - type Error = std::convert::Infallible; - - async fn bind(self) -> Result { - Ok(self) - } - - fn bind_endpoint(&self) -> io::Result { - L::endpoint(self) - } -} - -impl Bindable for either::Either { - type Listener = tokio_util::either::Either; - - type Error = either::Either; - - async fn bind(self) -> Result { - match self { - either::Either::Left(a) => a.bind() - .map_ok(tokio_util::either::Either::Left) - .map_err(either::Either::Left) - .await, - either::Either::Right(b) => b.bind() - .map_ok(tokio_util::either::Either::Right) - .map_err(either::Either::Right) - .await, - } - } - - fn bind_endpoint(&self) -> io::Result { - either::for_both!(self, a => a.bind_endpoint()) - } -} diff --git a/core/lib/src/listener/connection.rs b/core/lib/src/listener/connection.rs index 2a3c72c034..c838980f78 100644 --- a/core/lib/src/listener/connection.rs +++ b/core/lib/src/listener/connection.rs @@ -1,6 +1,7 @@ use std::io; use std::borrow::Cow; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::either::Either; use super::Endpoint; @@ -9,7 +10,7 @@ use super::Endpoint; #[derive(Clone)] pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>); -pub trait Connection: Send + Unpin { +pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin { fn endpoint(&self) -> io::Result; /// DER-encoded X.509 certificate chain presented by the client, if any. diff --git a/core/lib/src/listener/default.rs b/core/lib/src/listener/default.rs index ffb1f3e070..d44bddc0b8 100644 --- a/core/lib/src/listener/default.rs +++ b/core/lib/src/listener/default.rs @@ -1,64 +1,110 @@ -use either::Either; +use serde::Deserialize; +use tokio_util::either::{Either, Either::{Left, Right}}; +use futures::TryFutureExt; -use crate::listener::{Bindable, Endpoint}; -use crate::error::{Error, ErrorKind}; +use crate::error::ErrorKind; +use crate::{Ignite, Rocket}; +use crate::listener::{Bind, Endpoint, tcp::TcpListener}; -#[derive(serde::Deserialize)] -pub struct DefaultListener { +#[cfg(unix)] use crate::listener::unix::UnixListener; +#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig}; + +mod private { + use super::{Either, TcpListener}; + + #[cfg(feature = "tls")] pub type TlsListener = super::TlsListener; + #[cfg(not(feature = "tls"))] pub type TlsListener = T; + #[cfg(unix)] pub type UnixListener = super::UnixListener; + #[cfg(not(unix))] pub type UnixListener = super::TcpListener; + + pub type Listener = Either< + Either, TlsListener>, + Either, + >; +} + +#[derive(Deserialize)] +struct Config { #[serde(default)] - pub address: Endpoint, - pub port: Option, - pub reuse: Option, + address: Endpoint, #[cfg(feature = "tls")] - pub tls: Option, + tls: Option, } -#[cfg(not(unix))] type BaseBindable = Either; -#[cfg(unix)] type BaseBindable = Either; +pub type DefaultListener = private::Listener; -#[cfg(not(feature = "tls"))] type TlsBindable = Either; -#[cfg(feature = "tls")] type TlsBindable = Either, T>; +impl<'r> Bind<&'r Rocket> for DefaultListener { + type Error = crate::Error; -impl DefaultListener { - pub(crate) fn base_bindable(&self) -> Result { - match &self.address { - Endpoint::Tcp(mut address) => { - if let Some(port) = self.port { - address.set_port(port); - } + async fn bind(rocket: &'r Rocket) -> Result { + let config: Config = rocket.figment().extract()?; + match config.address { + #[cfg(feature = "tls")] + endpoint@Endpoint::Tcp(_) if config.tls.is_some() => { + let listener = as Bind<_>>::bind(rocket) + .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) + .await?; - Ok(BaseBindable::Left(address)) - }, + Ok(Left(Left(listener))) + } + endpoint@Endpoint::Tcp(_) => { + let listener = >::bind(rocket) + .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) + .await?; + + Ok(Right(Left(listener))) + } + #[cfg(all(unix, feature = "tls"))] + endpoint@Endpoint::Unix(_) if config.tls.is_some() => { + let listener = as Bind<_>>::bind(rocket) + .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) + .await?; + + Ok(Left(Right(listener))) + } #[cfg(unix)] - Endpoint::Unix(path) => { - let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, }; - Ok(BaseBindable::Right(uds)) - }, - #[cfg(not(unix))] - e@Endpoint::Unix(_) => { - let msg = "Unix domain sockets unavailable on non-unix platforms."; - let boxed = Box::::from(msg); - Err(Error::new(ErrorKind::Bind(Some(e.clone()), boxed))) - }, - other => { - let msg = format!("unsupported default listener address: {other}"); - let boxed = Box::::from(msg); - Err(Error::new(ErrorKind::Bind(Some(other.clone()), boxed))) + endpoint@Endpoint::Unix(_) => { + let listener = >::bind(rocket) + .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) + .await?; + + Ok(Right(Right(listener))) + } + endpoint => { + let msg = format!("unsupported bind endpoint: {endpoint}"); + let error = Box::::from(msg); + Err(ErrorKind::Bind(Some(endpoint), error).into()) } } } - pub(crate) fn tls_bindable(&self, inner: T) -> TlsBindable { - #[cfg(feature = "tls")] - if let Some(tls) = self.tls.clone() { - return TlsBindable::Left(super::tls::TlsBindable { inner, tls }); + fn bind_endpoint(rocket: &&'r Rocket) -> Result { + let config: Config = rocket.figment().extract()?; + match config.address { + #[cfg(feature = "tls")] + endpoint@Endpoint::Tcp(_) if config.tls.is_some() => { + as Bind<_>>::bind_endpoint(rocket) + .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)).into()) + } + endpoint@Endpoint::Tcp(_) => { + >::bind_endpoint(rocket) + .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)).into()) + } + #[cfg(all(unix, feature = "tls"))] + endpoint@Endpoint::Unix(_) if config.tls.is_some() => { + as Bind<_>>::bind_endpoint(rocket) + .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)).into()) + } + #[cfg(unix)] + endpoint@Endpoint::Unix(_) => { + >::bind_endpoint(rocket) + .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)).into()) + } + endpoint => { + let msg = format!("unsupported bind endpoint: {endpoint}"); + let error = Box::::from(msg); + Err(ErrorKind::Bind(Some(endpoint), error).into()) + } } - - TlsBindable::Right(inner) - } - - pub fn bindable(&self) -> Result { - self.base_bindable() - .map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b))) } } diff --git a/core/lib/src/listener/endpoint.rs b/core/lib/src/listener/endpoint.rs index 788c08d0d1..2864e75c60 100644 --- a/core/lib/src/listener/endpoint.rs +++ b/core/lib/src/listener/endpoint.rs @@ -5,6 +5,7 @@ use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::Arc; +use figment::Figment; use serde::de; use crate::http::uncased::AsUncased; @@ -21,7 +22,7 @@ impl EndpointAddr for T {} /// * [`&str`] - parse with [`FromStr`] /// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`] /// * [`PathBuf`] - infallibly as [`Endpoint::Unix`] -#[derive(Debug, Clone)] +#[derive(Clone)] #[non_exhaustive] pub enum Endpoint { Tcp(net::SocketAddr), @@ -152,6 +153,29 @@ impl Endpoint { Self::Tls(Arc::new(self), None) } + + /// Fetch the endpoint at `path` in `figment` of kind `kind` (e.g, "tcp") + /// then map the value using `f(Some(value))` if present and `f(None)` if + /// missing into a different value of typr `T`. + /// + /// If the conversion succeeds, returns `Ok(value)`. If the conversion fails + /// and `Some` value was passed in, returns an error indicating the endpoint + /// was an invalid `kind` and otherwise returns a "missing field" error. + pub(crate) fn fetch(figment: &Figment, kind: &str, path: &str, f: F) -> figment::Result + where F: FnOnce(Option<&Endpoint>) -> Option + { + match figment.extract_inner::(path) { + Ok(endpoint) => f(Some(&endpoint)).ok_or_else(|| { + let msg = format!("invalid {kind} endpoint: {endpoint:?}"); + let mut error = figment::Error::from(msg).with_path(path); + error.profile = Some(figment.profile().clone()); + error.metadata = figment.find_metadata(path).cloned(); + error + }), + Err(e) if e.missing() => f(None).ok_or(e), + Err(e) => Err(e) + } + } } impl fmt::Display for Endpoint { @@ -180,28 +204,15 @@ impl fmt::Display for Endpoint { } } -impl From for Endpoint { - fn from(value: PathBuf) -> Self { - Self::Unix(value) - } -} - -#[cfg(unix)] -impl TryFrom for Endpoint { - type Error = std::io::Error; - - fn try_from(v: tokio::net::unix::SocketAddr) -> Result { - v.as_pathname() - .ok_or_else(|| std::io::Error::other("unix socket is not path")) - .map(|path| Endpoint::Unix(path.to_path_buf())) - } -} - -impl TryFrom<&str> for Endpoint { - type Error = AddrParseError; - - fn try_from(value: &str) -> Result { - value.parse() +impl fmt::Debug for Endpoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Tcp(a) => write!(f, "tcp:{a}"), + Self::Quic(a) => write!(f, "quic:{a}]"), + Self::Unix(a) => write!(f, "unix:{}", a.display()), + Self::Tls(e, _) => write!(f, "unix:{:?}", &**e), + Self::Custom(e) => e.fmt(f), + } } } @@ -237,8 +248,6 @@ impl FromStr for Endpoint { if let Some((proto, string)) = string.split_once(':') { if proto.trim().as_uncased() == "tcp" { return parse_tcp(string.trim(), 8000).map(Self::Tcp); - } else if proto.trim().as_uncased() == "quic" { - return parse_tcp(string.trim(), 8000).map(Self::Quic); } else if proto.trim().as_uncased() == "unix" { return Ok(Self::Unix(PathBuf::from(string.trim()))); } @@ -256,7 +265,7 @@ impl<'de> de::Deserialize<'de> for Endpoint { type Value = Endpoint; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("TCP or Unix address") + formatter.write_str("valid TCP (ip) or unix (path) endpoint") } fn visit_str(self, v: &str) -> Result { @@ -294,3 +303,37 @@ impl PartialEq for Endpoint { self.unix() == Some(other) } } + +#[cfg(unix)] +impl TryFrom for Endpoint { + type Error = std::io::Error; + + fn try_from(v: tokio::net::unix::SocketAddr) -> Result { + v.as_pathname() + .ok_or_else(|| std::io::Error::other("unix socket is not path")) + .map(|path| Endpoint::Unix(path.to_path_buf())) + } +} + +impl TryFrom<&str> for Endpoint { + type Error = AddrParseError; + + fn try_from(value: &str) -> Result { + value.parse() + } +} + +macro_rules! impl_from { + ($T:ty => $V:ident) => { + impl From<$T> for Endpoint { + fn from(value: $T) -> Self { + Self::$V(value.into()) + } + } + } +} + +impl_from!(std::net::SocketAddr => Tcp); +impl_from!(std::net::SocketAddrV4 => Tcp); +impl_from!(std::net::SocketAddrV6 => Tcp); +impl_from!(PathBuf => Unix); diff --git a/core/lib/src/listener/listener.rs b/core/lib/src/listener/listener.rs index a272b699c8..a4c8c54830 100644 --- a/core/lib/src/listener/listener.rs +++ b/core/lib/src/listener/listener.rs @@ -5,7 +5,7 @@ use tokio_util::either::Either; use crate::listener::{Connection, Endpoint}; -pub trait Listener: Send + Sync { +pub trait Listener: Sized + Send + Sync { type Accept: Send; type Connection: Connection; diff --git a/core/lib/src/listener/mod.rs b/core/lib/src/listener/mod.rs index 4e0ea0c852..5c8dff4ed9 100644 --- a/core/lib/src/listener/mod.rs +++ b/core/lib/src/listener/mod.rs @@ -3,15 +3,12 @@ mod bounced; mod listener; mod endpoint; mod connection; -mod bindable; +mod bind; mod default; #[cfg(unix)] #[cfg_attr(nightly, doc(cfg(unix)))] pub mod unix; -#[cfg(feature = "tls")] -#[cfg_attr(nightly, doc(cfg(feature = "tls")))] -pub mod tls; pub mod tcp; #[cfg(feature = "http3-preview")] pub mod quic; @@ -19,7 +16,7 @@ pub mod quic; pub use endpoint::*; pub use listener::*; pub use connection::*; -pub use bindable::*; +pub use bind::*; pub use default::*; pub(crate) use cancellable::*; diff --git a/core/lib/src/listener/quic.rs b/core/lib/src/listener/quic.rs index e866f03652..9ba98d4056 100644 --- a/core/lib/src/listener/quic.rs +++ b/core/lib/src/listener/quic.rs @@ -38,7 +38,7 @@ use tokio::sync::Mutex; use tokio_stream::StreamExt; use crate::tls::{TlsConfig, Error}; -use crate::listener::{Listener, Connection, Endpoint}; +use crate::listener::Endpoint; type H3Conn = h3::server::Connection; @@ -94,25 +94,20 @@ impl QuicListener { } } -impl Listener for QuicListener { - type Accept = quic::Connection; - - type Connection = H3Stream; - - async fn accept(&self) -> io::Result { +impl QuicListener { + pub(crate) async fn accept(&self) -> Option { self.listener .lock().await .accept().await - .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "closed")) } - async fn connect(&self, accept: Self::Accept) -> io::Result { + pub(crate) async fn connect(&self, accept: quic::Connection) -> io::Result { let quic_conn = quic_h3::Connection::new(accept); let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?; Ok(H3Stream(conn)) } - fn endpoint(&self) -> io::Result { + pub(crate) fn endpoint(&self) -> io::Result { Ok(Endpoint::Quic(self.endpoint).with_tls(&self.tls)) } } @@ -159,16 +154,8 @@ impl QuicTx { } // FIXME: Expose certificates when possible. -impl Connection for H3Stream { - fn endpoint(&self) -> io::Result { - let addr = self.0.inner.conn.handle().remote_addr()?; - Ok(Endpoint::Quic(addr).assume_tls()) - } -} - -// FIXME: Expose certificates when possible. -impl Connection for H3Connection { - fn endpoint(&self) -> io::Result { +impl H3Connection { + pub fn endpoint(&self) -> io::Result { let addr = self.handle.remote_addr()?; Ok(Endpoint::Quic(addr).assume_tls()) } diff --git a/core/lib/src/listener/tcp.rs b/core/lib/src/listener/tcp.rs index af54ff7d47..6b62751b41 100644 --- a/core/lib/src/listener/tcp.rs +++ b/core/lib/src/listener/tcp.rs @@ -1,21 +1,50 @@ use std::io; +use std::net::{Ipv4Addr, SocketAddr}; + +use either::{Either, Left, Right}; #[doc(inline)] pub use tokio::net::{TcpListener, TcpStream}; -use crate::listener::{Listener, Bindable, Connection, Endpoint}; +use crate::{Ignite, Rocket}; +use crate::listener::{Bind, Connection, Endpoint, Listener}; -impl Bindable for std::net::SocketAddr { - type Listener = TcpListener; +impl Bind for TcpListener { + type Error = std::io::Error; - type Error = io::Error; + async fn bind(addr: SocketAddr) -> Result { + Self::bind(addr).await + } - async fn bind(self) -> Result { - TcpListener::bind(self).await + fn bind_endpoint(addr: &SocketAddr) -> Result { + Ok(Endpoint::Tcp(*addr)) } +} + +impl<'r> Bind<&'r Rocket> for TcpListener { + type Error = Either; + + async fn bind(rocket: &'r Rocket) -> Result { + let endpoint = Self::bind_endpoint(&rocket)?; + let addr = endpoint.tcp() + .ok_or_else(|| io::Error::other("internal error: invalid endpoint")) + .map_err(Right)?; + + Self::bind(addr).await.map_err(Right) + } + + fn bind_endpoint(rocket: &&'r Rocket) -> Result { + let figment = rocket.figment(); + let mut address = dbg!(Endpoint::fetch(figment, "tcp", "address", |e| { + let default = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000); + dbg!(e.map(|e| e.tcp()).unwrap_or(Some(default))) + })).map_err(Left)?; + + if let Some(port) = figment.extract_inner("port").map_err(Left)? { + address.set_port(port); + } - fn bind_endpoint(&self) -> io::Result { - Ok(Endpoint::Tcp(*self)) + Ok(Endpoint::Tcp(address)) } } diff --git a/core/lib/src/listener/tls.rs b/core/lib/src/listener/tls.rs deleted file mode 100644 index 220011e4f3..0000000000 --- a/core/lib/src/listener/tls.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::io; -use std::sync::Arc; - -use serde::Deserialize; -use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_rustls::TlsAcceptor; - -use crate::tls::{TlsConfig, Error}; -use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint}; - -#[doc(inline)] -pub use tokio_rustls::server::TlsStream; - -/// A TLS listener over some listener interface L. -pub struct TlsListener { - listener: L, - acceptor: TlsAcceptor, - config: TlsConfig, -} - -#[derive(Clone, Deserialize)] -pub struct TlsBindable { - #[serde(flatten)] - pub inner: I, - pub tls: TlsConfig, -} - -impl TlsConfig { - pub(crate) fn server_config(&self) -> Result { - let provider = Arc::new(self.default_crypto_provider()); - - #[cfg(feature = "mtls")] - let verifier = match self.mutual { - Some(ref mtls) => { - let ca = Arc::new(mtls.load_ca_certs()?); - let verifier = WebPkiClientVerifier::builder_with_provider(ca, provider.clone()); - match mtls.mandatory { - true => verifier.build()?, - false => verifier.allow_unauthenticated().build()?, - } - }, - None => WebPkiClientVerifier::no_client_auth(), - }; - - #[cfg(not(feature = "mtls"))] - let verifier = WebPkiClientVerifier::no_client_auth(); - - let mut tls_config = ServerConfig::builder_with_provider(provider) - .with_safe_default_protocol_versions()? - .with_client_cert_verifier(verifier) - .with_single_cert(self.load_certs()?, self.load_key()?)?; - - tls_config.ignore_client_order = self.prefer_server_cipher_order; - tls_config.session_storage = ServerSessionMemoryCache::new(1024); - tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?; - tls_config.alpn_protocols = vec![b"http/1.1".to_vec()]; - if cfg!(feature = "http2") { - tls_config.alpn_protocols.insert(0, b"h2".to_vec()); - } - - Ok(tls_config) - } -} - -impl Bindable for TlsBindable - where I::Listener: Listener::Connection>, - ::Connection: AsyncRead + AsyncWrite -{ - type Listener = TlsListener; - - type Error = Error; - - async fn bind(self) -> Result { - Ok(TlsListener { - acceptor: TlsAcceptor::from(Arc::new(self.tls.server_config()?)), - listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?, - config: self.tls, - }) - } - - fn bind_endpoint(&self) -> io::Result { - let inner = self.inner.bind_endpoint()?; - Ok(inner.with_tls(&self.tls)) - } -} - -impl Listener for TlsListener - where L: Listener::Connection>, - L::Connection: AsyncRead + AsyncWrite -{ - type Accept = L::Connection; - - type Connection = TlsStream; - - async fn accept(&self) -> io::Result { - self.listener.accept().await - } - - async fn connect(&self, conn: L::Connection) -> io::Result { - self.acceptor.accept(conn).await - } - - fn endpoint(&self) -> io::Result { - Ok(self.listener.endpoint()?.with_tls(&self.config)) - } -} - -impl Connection for TlsStream { - fn endpoint(&self) -> io::Result { - Ok(self.get_ref().0.endpoint()?.assume_tls()) - } - - #[cfg(feature = "mtls")] - fn certificates(&self) -> Option> { - let cert_chain = self.get_ref().1.peer_certificates()?; - Some(Certificates::from(cert_chain)) - } -} diff --git a/core/lib/src/listener/unix.rs b/core/lib/src/listener/unix.rs index a6db180170..92e6326497 100644 --- a/core/lib/src/listener/unix.rs +++ b/core/lib/src/listener/unix.rs @@ -1,48 +1,39 @@ use std::io; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; +use either::{Either, Left, Right}; use tokio::time::{sleep, Duration}; use crate::fs::NamedFile; -use crate::listener::{Listener, Bindable, Connection, Endpoint}; +use crate::listener::{Listener, Bind, Connection, Endpoint}; use crate::util::unix; +use crate::{Ignite, Rocket}; pub use tokio::net::UnixStream; -#[derive(Debug, Clone)] -pub struct UdsConfig { - /// Socket address. - pub path: PathBuf, - /// Recreate a socket that already exists. - pub reuse: Option, -} - -pub struct UdsListener { +pub struct UnixListener { path: PathBuf, lock: Option, listener: tokio::net::UnixListener, } -impl Bindable for UdsConfig { - type Listener = UdsListener; - - type Error = io::Error; - - async fn bind(self) -> Result { - let lock = if self.reuse.unwrap_or(true) { - let lock_ext = match self.path.extension().and_then(|s| s.to_str()) { +impl UnixListener { + async fn bind>(path: P, reuse: bool) -> io::Result { + let path = path.as_ref(); + let lock = if reuse { + let lock_ext = match path.extension().and_then(|s| s.to_str()) { Some(ext) if !ext.is_empty() => format!("{}.lock", ext), _ => "lock".to_string() }; let mut opts = tokio::fs::File::options(); opts.create(true).write(true); - let lock_path = self.path.with_extension(lock_ext); + let lock_path = path.with_extension(lock_ext); let lock_file = NamedFile::open_with(lock_path, &opts).await?; unix::lock_exclusive_nonblocking(lock_file.file())?; - if self.path.exists() { - tokio::fs::remove_file(&self.path).await?; + if path.exists() { + tokio::fs::remove_file(&path).await?; } Some(lock_file) @@ -55,9 +46,9 @@ impl Bindable for UdsConfig { // and this will succeed. So let's try a few times. let mut retries = 5; let listener = loop { - match tokio::net::UnixListener::bind(&self.path) { + match tokio::net::UnixListener::bind(&path) { Ok(listener) => break listener, - Err(e) if self.path.exists() && lock.is_none() => return Err(e), + Err(e) if path.exists() && lock.is_none() => return Err(e), Err(_) if retries > 0 => { retries -= 1; sleep(Duration::from_millis(100)).await; @@ -66,15 +57,31 @@ impl Bindable for UdsConfig { } }; - Ok(UdsListener { lock, listener, path: self.path, }) + Ok(UnixListener { lock, listener, path: path.into() }) + } +} + +impl<'r> Bind<&'r Rocket> for UnixListener { + type Error = Either; + + async fn bind(rocket: &'r Rocket) -> Result { + let endpoint = Self::bind_endpoint(&rocket)?; + let path = endpoint.unix() + .ok_or_else(|| Right(io::Error::other("internal error: invalid endpoint")))?; + + let reuse: Option = rocket.figment().extract_inner("reuse").map_err(Left)?; + Ok(Self::bind(path, reuse.unwrap_or(true)).await.map_err(Right)?) } - fn bind_endpoint(&self) -> io::Result { - Ok(Endpoint::Unix(self.path.clone())) + fn bind_endpoint(rocket: &&'r Rocket) -> Result { + let as_pathbuf = |e: Option<&Endpoint>| e.and_then(|e| e.unix().map(|p| p.to_path_buf())); + Endpoint::fetch(rocket.figment(), "unix", "address", as_pathbuf) + .map(Endpoint::Unix) + .map_err(Left) } } -impl Listener for UdsListener { +impl Listener for UnixListener { type Accept = UnixStream; type Connection = Self::Accept; @@ -98,7 +105,7 @@ impl Connection for UnixStream { } } -impl Drop for UdsListener { +impl Drop for UnixListener { fn drop(&mut self) { if let Some(lock) = &self.lock { let _ = std::fs::remove_file(&self.path); diff --git a/core/lib/src/log.rs b/core/lib/src/log.rs index d49742fc38..69f368c754 100644 --- a/core/lib/src/log.rs +++ b/core/lib/src/log.rs @@ -154,13 +154,15 @@ impl log::Log for RocketLogger { } } +static ROCKET_LOGGER_SET: AtomicBool = AtomicBool::new(false); + pub(crate) fn init_default() { - crate::log::init(&crate::Config::debug_default()) + if !ROCKET_LOGGER_SET.load(Ordering::Acquire) { + crate::log::init(&crate::Config::debug_default()) + } } pub(crate) fn init(config: &crate::Config) { - static ROCKET_LOGGER_SET: AtomicBool = AtomicBool::new(false); - // Try to initialize Rocket's logger, recording if we succeeded. if log::set_boxed_logger(Box::new(RocketLogger)).is_ok() { ROCKET_LOGGER_SET.store(true, Ordering::Release); diff --git a/core/lib/src/mtls/certificate.rs b/core/lib/src/mtls/certificate.rs index 906d821f21..3bccb7bb8d 100644 --- a/core/lib/src/mtls/certificate.rs +++ b/core/lib/src/mtls/certificate.rs @@ -14,8 +14,9 @@ use crate::http::Status; /// /// The request guard implementation succeeds if: /// +/// * MTLS is [configured](crate::mtls). /// * The client presents certificates. -/// * The certificates are active and not yet expired. +/// * The certificates are valid and not expired. /// * The client's certificate chain was signed by the CA identified by the /// configured `ca_certs` and with respect to SNI, if any. See [module level /// docs](crate::mtls) for configuration details. @@ -24,7 +25,7 @@ use crate::http::Status; /// status of 401 Unauthorized. /// /// If the certificate chain fails to validate or verify, the guard _fails_ with -/// the respective [`Error`]. +/// the respective [`Error`] a status of 401 Unauthorized. /// /// # Wrapping /// diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index b69b27408c..b0d5282985 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -1,4 +1,4 @@ -use std::fmt; +use std::{io, fmt}; use std::ops::RangeFrom; use std::sync::{Arc, atomic::Ordering}; use std::borrow::Cow; @@ -18,7 +18,7 @@ use crate::data::Limits; use crate::http::ProxyProto; use crate::http::{Method, Header, HeaderMap, ContentType, Accept, MediaType, CookieJar, Cookie}; use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority}; -use crate::listener::{Certificates, Endpoint, Connection}; +use crate::listener::{Certificates, Endpoint}; /// The type of an incoming web request. /// @@ -44,11 +44,11 @@ pub(crate) struct ConnectionMeta { pub peer_certs: Option>>, } -impl From<&C> for ConnectionMeta { - fn from(conn: &C) -> Self { +impl ConnectionMeta { + pub fn new(endpoint: io::Result, certs: Option>) -> Self { ConnectionMeta { - peer_endpoint: conn.endpoint().ok(), - peer_certs: conn.certificates().map(|c| c.into_owned()).map(Arc::new), + peer_endpoint: endpoint.ok(), + peer_certs: certs.map(|c| c.into_owned()).map(Arc::new), } } } diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index bfff32b835..a25eeef5dd 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -2,15 +2,16 @@ use std::fmt; use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::Duration; +use std::any::Any; +use futures::TryFutureExt; use yansi::Paint; use either::Either; use figment::{Figment, Provider}; -use tokio::io::{AsyncRead, AsyncWrite}; use crate::shutdown::{Stages, Shutdown}; use crate::{sentinel, shield::Shield, Catcher, Config, Route}; -use crate::listener::{Bindable, DefaultListener, Endpoint, Listener}; +use crate::listener::{Bind, DefaultListener, Endpoint, Listener}; use crate::router::Router; use crate::fairing::{Fairing, Fairings}; use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting}; @@ -681,19 +682,36 @@ impl Rocket { rocket } - async fn _launch(self) -> Result, Error> { - let config = self.figment().extract::()?; - either::for_both!(config.base_bindable()?, base => { - either::for_both!(config.tls_bindable(base), bindable => { - self._launch_on(bindable).await - }) - }) + async fn _launch_with(self) -> Result, Error> + where B: for<'r> Bind<&'r Rocket> + { + let bind_endpoint = B::bind_endpoint(&&self) + .map_err(|e| ErrorKind::Bind(None, Box::new(e)))?; + + let listener: B = B::bind(&self).await + .map_err(|e| ErrorKind::Bind(Some(bind_endpoint), Box::new(e)))?; + + let any: Box = Box::new(listener); + match any.downcast::() { + Ok(listener) => { + let listener = *listener; + crate::util::for_both!(listener, listener => { + crate::util::for_both!(listener, listener => { + self._launch_on(listener).await + }) + }) + } + Err(any) => { + let listener = *any.downcast::().unwrap(); + self._launch_on(listener).await + } + } } - async fn _launch_on(self, bindable: B) -> Result, Error> - where ::Connection: AsyncRead + AsyncWrite + async fn _launch_on(self, listener: L) -> Result, Error> + where L: Listener + 'static, { - let rocket = self.bind_and_serve(bindable, |rocket| async move { + let rocket = self.listen_and_serve(listener, |rocket| async move { let rocket = Arc::new(rocket); rocket.shutdown.spawn_listener(&rocket.config.shutdown); @@ -996,19 +1014,31 @@ impl Rocket

{ /// } /// ``` pub async fn launch(self) -> Result, Error> { + self.launch_with::().await + } + + pub async fn bind_launch>(self, value: T) -> Result, Error> { + let endpoint = B::bind_endpoint(&value).map_err(|e| ErrorKind::Bind(None, Box::new(e)))?; + let listener = B::bind(value).map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))); + self.launch_on(listener.await?).await + } + + pub async fn launch_with(self) -> Result, Error> + where B: for<'r> Bind<&'r Rocket> + { match self.0.into_state() { - State::Build(s) => Rocket::from(s).ignite().await?._launch().await, - State::Ignite(s) => Rocket::from(s)._launch().await, + State::Build(s) => Rocket::from(s).ignite().await?._launch_with::().await, + State::Ignite(s) => Rocket::from(s)._launch_with::().await, State::Orbit(s) => Ok(Rocket::from(s).into_ignite()) } } - pub async fn launch_on(self, bindable: B) -> Result, Error> - where ::Connection: AsyncRead + AsyncWrite + pub async fn launch_on(self, listener: L) -> Result, Error> + where L: Listener + 'static, { match self.0.into_state() { - State::Build(s) => Rocket::from(s).ignite().await?._launch_on(bindable).await, - State::Ignite(s) => Rocket::from(s)._launch_on(bindable).await, + State::Build(s) => Rocket::from(s).ignite().await?._launch_on(listener).await, + State::Ignite(s) => Rocket::from(s)._launch_on(listener).await, State::Orbit(s) => Ok(Rocket::from(s).into_ignite()) } } diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index e008c7bc13..f6796b38b9 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -6,14 +6,14 @@ use std::time::Duration; use hyper::service::service_fn; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use hyper_util::server::conn::auto::Builder; -use futures::{Future, TryFutureExt, future::Either::*}; +use futures::{Future, TryFutureExt}; use tokio::io::{AsyncRead, AsyncWrite}; use crate::{Ignite, Orbit, Request, Rocket}; use crate::request::ConnectionMeta; use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler}; -use crate::listener::{Bindable, BouncedExt, CancellableExt, Listener}; -use crate::error::{log_server_error, ErrorKind}; +use crate::listener::{Listener, Connection, BouncedExt, CancellableExt}; +use crate::error::log_server_error; use crate::data::{IoStream, RawStream}; use crate::util::{spawn_inspect, FutureExt, ReaderStream}; use crate::http::Status; @@ -91,31 +91,28 @@ async fn io_handler_task(stream: S, mut handler: ErasedIoHandler) } impl Rocket { - pub(crate) async fn bind_and_serve( + pub(crate) async fn listen_and_serve( self, - bindable: B, - post_bind_callback: impl FnOnce(Rocket) -> R, + listener: L, + orbit_callback: impl FnOnce(Rocket) -> R, ) -> Result>> - where B: Bindable, - ::Connection: AsyncRead + AsyncWrite, + where L: Listener + 'static, R: Future>>> { - let binding_endpoint = bindable.bind_endpoint().ok(); - let h12listener = bindable.bind() - .map_err(|e| ErrorKind::Bind(binding_endpoint, Box::new(e))) - .await?; + let endpoint = listener.endpoint()?; - let endpoint = h12listener.endpoint()?; #[cfg(feature = "http3-preview")] if let (Some(addr), Some(tls)) = (endpoint.tcp(), endpoint.tls_config()) { + use crate::error::ErrorKind; + let h3listener = crate::listener::quic::QuicListener::bind(addr, tls.clone()) .map_err(|e| ErrorKind::Bind(Some(endpoint.clone()), Box::new(e))) .await?; let rocket = self.into_orbit(vec![h3listener.endpoint()?, endpoint]); - let rocket = post_bind_callback(rocket).await?; + let rocket = orbit_callback(rocket).await?; - let http12 = tokio::task::spawn(rocket.clone().serve12(h12listener)); + let http12 = tokio::task::spawn(rocket.clone().serve12(listener)); let http3 = tokio::task::spawn(rocket.clone().serve3(h3listener)); let (r1, r2) = tokio::join!(http12, http3); r1.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??; @@ -129,8 +126,8 @@ impl Rocket { } let rocket = self.into_orbit(vec![endpoint]); - let rocket = post_bind_callback(rocket).await?; - rocket.clone().serve12(h12listener).await?; + let rocket = orbit_callback(rocket).await?; + rocket.clone().serve12(listener).await?; Ok(rocket) } } @@ -160,11 +157,11 @@ impl Rocket { } let (listener, server) = (Arc::new(listener.bounced()), Arc::new(builder)); - while let Some(accept) = listener.accept().unless(self.shutdown()).await? { + while let Some(accept) = listener.accept().race(self.shutdown()).await.left().transpose()? { let (listener, rocket, server) = (listener.clone(), self.clone(), server.clone()); spawn_inspect(|e| log_server_error(&**e), async move { - let conn = listener.connect(accept).io_unless(rocket.shutdown()).await?; - let meta = ConnectionMeta::from(&conn); + let conn = listener.connect(accept).race_io(rocket.shutdown()).await?; + let meta = ConnectionMeta::new(conn.endpoint(), conn.certificates()); let service = service_fn(|mut req| { let upgrade = hyper::upgrade::on(&mut req); let (parts, incoming) = req.into_parts(); @@ -173,9 +170,9 @@ impl Rocket { let io = TokioIo::new(conn.cancellable(rocket.shutdown.clone())); let mut server = pin!(server.serve_connection_with_upgrades(io, service)); - match server.as_mut().or(rocket.shutdown()).await { - Left(result) => result, - Right(()) => { + match server.as_mut().race(rocket.shutdown()).await.left() { + Some(result) => result, + None => { server.as_mut().graceful_shutdown(); server.await }, @@ -189,26 +186,26 @@ impl Rocket { #[cfg(feature = "http3-preview")] async fn serve3(self: Arc, listener: crate::listener::quic::QuicListener) -> Result<()> { let rocket = self.clone(); - let listener = Arc::new(listener.bounced()); - while let Some(accept) = listener.accept().unless(rocket.shutdown()).await? { + let listener = Arc::new(listener); + while let Some(Some(accept)) = listener.accept().race(rocket.shutdown()).await.left() { let (listener, rocket) = (listener.clone(), rocket.clone()); spawn_inspect(|e: &io::Error| log_server_error(e), async move { - let mut stream = listener.connect(accept).io_unless(rocket.shutdown()).await?; - while let Some(mut conn) = stream.accept().io_unless(rocket.shutdown()).await? { + let mut stream = listener.connect(accept).race_io(rocket.shutdown()).await?; + while let Some(mut conn) = stream.accept().race_io(rocket.shutdown()).await? { let rocket = rocket.clone(); spawn_inspect(|e: &io::Error| log_server_error(e), async move { - let meta = ConnectionMeta::from(&conn); + let meta = ConnectionMeta::new(conn.endpoint(), None); let rx = conn.rx.cancellable(rocket.shutdown.clone()); let response = rocket.clone() .service(conn.parts, rx, None, meta) .map_err(io::Error::other) - .io_unless(rocket.shutdown.mercy.clone()) + .race_io(rocket.shutdown.mercy.clone()) .await?; let grace = rocket.shutdown.grace.clone(); - match conn.tx.send_response(response).or(grace).await { - Left(result) => result, - Right(_) => Ok(conn.tx.cancel()), + match conn.tx.send_response(response).race(grace).await.left() { + Some(result) => result, + None => Ok(conn.tx.cancel()), } }); } diff --git a/core/lib/src/tls/config.rs b/core/lib/src/tls/config.rs index 5762e8f182..0fe0d2b268 100644 --- a/core/lib/src/tls/config.rs +++ b/core/lib/src/tls/config.rs @@ -1,11 +1,15 @@ use std::io; +use std::sync::Arc; -use rustls::crypto::{ring, CryptoProvider}; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use futures::TryFutureExt; use figment::value::magic::{Either, RelativePathBuf}; use serde::{Deserialize, Serialize}; use indexmap::IndexSet; +use rustls::crypto::{ring, CryptoProvider}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier}; +use crate::tls::resolver::DynResolver; use crate::tls::error::{Result, Error, KeyError}; /// TLS configuration: certificate chain, key, and ciphersuites. @@ -78,7 +82,7 @@ use crate::tls::error::{Result, Error, KeyError}; /// # assert_eq!(tls_config.ciphers().count(), 9); /// # assert!(!tls_config.prefer_server_cipher_order()); /// ``` -#[derive(PartialEq, Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] pub struct TlsConfig { /// Path to a PEM file with, or raw bytes for, a DER-encoded X.509 TLS /// certificate chain. @@ -97,6 +101,8 @@ pub struct TlsConfig { #[cfg(feature = "mtls")] #[cfg_attr(nightly, doc(cfg(feature = "mtls")))] pub(crate) mutual: Option, + #[serde(skip)] + pub(crate) resolver: Option, } /// A supported TLS cipher suite. @@ -134,6 +140,7 @@ impl Default for TlsConfig { prefer_server_cipher_order: false, #[cfg(feature = "mtls")] mutual: None, + resolver: None, } } } @@ -430,8 +437,58 @@ impl TlsConfig { self.mutual.as_ref() } - pub fn validate(&self) -> Result<(), crate::tls::Error> { - self.server_config().map(|_| ()) + /// Try to convert `self` into a [rustls] [`ServerConfig`]. + /// + /// [`ServerConfig`]: rustls::server::ServerConfig + pub async fn server_config(&self) -> Result { + let this = self.clone(); + tokio::task::spawn_blocking(move || this._server_config()) + .map_err(io::Error::other) + .await? + } + + /// Try to convert `self` into a [rustls] [`ServerConfig`]. + /// + /// [`ServerConfig`]: rustls::server::ServerConfig + pub(crate) fn _server_config(&self) -> Result { + let provider = Arc::new(self.default_crypto_provider()); + + #[cfg(feature = "mtls")] + let verifier = match self.mutual { + Some(ref mtls) => { + let ca = Arc::new(mtls.load_ca_certs()?); + let verifier = WebPkiClientVerifier::builder_with_provider(ca, provider.clone()); + match mtls.mandatory { + true => verifier.build()?, + false => verifier.allow_unauthenticated().build()?, + } + }, + None => WebPkiClientVerifier::no_client_auth(), + }; + + #[cfg(not(feature = "mtls"))] + let verifier = WebPkiClientVerifier::no_client_auth(); + + let mut tls_config = ServerConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions()? + .with_client_cert_verifier(verifier) + .with_single_cert(self.load_certs()?, self.load_key()?)?; + + tls_config.ignore_client_order = self.prefer_server_cipher_order; + tls_config.session_storage = ServerSessionMemoryCache::new(1024); + tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?; + tls_config.alpn_protocols = vec![b"http/1.1".to_vec()]; + if cfg!(feature = "http2") { + tls_config.alpn_protocols.insert(0, b"h2".to_vec()); + } + + Ok(tls_config) + + } + + /// NOTE: This is a blocking function. + pub fn validate(&self) -> Result<()> { + self._server_config().map(|_| ()) } } diff --git a/core/lib/src/tls/error.rs b/core/lib/src/tls/error.rs index 5fd506157c..1a6612f002 100644 --- a/core/lib/src/tls/error.rs +++ b/core/lib/src/tls/error.rs @@ -17,6 +17,7 @@ pub enum Error { CertChain(std::io::Error), PrivKey(KeyError), CertAuth(rustls::Error), + Config(figment::Error), } impl std::fmt::Display for Error { @@ -31,6 +32,7 @@ impl std::fmt::Display for Error { PrivKey(e) => write!(f, "failed to process private key: {e}"), CertAuth(e) => write!(f, "failed to process certificate authority: {e}"), Bind(e) => write!(f, "failed to bind to network interface: {e}"), + Config(e) => write!(f, "failed to read tls configuration: {e}"), } } } @@ -69,6 +71,7 @@ impl std::error::Error for Error { Error::PrivKey(e) => Some(e), Error::CertAuth(e) => Some(e), Error::Bind(e) => Some(&**e), + Error::Config(e) => Some(e), } } } @@ -102,3 +105,9 @@ impl From for Error { v.into() } } + +impl From for Error { + fn from(value: figment::Error) -> Self { + Error::Config(value) + } +} diff --git a/core/lib/src/tls/listener.rs b/core/lib/src/tls/listener.rs new file mode 100644 index 0000000000..aeda6afee2 --- /dev/null +++ b/core/lib/src/tls/listener.rs @@ -0,0 +1,104 @@ +use std::io; +use std::sync::Arc; + +use futures::TryFutureExt; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::LazyConfigAcceptor; +use rustls::server::{Acceptor, ServerConfig}; + +use crate::{Ignite, Rocket}; +use crate::listener::{Bind, Certificates, Connection, Endpoint, Listener}; +use crate::tls::{Error, TlsConfig}; +use super::resolver::DynResolver; + +#[doc(inline)] +pub use tokio_rustls::server::TlsStream; + +/// A TLS listener over some listener interface L. +pub struct TlsListener { + listener: L, + config: TlsConfig, + default: Arc, +} + +impl> Bind<(T, TlsConfig)> for TlsListener + where L: Listener::Connection>, +{ + type Error = Error; + + async fn bind((inner, config): (T, TlsConfig)) -> Result { + Ok(TlsListener { + default: Arc::new(config.server_config().await?), + listener: L::bind(inner).map_err(|e| Error::Bind(Box::new(e))).await?, + config, + }) + } + + fn bind_endpoint((inner, config): &(T, TlsConfig)) -> Result { + L::bind_endpoint(inner) + .map(|e| e.with_tls(config)) + .map_err(|e| Error::Bind(Box::new(e))) + } +} + +impl<'r, L> Bind<&'r Rocket> for TlsListener + where L: Bind<&'r Rocket> + Listener::Connection> +{ + type Error = Error; + + async fn bind(rocket: &'r Rocket) -> Result { + let mut config: TlsConfig = rocket.figment().extract_inner("tls")?; + config.resolver = DynResolver::extract(rocket); + >::bind((rocket, config)).await + } + + fn bind_endpoint(rocket: &&'r Rocket) -> Result { + let config: TlsConfig = rocket.figment().extract_inner("tls")?; + >::bind_endpoint(&(*rocket, config)) + } +} + +impl Listener for TlsListener + where L: Listener::Connection>, + L::Connection: AsyncRead + AsyncWrite +{ + type Accept = L::Connection; + + type Connection = TlsStream; + + async fn accept(&self) -> io::Result { + self.listener.accept().await + } + + async fn connect(&self, conn: L::Connection) -> io::Result { + let acceptor = LazyConfigAcceptor::new(Acceptor::default(), conn); + let handshake = acceptor.await?; + let hello = handshake.client_hello(); + let config = match &self.config.resolver { + Some(r) => r.resolve(hello).await.unwrap_or_else(|| self.default.clone()), + None => self.default.clone(), + }; + + handshake.into_stream(config).await + } + + fn endpoint(&self) -> io::Result { + Ok(self.listener.endpoint()?.with_tls(&self.config)) + } +} + +impl Connection for TlsStream { + fn endpoint(&self) -> io::Result { + Ok(self.get_ref().0.endpoint()?.assume_tls()) + } + + fn certificates(&self) -> Option> { + #[cfg(feature = "mtls")] { + let cert_chain = self.get_ref().1.peer_certificates()?; + Some(Certificates::from(cert_chain)) + } + + #[cfg(not(feature = "mtls"))] + None + } +} diff --git a/core/lib/src/tls/mod.rs b/core/lib/src/tls/mod.rs index 7f5a05deaa..df9899dfb9 100644 --- a/core/lib/src/tls/mod.rs +++ b/core/lib/src/tls/mod.rs @@ -1,6 +1,11 @@ mod error; +mod resolver; +mod listener; pub(crate) mod config; -pub use error::Result; +pub use rustls; + +pub use error::{Error, Result}; pub use config::{TlsConfig, CipherSuite}; -pub use error::Error; +pub use resolver::{Resolver, ClientHello, ServerConfig}; +pub use listener::{TlsListener, TlsStream}; diff --git a/core/lib/src/tls/resolver.rs b/core/lib/src/tls/resolver.rs new file mode 100644 index 0000000000..475ec2b8e9 --- /dev/null +++ b/core/lib/src/tls/resolver.rs @@ -0,0 +1,82 @@ +use std::fmt; +use std::marker::PhantomData; +use std::ops::Deref; +use std::sync::Arc; + +pub use rustls::server::{ClientHello, ServerConfig}; + +use crate::{Build, Ignite, Rocket}; +use crate::fairing::{self, Info, Kind}; + +/// Proxy type to get PartialEq + Debug impls. +#[derive(Clone)] +pub(crate) struct DynResolver(Arc); + +pub struct Fairing(PhantomData); + +/// A dynamic TLS configuration resolver. +#[crate::async_trait] +pub trait Resolver: Send + Sync + 'static { + async fn init(rocket: &Rocket) -> crate::tls::Result where Self: Sized { + let _rocket = rocket; + let type_name = std::any::type_name::(); + Err(figment::Error::from(format!("{type_name}: Resolver::init() unimplemented")).into()) + } + + async fn resolve(&self, hello: ClientHello<'_>) -> Option>; + + fn fairing() -> Fairing where Self: Sized { + Fairing(PhantomData) + } +} + +#[crate::async_trait] +impl fairing::Fairing for Fairing { + fn info(&self) -> Info { + Info { + name: "Resolver Fairing", + kind: Kind::Ignite | Kind::Singleton + } + } + + async fn on_ignite(&self, rocket: Rocket) -> fairing::Result { + use yansi::Paint; + + let result = T::init(&rocket).await; + match result { + Ok(resolver) => Ok(rocket.manage(Arc::new(resolver) as Arc)), + Err(e) => { + let name = std::any::type_name::(); + error!("TLS resolver {} failed to initialize.", name.primary().bold()); + error_!("{e}"); + Err(rocket) + } + } + } +} + +impl DynResolver { + pub fn extract(rocket: &Rocket) -> Option { + rocket.state::>().map(|r| Self(r.clone())) + } +} + +impl fmt::Debug for DynResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Resolver").finish() + } +} + +impl PartialEq for DynResolver { + fn eq(&self, _: &Self) -> bool { + false + } +} + +impl Deref for DynResolver { + type Target = dyn Resolver; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} diff --git a/core/lib/src/util/mod.rs b/core/lib/src/util/mod.rs index 6b5df12e7d..276ff64dc0 100644 --- a/core/lib/src/util/mod.rs +++ b/core/lib/src/util/mod.rs @@ -22,39 +22,39 @@ pub fn spawn_inspect(or: F, future: Fut) use std::io; use std::pin::pin; use std::future::Future; -use futures::future::{select, Either}; +use either::Either; +use futures::future; pub trait FutureExt: Future + Sized { /// Await `self` or `other`, whichever finishes first. - async fn or(self, other: B) -> Either { - match futures::future::select(pin!(self), pin!(other)).await { - Either::Left((v, _)) => Either::Left(v), - Either::Right((v, _)) => Either::Right(v), + async fn race(self, other: B) -> Either { + match future::select(pin!(self), pin!(other)).await { + future::Either::Left((v, _)) => Either::Left(v), + future::Either::Right((v, _)) => Either::Right(v), } } - /// Await `self` unless `trigger` completes. Returns `Ok(Some(T))` if `self` - /// completes successfully before `trigger`, `Err(E)` if `self` completes - /// unsuccessfully, and `Ok(None)` if `trigger` completes before `self`. - async fn unless(self, trigger: K) -> Result, E> - where Self: Future> + async fn race_io(self, trigger: K) -> io::Result + where Self: Future> { - match select(pin!(self), pin!(trigger)).await { - Either::Left((v, _)) => Ok(Some(v?)), - Either::Right((_, _)) => Ok(None), + match future::select(pin!(self), pin!(trigger)).await { + future::Either::Left((v, _)) => v, + future::Either::Right((_, _)) => Err(io::Error::other("i/o terminated")), } } +} - /// Await `self` unless `trigger` completes. If `self` completes before - /// `trigger`, returns the result. Otherwise, always returns an `Err`. - async fn io_unless(self, trigger: K) -> std::io::Result - where Self: Future> - { - match select(pin!(self), pin!(trigger)).await { - Either::Left((v, _)) => v, - Either::Right((_, _)) => Err(io::Error::other("I/O terminated")), +impl FutureExt for F { } + +#[doc(hidden)] +#[macro_export] +macro_rules! for_both { + ($value:expr, $pattern:pat => $result:expr) => { + match $value { + tokio_util::either::Either::Left($pattern) => $result, + tokio_util::either::Either::Right($pattern) => $result, } - } + }; } -impl FutureExt for F { } +pub use for_both; diff --git a/core/lib/tests/on_launch_fairing_can_inspect_port.rs b/core/lib/tests/on_launch_fairing_can_inspect_port.rs index 9631e57cdf..a9c42cdc8f 100644 --- a/core/lib/tests/on_launch_fairing_can_inspect_port.rs +++ b/core/lib/tests/on_launch_fairing_can_inspect_port.rs @@ -3,6 +3,7 @@ use std::net::{SocketAddr, Ipv4Addr}; use rocket::config::Config; use rocket::fairing::AdHoc; use rocket::futures::channel::oneshot; +use rocket::listener::tcp::TcpListener; #[rocket::async_test] async fn on_ignite_fairing_can_inspect_port() { @@ -15,6 +16,7 @@ async fn on_ignite_fairing_can_inspect_port() { }) })); - rocket::tokio::spawn(rocket.launch_on(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))); + let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0)); + rocket::tokio::spawn(rocket.bind_launch::<_, TcpListener>(addr)); assert_ne!(rx.await.unwrap(), 0); } diff --git a/docs/tests/Cargo.toml b/docs/tests/Cargo.toml index bf45422c10..9daf3bff1d 100644 --- a/docs/tests/Cargo.toml +++ b/docs/tests/Cargo.toml @@ -10,7 +10,7 @@ rocket = { path = "../../core/lib", features = ["secrets"] } [dev-dependencies] rocket = { path = "../../core/lib", features = ["secrets", "json", "mtls"] } -figment = { version = "0.10", features = ["toml", "env"] } +figment = { version = "0.10.17", features = ["toml", "env"] } tokio = { version = "1", features = ["macros", "io-std"] } rand = "0.8" diff --git a/examples/tls/private/client.pem b/examples/tls/private/client.pem new file mode 100644 index 0000000000..a5ba33910b --- /dev/null +++ b/examples/tls/private/client.pem @@ -0,0 +1,88 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCjle13u/R/0+zw +eycXhdF7ZNYQfqXfkMpw9GlerbqRrxSLEc/YXXBuIO5AZKkXYeP8iM9KbSBD4p8F +wZD7LL47601c5WwWpNfOravCaSjYgvaYyhnoNzmG8NYaVYKB9kup6lOyQmesNXEK +NGNSrKpsoaQ7jBk+l+VV1jNBjMhNVWuz4AdFMsVD09QyL1GvQ0OvT/BbUKypaFFw +YcHruYvHuKGnrlXkvw05aZmKtKiSE6UQoDKtZWfV8yV2M6Sr75i9GKaGMyUZIl88 +MxVLGcGwO6To2wNFKfLkHLOGIWrKA7m/Bb2n1k2OT+6iOnDzU62BoAzG/j8dhNPL +mZ6a7cZfAgMBAAECggEANwiZe06gUuDZNY44+JDsiLbDzYjOBQiREq8nQ9LukVR1 +dNPpOME2sdYiUUeMG3GzYaIlGsTbtfrnxOf5/oZu+XmP7VDBrFyIvd9viVgXhb+J +dp2HWbg6gktDvFhIL7DMg71xqubsOeNAxE4bnBS6wREgT2gylfxECzykwci7Gki4 +AkeihvaxqdHk9WP8dtFOuCYhX5pyKd9veS1/L01dVMpoFrq72PHupplKYb3HIo+v +ga02DhNVcH3fomEbXzazC64k2h5Vz+8mgpu5/V1thKiB2izOwt/hv4tkf2iDNz43 +xdSYUEFsk80M97VI1dM1+TBe/JO0auZvKLkuOWUjAQKBgQDlBMr+d+guajgQ863I +uEFK4veEXrD51L6AKT+cqFUi894fhOodnnmK8l3JBKzO0zjgsaez8exKZPRN8An8 +4MejM+hMYciJsP7uDpPkhlI5zHd9CR7EFPWXXpt4PecQLvBbnJ/lDnWCrE4m5Zhs +9OR7izLMBAmaiPlTNAaXj22iqwKBgQC226wzXGr//lnTggZX+u9UdkZKewAYlgnB +Ywj3+JB6Q/kDDS8C6fdlAvWyHShxtO3gx2pJSI3hk7J8fZu/kbojlLF16ayO+tgg +t3EoTZxN5zncygPaULstdKHhnMp8a4AO8lLrHtackFbbX7fuUJft0w457FpARvM8 +DONjWI8LHQKBgBBY5TyAxpv5jQL4weDf9hkoVk6mi69plieDyjyeb2VNTv+k9yki +FL7sSfF9WfBxd0/innvjuuAckKu3hJ7+VIG7xMse97eMYMYRWFEpnVju1WChdAa/ +EEC7yhEtKf8nupRve6JYA99N+U4heV3dpSmEaB3T8/OJ73IW9pl+7W59AoGADxM/ +OCDHZYF3sFtI4Jn8fy8dDmjjkiNUfJAInkDs0FeoQNsmZAwb7ET5Moz615z9+4kV +NyN3JwDBN0g3vexqtyI8Gyd/pW4CwXe+KX90gmustoolFSuQsueprOr7OpS2QwUx +Vtb9BH1V29IhXNFiJSZARwA4VJJE3U+Gs5sKd/UCgYEAoCPE3gVaa89nOqQtalhT +9SISOGQxxMknjNFrEuF3UaGuR0cxDRLX6lSEneAATEpho0QB2Fj4vO8PiyYOGvH+ +5ouJD97rcU77OOixlLFt4+TAWI9AvT0mN7y+SHJ22RkwWGQyF4TIfkg0tQvu36D+ +35W26Li1WteB2O4wV9qVReA= +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIEwDCCAqigAwIBAgIUay5Z8sVQUkSTFpacn6o4iq2ElGowDQYJKoZIhvcNAQEL +BQAwRzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQKDAlSb2NrZXQg +Q0ExFzAVBgNVBAMMDlJvY2tldCBSb290IENBMB4XDTI0MDQxNDA4MTU0MVoXDTI1 +MDQxNDA4MTU0MVowgY4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlh +MRcwFQYDVQQHDA5TaWxpY29uIFZhbGxleTEPMA0GA1UECgwGUm9ja2V0MRswGQYD +VQQDDBJSb2NrZXQgVExTIEV4YW1wbGUxIzAhBgkqhkiG9w0BCQEWFGV4YW1wbGVA +cm9ja2V0LmxvY2FsMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAo5Xt +d7v0f9Ps8HsnF4XRe2TWEH6l35DKcPRpXq26ka8UixHP2F1wbiDuQGSpF2Hj/IjP +Sm0gQ+KfBcGQ+yy+O+tNXOVsFqTXzq2rwmko2IL2mMoZ6Dc5hvDWGlWCgfZLqepT +skJnrDVxCjRjUqyqbKGkO4wZPpflVdYzQYzITVVrs+AHRTLFQ9PUMi9Rr0NDr0/w +W1CsqWhRcGHB67mLx7ihp65V5L8NOWmZirSokhOlEKAyrWVn1fMldjOkq++YvRim +hjMlGSJfPDMVSxnBsDuk6NsDRSny5ByzhiFqygO5vwW9p9ZNjk/uojpw81OtgaAM +xv4/HYTTy5memu3GXwIDAQABo1wwWjAYBgNVHREEETAPgg1ETlM6bG9jYWxob3N0 +MB0GA1UdDgQWBBSowDBXM26C7VogwXNB1F0vLpYO7DAfBgNVHSMEGDAWgBREAyUj +0lTwopZ2B1VmnvMPfUtCkzANBgkqhkiG9w0BAQsFAAOCAgEAbjF11+t8qVEF72ey +19p1sRkG9ygb0gE2UpLzVpPilucioIOwQuT4rvsVYZQxK+smQZURDI4uNXODIeoS +r3maL82VryLSYbkQADyShYjF0uCX8AfCI0YtOKOschNZDcZEJ5mUpHjJE0lEZnkO +x8ZVXwWf4pv1/8DZoCkMN3gDHwhQGPtrls4q7O38rI7zK9DNrzu7R1ZdGjQSDasL +6DqHee90O2ejpELUxO6lRl2EUosfklRvjV7hfrDHlpN9EuweXt0JiaKw3WZzHSLa +dKS8wtTMq5aWzOWrew1ZEhRr+B3KS6BSC5o9xSQMfcDyS0KJcIJI9bNh3nElWFhM +IBVtGxM/EYAwNJ++jLD10WHvaqW0epMV2cUu+dGJX+TPuI0c/wNehisS4ahvR64m +UpjAwNUBlYpR/Gb15/i2fVk2BbUyU3AcpZfWFDopQ8UqC8ALVcNjbNHq+yVkuTpj +gn5iiTTcTqb6qNfie4oDX4KR6ZgpNiTl/PWZo58qxSwdGiJwrINACkPJ6Qg6Qrpd +hp3CanTWjioHfvTSdiubqw5/XRnqa2Iav0Sttc6TPnTimodmtWkaYA8mvjS+jq8N +f9l2UYQz8yLabMkn98BM+gRJYwrVt6sCbVuEaHgPwq/qX9mQFhUrfw3iEPKlmezt +T3AhgPhybUpMFpu+4Tp8JE2JlKQ= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIFbzCCA1egAwIBAgIURX345HUrWikAysSTFd8xoV5GSIYwDQYJKoZIhvcNAQEL +BQAwRzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQKDAlSb2NrZXQg +Q0ExFzAVBgNVBAMMDlJvY2tldCBSb290IENBMB4XDTIxMDcwOTIzMzMzM1oXDTMx +MDcwNzIzMzMzM1owRzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRIwEAYDVQQK +DAlSb2NrZXQgQ0ExFzAVBgNVBAMMDlJvY2tldCBSb290IENBMIICIjANBgkqhkiG +9w0BAQEFAAOCAg8AMIICCgKCAgEAybxw0cVrq8yn9W9xRDHdci8rnA5CxPcxAyM5 +y5LCFOV/tTY0IZgrmegRLo4blmz8QNiFta9Ydt7zsm7XUTm6BhJ7TkOUAcfP7eSv +3jNEIEJQLU+k5SepV7pwFPRjUr6+a7yypS2xXAkDEVoyvzsuKYwzj+x6HvDuVhOF +2zv4Kk0sLfS/3UelMdilKa5VBCL/WMEXaCpb7/BMUUwn868LVU8E9+1H6uDQMxKo +ZH1mH98yeFODjzM9Ko6n2ghXx8qbe+wab4mSHn/SPgFnDFU+ujyPXIQqrS4PSQW3 +5lkCn70hOw2K+8LHDBmgxOLk2Jb8o8PJWX6v346dlRcZr9VzMqCyKvEf1i5oT2hg +NZrkDdUOgyMZeq6H7pQpSxSFSMtkaombSm816V0rg7/sXwS66KyaYJY7x8eYEpgd +GuQKXkyIwp687TGLul97amoy/J3jIDnQOuf/YEcdyHCKojh20E5AERC4sCg6l+qs +5Nbol7jZclzBFf+70JOsUFmCfVYd5e0LKWdYV9UhYABc3yQqJyzy/eyihWihUNZU +LXStjd+XIkhKs+b7uKaBp1poFfgjpdboxmREyppWexua1t0eAReBgMU43bEGoy+B +iWoTFjyeQijd6M++npzsqwknYyv+7VjX3EfijyTFgIpZUL196PTJ5SGJMf7eJmaG +BO0g2W0CAwEAAaNTMFEwHQYDVR0OBBYEFEQDJSPSVPCilnYHVWae8w99S0KTMB8G +A1UdIwQYMBaAFEQDJSPSVPCilnYHVWae8w99S0KTMA8GA1UdEwEB/wQFMAMBAf8w +DQYJKoZIhvcNAQELBQADggIBACCArR/ArOyoh97Pgie37miFJEJNtAe+ApzhDevh +11P0Vn5hbu+dR7vftCJ7e+0u0irbPgfdxuW3IpEDgL+fttqCIRdAT6MTKLiUmrVS +x0fQJqC4Hw4o+bIeRsaNAfB/SAEvOxBbyu5szOFak1r/sXVs4vzBINIF3NdtbNtj +Bhac0Fiy/+DlfTHJSRGvzYo+GljXHkrG02mF4aOWx9x97y/6UzbLqHJPINgyAIlN +ts29QIHVNtQQyUN292xC1F4TSrBNB+GziGt3XZ8YEASCkMEnIvs3Lpzsjjm9TrkE +W/b9ee3C6RWg+RW3pokORMM7Q/lSOMWUmPrzI7CBCKaQUNN9g+iimLkPyp386sCS +zXJDd0OKb0xkpxhrauEvzNfEJxGDQbxs8s598ZofhVo9ehdmmXcJAw/zUZjHSrI2 +PW+vHJ4kslBmKtH1oyAW3zYiFyYYPu4ohkeSrq8z8351upxwJUm4m/ndByXTrPwz +Yj6dEHaysjoRl0wOJgQ7G2ikw1QtWja2apJN9Q66i98vEDmtoEyOqOLMSjKjFL7c +sSJ6vAittYtIziIeMK7E8lDc1rtzMT5MOAoTriVyIGBgHFs96YOoL0Vi5QmVtQtc +8dkFUapFAUj8pREVxnJoLGose/FxBvF2FQZ5Sb25pyTPAeXk7y56noF78nusiVSF +xRjI +-----END CERTIFICATE----- diff --git a/examples/tls/private/gen_certs.sh b/examples/tls/private/gen_certs.sh index d98d152e55..cb68ac5a1f 100755 --- a/examples/tls/private/gen_certs.sh +++ b/examples/tls/private/gen_certs.sh @@ -9,6 +9,7 @@ # ecdsa_nistp256_sha256 # ecdsa_nistp384_sha384 # ecdsa_nistp521_sha512 +# client # # Generate a certificate of the [cert-kind] key type, or if no cert-kind is # specified, all of the certificates. @@ -136,12 +137,23 @@ function gen_ecdsa_nistp521_sha512() { rm ca_cert.srl server.csr ecdsa_nistp521_sha512_key.pem } +function gen_client_cert() { + openssl req -newkey rsa:2048 -nodes -keyout client.key -out client.csr + openssl x509 -req -extfile <(printf "subjectAltName=DNS:${ALT}") -days 365 \ + -in client.csr -CA ca_cert.pem -CAkey ca_key.pem -CAcreateserial \ + -out client.crt + + cat client.key client.crt ca_cert.pem > client.pem + rm client.key client.crt client.csr ca_cert.srl +} + case $1 in ed25519) gen_ed25519 ;; rsa_sha256) gen_rsa_sha256 ;; ecdsa_nistp256_sha256) gen_ecdsa_nistp256_sha256 ;; ecdsa_nistp384_sha384) gen_ecdsa_nistp384_sha384 ;; ecdsa_nistp521_sha512) gen_ecdsa_nistp521_sha512 ;; + client) gen_client_cert ;; *) gen_ed25519 gen_rsa_sha256 diff --git a/examples/tls/src/redirector.rs b/examples/tls/src/redirector.rs index 43892108f7..20276e29f0 100644 --- a/examples/tls/src/redirector.rs +++ b/examples/tls/src/redirector.rs @@ -7,6 +7,7 @@ use rocket::log::LogLevel; use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite}; use rocket::fairing::{Fairing, Info, Kind}; use rocket::response::Redirect; +use rocket::listener::tcp::TcpListener; use yansi::Paint; @@ -59,7 +60,7 @@ impl Redirector { rocket::custom(&config.server) .manage(config) .mount("/", redirects) - .launch_on(addr) + .bind_launch::<_, TcpListener>(addr) .await } } diff --git a/examples/tls/src/tests.rs b/examples/tls/src/tests.rs index 413e64fd9d..7cf93d5e33 100644 --- a/examples/tls/src/tests.rs +++ b/examples/tls/src/tests.rs @@ -66,8 +66,7 @@ fn insecure_cookies() { } fn validate_profiles(profiles: &[&str]) { - use rocket::listener::DefaultListener; - use rocket::config::{Config, SecretKey}; + use rocket::config::{Config, TlsConfig, SecretKey}; for profile in profiles { let config = Config { @@ -81,9 +80,8 @@ fn validate_profiles(profiles: &[&str]) { assert_eq!(response.into_string().unwrap(), "Hello, world!"); let figment = client.rocket().figment(); - let listener: DefaultListener = figment.extract().unwrap(); - assert_eq!(figment.profile(), profile); - listener.tls.as_ref().unwrap().validate().expect("valid TLS config"); + let config: TlsConfig = figment.extract_inner("tls").unwrap(); + config.validate().expect("valid TLS config"); } } diff --git a/testbench/Cargo.toml b/testbench/Cargo.toml index 57b5e9ffb5..c640a9720c 100644 --- a/testbench/Cargo.toml +++ b/testbench/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "rocket-testbench" -description = "end-to-end HTTP testbench for Rocket" +name = "testbench" +description = "End-to-end HTTP Rocket testbench." version = "0.0.0" edition = "2021" publish = false @@ -12,6 +12,7 @@ thiserror = "1.0" procspawn = "1" pretty_assertions = "1.4.0" ipc-channel = "0.18" +rustls-pemfile = "2.1" [dependencies.nix] version = "0.28" diff --git a/testbench/src/client.rs b/testbench/src/client.rs index ec11374de1..c3f3fda3d2 100644 --- a/testbench/src/client.rs +++ b/testbench/src/client.rs @@ -1,206 +1,64 @@ use std::time::Duration; -use std::sync::Once; -use std::process::Stdio; -use std::io::{self, Read}; -use rocket::fairing::AdHoc; -use rocket::http::ext::IntoOwned; -use rocket::http::uri::{self, Absolute, Uri}; -use rocket::serde::{Deserialize, Serialize}; -use rocket::{Build, Rocket}; +use reqwest::blocking::{ClientBuilder, RequestBuilder}; +use rocket::http::{ext::IntoOwned, uri::{Absolute, Uri}}; -use procspawn::SpawnError; -use thiserror::Error; -use ipc_channel::ipc::{IpcOneShotServer, IpcReceiver, IpcSender}; - -static DEFAULT_CONFIG: &str = r#" - [default] - address = "tcp:127.0.0.1" - workers = 2 - port = 0 - cli_colors = false - secret_key = "itlYmFR2vYKrOmFhupMIn/hyB6lYCCTXz4yaQX89XVg=" - - [default.shutdown] - grace = 1 - mercy = 1 -"#; +use crate::{Result, Error, Server}; #[derive(Debug)] -#[allow(unused)] pub struct Client { client: reqwest::blocking::Client, - server: procspawn::JoinHandle<()>, - tls: bool, - port: u16, - rx: IpcReceiver, -} - -#[derive(Error, Debug)] -pub enum Error { - #[error("join/kill failed: {0}")] - JoinError(#[from] SpawnError), - #[error("kill failed: {0}")] - TermFailure(#[from] nix::errno::Errno), - #[error("i/o error: {0}")] - Io(#[from] io::Error), - #[error("invalid URI: {0}")] - Uri(#[from] uri::Error<'static>), - #[error("the URI is invalid")] - InvalidUri, - #[error("bad request: {0}")] - Request(#[from] reqwest::Error), - #[error("IPC failure: {0}")] - Ipc(#[from] ipc_channel::ipc::IpcError), - #[error("liftoff failed")] - Liftoff(String, String), -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(crate = "rocket::serde")] -pub enum Message { - Liftoff(bool, u16), - Failure, -} - -#[derive(Serialize, Deserialize)] -#[serde(crate = "rocket::serde")] -#[must_use] -pub struct Token(String); - -pub type Result = std::result::Result; - -impl Token { - fn configure(&self, toml: &str, rocket: Rocket) -> Rocket { - use rocket::figment::{Figment, providers::{Format, Toml}}; - - let toml = toml.replace("{CRATE}", env!("CARGO_MANIFEST_DIR")); - let config = Figment::from(rocket.figment()) - .merge(Toml::string(DEFAULT_CONFIG).nested()) - .merge(Toml::string(&toml).nested()); - - let server = self.0.clone(); - rocket.configure(config) - .attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move { - let tcp = rocket.endpoints().find_map(|e| e.tcp()).unwrap(); - let tls = rocket.endpoints().any(|e| e.is_tls()); - let sender = IpcSender::::connect(server).unwrap(); - let _ = sender.send(Message::Liftoff(tls, tcp.port())); - let _ = sender.send(Message::Liftoff(tls, tcp.port())); - }))) - } - - pub fn rocket(&self, toml: &str) -> Rocket { - self.configure(toml, rocket::build()) - } - - pub fn configured_launch(self, toml: &str, rocket: Rocket) { - let rocket = self.configure(toml, rocket); - if let Err(e) = rocket::execute(rocket.launch()) { - let sender = IpcSender::::connect(self.0).unwrap(); - let _ = sender.send(Message::Failure); - let _ = sender.send(Message::Failure); - e.pretty_print(); - std::process::exit(1); - } - } - - pub fn launch(self, rocket: Rocket) { - self.configured_launch(DEFAULT_CONFIG, rocket) - } -} -pub fn start(f: fn(Token)) -> Result { - static INIT: Once = Once::new(); - INIT.call_once(procspawn::init); - - let (ipc, server) = IpcOneShotServer::new()?; - let mut server = procspawn::Builder::new() - .stdin(Stdio::null()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn(Token(server), f); - - let client = reqwest::blocking::Client::builder() - .danger_accept_invalid_certs(true) - .cookie_store(true) - .tls_info(true) - .timeout(Duration::from_secs(5)) - .connect_timeout(Duration::from_secs(5)) - .build()?; - - let (rx, _) = ipc.accept().unwrap(); - match rx.recv() { - Ok(Message::Liftoff(tls, port)) => Ok(Client { client, server, tls, port, rx }), - Ok(Message::Failure) => { - let stdout = server.stdout().unwrap(); - let mut out = String::new(); - stdout.read_to_string(&mut out)?; - - let stderr = server.stderr().unwrap(); - let mut err = String::new(); - stderr.read_to_string(&mut err)?; - Err(Error::Liftoff(out, err)) - } - Err(e) => Err(e.into()), - } - -} - -pub fn default() -> Result { - start(|token| token.launch(rocket::build())) } impl Client { - pub fn read_stdout(&mut self) -> Result { - let Some(stdout) = self.server.stdout() else { - return Ok(String::new()); - }; - - let mut string = String::new(); - stdout.read_to_string(&mut string)?; - Ok(string) + pub fn default() -> Client { + Client::build() + .try_into() + .expect("default builder ok") } - pub fn read_stderr(&mut self) -> Result { - let Some(stderr) = self.server.stderr() else { - return Ok(String::new()); - }; - - let mut string = String::new(); - stderr.read_to_string(&mut string)?; - Ok(string) - } - - pub fn kill(&mut self) -> Result<()> { - Ok(self.server.kill()?) + pub fn build() -> ClientBuilder { + reqwest::blocking::Client::builder() + .danger_accept_invalid_certs(true) + .cookie_store(true) + .tls_info(true) + .timeout(Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(5)) } - pub fn terminate(&mut self) -> Result<()> { - use nix::{sys::signal, unistd::Pid}; - - let pid = Pid::from_raw(self.server.pid().unwrap() as i32); - Ok(signal::kill(pid, signal::SIGTERM)?) - } - - pub fn wait(&mut self) -> Result<()> { - match self.server.join_timeout(Duration::from_secs(5)) { - Ok(_) => Ok(()), - Err(e) if e.is_remote_close() => Ok(()), - Err(e) => Err(e.into()), - } - } - - pub fn get(&self, url: &str) -> Result { + pub fn get(&self, server: &Server, url: &str) -> Result { let uri = match Uri::parse_any(url).map_err(|e| e.into_owned())? { Uri::Origin(uri) => { - let proto = if self.tls { "https" } else { "http" }; - let uri = format!("{proto}://127.0.0.1:{}{uri}", self.port); + let proto = if server.tls { "https" } else { "http" }; + let uri = format!("{proto}://127.0.0.1:{}{uri}", server.port); Absolute::parse_owned(uri)? } - Uri::Absolute(uri) => uri, - _ => return Err(Error::InvalidUri), + Uri::Absolute(mut uri) => { + if let Some(auth) = uri.authority() { + let mut auth = auth.clone(); + auth.set_port(server.port); + uri.set_authority(auth); + } + + uri + } + uri => return Err(Error::InvalidUri(uri.into_owned())), }; Ok(self.client.get(uri.to_string())) } } + +impl From for Client { + fn from(client: reqwest::blocking::Client) -> Self { + Client { client } + } +} + +impl TryFrom for Client { + type Error = Error; + + fn try_from(builder: ClientBuilder) -> Result { + Ok(Client { client: builder.build()? }) + } +} diff --git a/testbench/src/lib.rs b/testbench/src/lib.rs index c8ab57bdc4..1c34da7544 100644 --- a/testbench/src/lib.rs +++ b/testbench/src/lib.rs @@ -1,3 +1,35 @@ -pub mod client; +// pub mod session; +mod client; +mod server; +pub use server::*; pub use client::*; + +use std::io; +use thiserror::Error; +use procspawn::SpawnError; +use rocket::http::uri; + +pub type Result = std::result::Result; + +#[derive(Error, Debug)] +pub enum Error { + #[error("join/kill failed: {0}")] + JoinError(#[from] SpawnError), + #[error("kill failed: {0}")] + TermFailure(#[from] nix::errno::Errno), + #[error("i/o error: {0}")] + Io(#[from] io::Error), + #[error("invalid URI: {0}")] + Uri(#[from] uri::Error<'static>), + #[error("invalid uri: {0}")] + InvalidUri(uri::Uri<'static>), + #[error("expected certificates are not present")] + MissingCertificate, + #[error("bad request: {0}")] + Request(#[from] reqwest::Error), + #[error("IPC failure: {0}")] + Ipc(#[from] ipc_channel::ipc::IpcError), + #[error("liftoff failed")] + Liftoff(String, String), +} diff --git a/testbench/src/main.rs b/testbench/src/main.rs index 22ecd8dd15..44b1a24fac 100644 --- a/testbench/src/main.rs +++ b/testbench/src/main.rs @@ -1,54 +1,92 @@ -use rocket::{fairing::AdHoc, *}; -use rocket_testbench::client::{self, Error}; -use reqwest::tls::TlsInfo; +use std::process::ExitCode; -fn run() -> client::Result<()> { - let mut client = client::start(|token| { - #[get("/")] - fn index() -> &'static str { - "Hello, world!" - } +use rocket::listener::unix::UnixListener; +use rocket::tokio::net::TcpListener; +use rocket::yansi::Paint; +use rocket::{get, routes, Build, Rocket, State}; +use reqwest::{tls::TlsInfo, Identity}; +use testbench::*; - token.configured_launch(r#" - [default.tls] - certs = "{CRATE}/../examples/tls/private/rsa_sha256_cert.pem" - key = "{CRATE}/../examples/tls/private/rsa_sha256_key.pem" - "#, rocket::build().mount("/", routes![index])); - })?; +static DEFAULT_CONFIG: &str = r#" + [default] + address = "tcp:127.0.0.1" + workers = 2 + port = 0 + cli_colors = false + secret_key = "itlYmFR2vYKrOmFhupMIn/hyB6lYCCTXz4yaQX89XVg=" - let response = client.get("/")?.send()?; - let tls = response.extensions().get::().unwrap(); - assert!(!tls.peer_certificate().unwrap().is_empty()); - assert_eq!(response.text()?, "Hello, world!"); + [default.shutdown] + grace = 1 + mercy = 1 +"#; - client.terminate()?; - let stdout = client.read_stdout()?; - assert!(stdout.contains("Rocket has launched on https")); - assert!(stdout.contains("Graceful shutdown completed")); - assert!(stdout.contains("GET /")); - Ok(()) +static TLS_CONFIG: &str = r#" + [default.tls] + certs = "{ROCKET}/examples/tls/private/rsa_sha256_cert.pem" + key = "{ROCKET}/examples/tls/private/rsa_sha256_key.pem" +"#; + +trait RocketExt { + fn default() -> Self; + fn tls_default() -> Self; + fn configure_with_toml(self, toml: &str) -> Self; } -fn run_fail() -> client::Result<()> { - let client = client::start(|token| { +impl RocketExt for Rocket { + fn default() -> Self { + rocket::build().configure_with_toml(DEFAULT_CONFIG) + } + + fn tls_default() -> Self { + rocket::build() + .configure_with_toml(DEFAULT_CONFIG) + .configure_with_toml(TLS_CONFIG) + } + + fn configure_with_toml(self, toml: &str) -> Self { + use rocket::figment::{Figment, providers::{Format, Toml}}; + + let toml = toml.replace("{ROCKET}", rocket::fs::relative!("../")); + let config = Figment::from(self.figment()) + .merge(Toml::string(&toml).nested()); + + self.configure(config) + } +} + +fn read(path: &str) -> Result> { + let path = path.replace("{ROCKET}", rocket::fs::relative!("../")); + Ok(std::fs::read(path)?) +} + +fn cert(path: &str) -> Result> { + let mut data = std::io::Cursor::new(read(path)?); + let cert = rustls_pemfile::certs(&mut data).last(); + Ok(cert.ok_or(Error::MissingCertificate)??.to_vec()) +} + +fn run_fail() -> Result<()> { + use rocket::fairing::AdHoc; + + let server = spawn! { let fail = AdHoc::try_on_ignite("FailNow", |rocket| async { Err(rocket) }); - token.launch(rocket::build().attach(fail)); - }); + Rocket::default().attach(fail) + }; - if let Err(Error::Liftoff(stdout, _)) = client { + if let Err(Error::Liftoff(stdout, _)) = server { assert!(stdout.contains("Rocket failed to launch due to failing fairings")); assert!(stdout.contains("FailNow")); } else { - panic!("unexpected result: {client:#?}"); + panic!("unexpected result: {server:#?}"); } Ok(()) } -fn infinite() -> client::Result<()> { +fn infinite() -> Result<()> { use rocket::response::stream::TextStream; - let mut client = client::start(|token| { + let mut server = spawn! { #[get("/")] fn infinite() -> TextStream![&'static str] { TextStream! { @@ -58,37 +96,338 @@ fn infinite() -> client::Result<()> { } } - token.launch(rocket::build().mount("/", routes![infinite])); - })?; + Rocket::default().mount("/", routes![infinite]) + }?; + + let client = Client::default(); + client.get(&server, "/")?.send()?; + server.terminate()?; - client.get("/")?.send()?; - client.terminate()?; - let stdout = client.read_stdout()?; + let stdout = server.read_stdout()?; assert!(stdout.contains("Rocket has launched on http")); assert!(stdout.contains("GET /")); assert!(stdout.contains("Graceful shutdown completed")); Ok(()) } -fn main() { - let names = ["run", "run_fail", "infinite"]; - let tests = [run, run_fail, infinite]; - let handles = tests.into_iter() - .map(|test| std::thread::spawn(test)) - .collect::>(); +fn tls_info() -> Result<()> { + let mut server = spawn! { + #[get("/")] + fn hello_world() -> &'static str { + "Hello, world!" + } + + Rocket::tls_default().mount("/", routes![hello_world]) + }?; + + let client = Client::default(); + let response = client.get(&server, "/")?.send()?; + let tls = response.extensions().get::().unwrap(); + assert!(!tls.peer_certificate().unwrap().is_empty()); + assert_eq!(response.text()?, "Hello, world!"); + + server.terminate()?; + let stdout = server.read_stdout()?; + assert!(stdout.contains("Rocket has launched on https")); + assert!(stdout.contains("Graceful shutdown completed")); + assert!(stdout.contains("GET /")); + Ok(()) +} + +fn tls_resolver() -> Result<()> { + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use rocket::tls::rustls::{server::ClientHello, ServerConfig}; + use rocket::tls::{Resolver, TlsConfig}; + + struct CountingResolver { + config: Arc, + counter: Arc, + } + + #[rocket::async_trait] + impl Resolver for CountingResolver { + async fn init(rocket: &Rocket) -> rocket::tls::Result { + let config: TlsConfig = rocket.figment().extract_inner("tls")?; + let config = Arc::new(config.server_config().await?); + let counter = rocket.state::>().unwrap().clone(); + Ok(Self { config, counter }) + } + + async fn resolve(&self, _: ClientHello<'_>) -> Option> { + self.counter.fetch_add(1, Ordering::Release); + Some(self.config.clone()) + } + } + + let server = spawn! { + #[get("/count")] + fn count(counter: &State>) -> String { + let count = counter.load(Ordering::Acquire); + println!("{count}"); + count.to_string() + } + + let counter = Arc::new(AtomicUsize::new(0)); + Rocket::tls_default() + .manage(counter) + .mount("/", routes![count]) + .attach(CountingResolver::fairing()) + }?; + + let client = Client::default(); + let response = client.get(&server, "/count")?.send()?; + assert_eq!(response.text()?, "1"); + + // Use a new client so we get a new TLS session. + let client = Client::default(); + let response = client.get(&server, "/count")?.send()?; + assert_eq!(response.text()?, "2"); + Ok(()) +} + +fn test_mtls(mandatory: bool) -> Result<()> { + let server = spawn!(mandatory: bool => { + let mtls_config = format!(r#" + [default.tls.mutual] + ca_certs = "{{ROCKET}}/examples/tls/private/ca_cert.pem" + mandatory = {mandatory} + "#); + + #[get("/")] + fn hello(cert: rocket::mtls::Certificate<'_>) -> String { + format!("{}:{}[{}] {}", cert.serial(), cert.version(), cert.issuer(), cert.subject()) + } + + #[get("/", rank = 2)] + fn hi() -> &'static str { + "Hello!" + } + + Rocket::tls_default() + .configure_with_toml(&mtls_config) + .mount("/", routes![hello, hi]) + })?; + + let pem = read("{ROCKET}/examples/tls/private/client.pem")?; + let client: Client = Client::build() + .identity(Identity::from_pem(&pem)?) + .try_into()?; + + let response = client.get(&server, "/")?.send()?; + assert_eq!(response.text()?, + "611895682361338926795452113263857440769284805738:2\ + [C=US, ST=CA, O=Rocket CA, CN=Rocket Root CA] \ + C=US, ST=California, L=Silicon Valley, O=Rocket, \ + CN=Rocket TLS Example, Email=example@rocket.local"); + + let client = Client::default(); + let response = client.get(&server, "/")?.send(); + if mandatory { + assert!(response.unwrap_err().is_request()); + } else { + assert_eq!(response?.text()?, "Hello!"); + } + + Ok(()) +} + +fn tls_mtls() -> Result<()> { + test_mtls(false)?; + test_mtls(true) +} + +fn sni_resolver() -> Result<()> { + use std::sync::Arc; + use std::collections::HashMap; + + use rocket::http::uri::Host; + use rocket::tls::rustls::{server::ClientHello, ServerConfig}; + use rocket::tls::{Resolver, TlsConfig}; + + struct SniResolver { + default: Arc, + map: HashMap, Arc> + } + + #[rocket::async_trait] + impl Resolver for SniResolver { + async fn init(rocket: &Rocket) -> rocket::tls::Result { + let default: TlsConfig = rocket.figment().extract_inner("tls")?; + let sni: HashMap, TlsConfig> = rocket.figment().extract_inner("tls.sni")?; + + let default = Arc::new(default.server_config().await?); + let mut map = HashMap::new(); + for (host, config) in sni { + let config = config.server_config().await?; + map.insert(host, Arc::new(config)); + } + + Ok(SniResolver { default, map }) + } + + async fn resolve(&self, hello: ClientHello<'_>) -> Option> { + if let Some(Ok(host)) = hello.server_name().map(Host::parse) { + if let Some(config) = self.map.get(&host) { + return Some(config.clone()); + } + } + + Some(self.default.clone()) + } + } + + static SNI_TLS_CONFIG: &str = r#" + [default.tls] + certs = "{ROCKET}/examples/tls/private/rsa_sha256_cert.pem" + key = "{ROCKET}/examples/tls/private/rsa_sha256_key.pem" + + [default.tls.sni."sni1.dev"] + certs = "{ROCKET}/examples/tls/private/ecdsa_nistp256_sha256_cert.pem" + key = "{ROCKET}/examples/tls/private/ecdsa_nistp256_sha256_key_pkcs8.pem" + + [default.tls.sni."sni2.dev"] + certs = "{ROCKET}/examples/tls/private/ed25519_cert.pem" + key = "{ROCKET}/examples/tls/private/ed25519_key.pem" + "#; + + let server = spawn! { + #[get("/")] fn index() { } + + Rocket::default() + .configure_with_toml(SNI_TLS_CONFIG) + .mount("/", routes![index]) + .attach(SniResolver::fairing()) + }?; + + let client: Client = Client::build() + .resolve("unknown.dev", server.socket_addr()) + .resolve("sni1.dev", server.socket_addr()) + .resolve("sni2.dev", server.socket_addr()) + .try_into()?; + + let response = client.get(&server, "https://unknown.dev")?.send()?; + let tls = response.extensions().get::().unwrap(); + let expected = cert("{ROCKET}/examples/tls/private/rsa_sha256_cert.pem")?; + assert_eq!(tls.peer_certificate().unwrap(), expected); + + let response = client.get(&server, "https://sni1.dev")?.send()?; + let tls = response.extensions().get::().unwrap(); + let expected = cert("{ROCKET}/examples/tls/private/ecdsa_nistp256_sha256_cert.pem")?; + assert_eq!(tls.peer_certificate().unwrap(), expected); + + let response = client.get(&server, "https://sni2.dev")?.send()?; + let tls = response.extensions().get::().unwrap(); + let expected = cert("{ROCKET}/examples/tls/private/ed25519_cert.pem")?; + assert_eq!(tls.peer_certificate().unwrap(), expected); + Ok(()) +} + +fn tcp_unix_listener_fail() -> Result<()> { + let server = spawn! { + Rocket::default().configure_with_toml("[default]\naddress = 123") + }; + + if let Err(Error::Liftoff(stdout, _)) = server { + assert!(stdout.contains("expected valid TCP")); + assert!(stdout.contains("for key default.address")); + } else { + panic!("unexpected result: {server:#?}"); + } + + let server = Server::spawn((), |(token, _)| { + let rocket = Rocket::default().configure_with_toml("[default]\naddress = \"unix:foo\""); + token.launch_with::(rocket) + }); + + if let Err(Error::Liftoff(stdout, _)) = server { + assert!(stdout.contains("invalid tcp endpoint: unix:foo")); + } else { + panic!("unexpected result: {server:#?}"); + } + + let server = Server::spawn((), |(token, _)| { + token.launch_with::(Rocket::default()) + }); + + if let Err(Error::Liftoff(stdout, _)) = server { + assert!(stdout.contains("invalid unix endpoint: tcp:127.0.0.1:8000")); + } else { + panic!("unexpected result: {server:#?}"); + } + + Ok(()) +} + +macro_rules! tests { + ($($f:ident),* $(,)?) => {[ + $(Test { name: stringify!($f), func: $f, }),* + ]}; +} + +#[derive(Copy, Clone)] +struct Test { + name: &'static str, + func: fn() -> Result<()>, +} + +static TESTS: &[Test] = &tests![ + run_fail, infinite, tls_info, tls_resolver, tls_mtls, sni_resolver, + tcp_unix_listener_fail +]; + +fn main() -> ExitCode { + let filter = std::env::args().nth(1).unwrap_or_default(); + let filtered = TESTS.into_iter().filter(|test| test.name.contains(&filter)); + + println!("running {}/{} tests", filtered.clone().count(), TESTS.len()); + let handles: Vec<_> = filtered + .map(|test| (test, std::thread::spawn(move || { + if let Err(e) = (test.func)() { + println!("test {} ... {}\n {e}", test.name.bold(), "fail".red()); + return Err(e); + } + + println!("test {} ... {}", test.name.bold(), "ok".green()); + Ok(()) + }))) + .collect(); let mut failure = false; - for (handle, name) in handles.into_iter().zip(names) { + for (test, handle) in handles { let result = handle.join(); - failure = failure || matches!(result, Ok(Err(_)) | Err(_)); - match result { - Ok(Ok(_)) => continue, - Ok(Err(e)) => eprintln!("{name} failed: {e}"), - Err(_) => eprintln!("{name} failed (see panic above)"), + failure |= matches!(result, Err(_) | Ok(Err(_))); + if result.is_err() { + println!("test {} ... {}", test.name.bold(), "panic".red().underline()); } } - if failure { - std::process::exit(1); + match failure { + true => ExitCode::FAILURE, + false => ExitCode::SUCCESS } } + +// struct UpdatingResolver { +// timestamp: AtomicU64, +// config: ArcSwap +// } +// +// #[crate::async_trait] +// impl Resolver for UpdatingResolver { +// async fn resolve(&self, _: ClientHello<'_>) -> Option> { +// if let Either::Left(path) = self.tls_config.certs() { +// let metadata = tokio::fs::metadata(&path).await.ok()?; +// let modtime = metadata.modified().ok()?; +// let timestamp = modtime.duration_since(UNIX_EPOCH).ok()?.as_secs(); +// let old_timestamp = self.timestamp.load(Ordering::Acquire); +// if timestamp > old_timestamp { +// let new_config = self.tls_config.to_server_config().await.ok()?; +// self.server_config.store(Arc::new(new_config)); +// self.timestamp.store(timestamp, Ordering::Release); +// } +// } +// +// Some(self.server_config.load_full()) +// } +// } diff --git a/testbench/src/server.rs b/testbench/src/server.rs new file mode 100644 index 0000000000..2d107e4845 --- /dev/null +++ b/testbench/src/server.rs @@ -0,0 +1,178 @@ +use std::net::{Ipv4Addr, SocketAddr}; +use std::time::Duration; +use std::sync::Once; +use std::process::Stdio; +use std::io::Read; + +use rocket::fairing::AdHoc; +use rocket::listener::{Bind, DefaultListener}; +use rocket::serde::{Deserialize, DeserializeOwned, Serialize}; +use rocket::{Build, Ignite, Rocket}; + +use ipc_channel::ipc::{IpcOneShotServer, IpcReceiver, IpcSender}; + +use crate::{Result, Error}; + +#[derive(Debug)] +pub struct Server { + proc: procspawn::JoinHandle, + pub tls: bool, + pub port: u16, + _rx: IpcReceiver, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(crate = "rocket::serde")] +pub enum Message { + Liftoff(bool, u16), + Failure, +} + +#[derive(Serialize, Deserialize)] +#[serde(crate = "rocket::serde")] +pub struct Token(String); + +#[derive(Serialize, Deserialize)] +#[serde(crate = "rocket::serde")] +pub struct Launched(()); + +fn stdio() -> Stdio { + std::env::var_os("NOCAPTURE") + .map(|_| Stdio::inherit()) + .unwrap_or_else(|| Stdio::piped()) +} + +impl Server { + pub fn spawn(ctxt: T, f: fn((Token, T)) -> Launched) -> Result + where T: Serialize + DeserializeOwned + { + static INIT: Once = Once::new(); + INIT.call_once(procspawn::init); + + let (ipc, server) = IpcOneShotServer::new()?; + let mut proc = procspawn::Builder::new() + .stdin(Stdio::null()) + .stdout(stdio()) + .stderr(stdio()) + .spawn((Token(server), ctxt), f); + + let (rx, _) = ipc.accept().unwrap(); + match rx.recv()? { + Message::Liftoff(tls, port) => { + Ok(Server { proc, tls, port, _rx: rx }) + }, + Message::Failure => { + let stdout = proc.stdout().unwrap(); + let mut out = String::new(); + stdout.read_to_string(&mut out)?; + + let stderr = proc.stderr().unwrap(); + let mut err = String::new(); + stderr.read_to_string(&mut err)?; + Err(Error::Liftoff(out, err)) + } + } + } + + pub fn socket_addr(&self) -> SocketAddr { + let ip = Ipv4Addr::LOCALHOST; + SocketAddr::new(ip.into(), self.port) + } + + pub fn read_stdout(&mut self) -> Result { + let Some(stdout) = self.proc.stdout() else { + return Ok(String::new()); + }; + + let mut string = String::new(); + stdout.read_to_string(&mut string)?; + Ok(string) + } + + pub fn read_stderr(&mut self) -> Result { + let Some(stderr) = self.proc.stderr() else { + return Ok(String::new()); + }; + + let mut string = String::new(); + stderr.read_to_string(&mut string)?; + Ok(string) + } + + pub fn kill(&mut self) -> Result<()> { + Ok(self.proc.kill()?) + } + + pub fn terminate(&mut self) -> Result<()> { + use nix::{sys::signal, unistd::Pid}; + + let pid = Pid::from_raw(self.proc.pid().unwrap() as i32); + Ok(signal::kill(pid, signal::SIGTERM)?) + } + + pub fn join(&mut self, duration: Duration) -> Result<()> { + match self.proc.join_timeout(duration) { + Ok(_) => Ok(()), + Err(e) if e.is_remote_close() => Ok(()), + Err(e) => Err(e.into()), + } + } +} + +impl Token { + pub fn launch_with(self, rocket: Rocket) -> Launched + where B: for<'r> Bind<&'r Rocket> + Sync + Send + 'static + { + let server = self.0.clone(); + let rocket = rocket.attach(AdHoc::on_liftoff("Liftoff", move |rocket| Box::pin(async move { + let tcp = rocket.endpoints().find_map(|e| e.tcp()).unwrap(); + let tls = rocket.endpoints().any(|e| e.is_tls()); + let sender = IpcSender::::connect(server).unwrap(); + let _ = sender.send(Message::Liftoff(tls, tcp.port())); + let _ = sender.send(Message::Liftoff(tls, tcp.port())); + }))); + + let server = self.0.clone(); + let fut = rocket.launch_with::(); + if let Err(e) = rocket::execute(fut) { + let sender = IpcSender::::connect(server).unwrap(); + let _ = sender.send(Message::Failure); + let _ = sender.send(Message::Failure); + e.pretty_print(); + std::process::exit(1); + } + + Launched(()) + } + + pub fn launch(self, rocket: Rocket) -> Launched { + self.launch_with::(rocket) + } +} + +impl Drop for Server { + fn drop(&mut self) { + let _ = self.terminate(); + if self.join(Duration::from_secs(3)).is_err() { + let _ = self.kill(); + } + } +} + +#[macro_export] +macro_rules! spawn { + ($($arg:ident : $t:ty),* => $rocket:block) => {{ + #[allow(unused_parens)] + fn _server((token, $($arg),*): ($crate::Token, $($t),*)) -> $crate::Launched { + let rocket: rocket::Rocket = $rocket; + token.launch(rocket) + } + + Server::spawn(($($arg),*), _server) + }}; + + ($($token:tt)*) => {{ + let _unit = (); + spawn!(_unit: () => { $($token)* } ) + }}; +}