From 146d8ecdf207c420e24510dfbf376f8952e6cdf6 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Tue, 16 Apr 2024 15:17:10 -0700 Subject: [PATCH] Introduce dynamic TLS resolvers. This commit introduces the ability to dynamically select a TLS configuration based on the client's TLS hello. Added `Authority::set_port()`. Various `Config` structures for listeners removed. `UdsListener` is now `UnixListener`. `Bindable` removed in favor of new `Bind`. `Connection` requires `AsyncRead + AsyncWrite` again The `Debug` impl for `Endpoint` displays the underlying address in plaintext. `Listener` must be `Sized`. `tls` listener moved to `tls::TlsListener` The preview `quic` listener no longer implements `Listener`. All built-in listeners now implement `Bind<&Rocket>`. Clarified docs for `mtls::Certificate` guard. No reexporitng rustls from `tls`. Added `TlsConfig::server_config()`. Added some future helpers: `race()` and `race_io()`. Fix an issue where the logger wouldn't respect a configuration during error printing. Added Rocket::launch_with(), launch_on(), bind_launch(). Added a default client.pem to the TLS example. Revamped the testbench. Added tests for TLS resolvers, MTLS, listener failure output. TODO: clippy. TODO: UDS testing. Resolves #2730. Resolves #2363. Closes #2748. Closes #2683. Closes #2577. --- contrib/dyn_templates/src/lib.rs | 4 +- contrib/dyn_templates/src/template.rs | 13 +- core/lib/src/error.rs | 8 +- core/lib/src/listener/default.rs | 173 +++++++++++++++++++++----- core/lib/src/listener/endpoint.rs | 39 +++--- core/lib/src/listener/quic.rs | 16 +-- core/lib/src/listener/tcp.rs | 11 ++ core/lib/src/listener/unix.rs | 12 +- core/lib/src/mtls/config.rs | 7 +- core/lib/src/request/from_request.rs | 6 +- core/lib/src/response/response.rs | 20 ++- core/lib/src/shutdown/handle.rs | 2 +- core/lib/src/tls/config.rs | 3 +- core/lib/src/tls/mod.rs | 2 - core/lib/src/tls/resolver.rs | 33 +++++ scripts/test.sh | 9 ++ testbench/src/main.rs | 81 +++++++----- testbench/src/server.rs | 38 +++--- 18 files changed, 327 insertions(+), 150 deletions(-) diff --git a/contrib/dyn_templates/src/lib.rs b/contrib/dyn_templates/src/lib.rs index 6ed0c36547..eb8553690a 100644 --- a/contrib/dyn_templates/src/lib.rs +++ b/contrib/dyn_templates/src/lib.rs @@ -117,6 +117,8 @@ //! to an `Object` (a dictionary) value. The [`context!`] macro can be used to //! create inline `Serialize`-able context objects. //! +//! [`Serialize`]: rocket::serde::Serialize +//! //! ```rust //! # #[macro_use] extern crate rocket; //! use rocket::serde::Serialize; @@ -165,7 +167,7 @@ //! builds, template reloading is disabled to improve performance and cannot be //! enabled. //! -//! [attached]: Rocket::attach() +//! [attached]: rocket::Rocket::attach() //! //! ### Metadata and Rendering to `String` //! diff --git a/contrib/dyn_templates/src/template.rs b/contrib/dyn_templates/src/template.rs index 3e275c05e5..2cf88a8cbb 100644 --- a/contrib/dyn_templates/src/template.rs +++ b/contrib/dyn_templates/src/template.rs @@ -140,11 +140,12 @@ impl Template { } /// Render the template named `name` with the context `context`. The - /// `context` is typically created using the [`context!`] macro, but it can - /// be of any type that implements `Serialize`, such as `HashMap` or a - /// custom `struct`. + /// `context` is typically created using the [`context!()`](crate::context!) + /// macro, but it can be of any type that implements `Serialize`, such as + /// `HashMap` or a custom `struct`. /// - /// To render a template directly into a string, use [`Metadata::render()`]. + /// To render a template directly into a string, use + /// [`Metadata::render()`](crate::Metadata::render()). /// /// # Examples /// @@ -291,8 +292,8 @@ impl Sentinel for Template { /// A macro to easily create a template rendering context. /// /// Invocations of this macro expand to a value of an anonymous type which -/// implements [`serde::Serialize`]. Fields can be literal expressions or -/// variables captured from a surrounding scope, as long as all fields implement +/// implements [`Serialize`]. Fields can be literal expressions or variables +/// captured from a surrounding scope, as long as all fields implement /// `Serialize`. /// /// # Examples diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 85867017dd..5802e817fb 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -179,16 +179,16 @@ impl Error { match self.kind() { ErrorKind::Bind(ref a, ref e) => { if let Some(e) = e.downcast_ref::() { - e.pretty_print(); + e.pretty_print() } else { match a { Some(a) => error!("Binding to {} failed.", a.primary().underline()), None => error!("Binding to network interface failed."), } - } - info_!("{}", e); - "aborting due to bind error" + info_!("{}", e); + "aborting due to bind error" + } } ErrorKind::Io(ref e) => { error!("Rocket failed to launch due to an I/O error."); diff --git a/core/lib/src/listener/default.rs b/core/lib/src/listener/default.rs index a82201e3dd..24d3971568 100644 --- a/core/lib/src/listener/default.rs +++ b/core/lib/src/listener/default.rs @@ -1,8 +1,9 @@ +use core::fmt; + use serde::Deserialize; -use tokio_util::either::{Either, Either::{Left, Right}}; -use futures::TryFutureExt; +use tokio_util::either::Either::{Left, Right}; +use either::Either; -use crate::error::ErrorKind; use crate::{Ignite, Rocket}; use crate::listener::{Bind, Endpoint, tcp::TcpListener}; @@ -10,17 +11,52 @@ use crate::listener::{Bind, Endpoint, tcp::TcpListener}; #[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig}; mod private { - use super::{Either, TcpListener}; + use super::*; + use tokio_util::either::Either; - #[cfg(feature = "tls")] pub type TlsListener = super::TlsListener; - #[cfg(not(feature = "tls"))] pub type TlsListener = T; - #[cfg(unix)] pub type UnixListener = super::UnixListener; - #[cfg(not(unix))] pub type UnixListener = super::TcpListener; + #[cfg(feature = "tls")] type TlsListener = super::TlsListener; + #[cfg(not(feature = "tls"))] type TlsListener = T; + #[cfg(unix)] type UnixListener = super::UnixListener; + #[cfg(not(unix))] type UnixListener = TcpListener; pub type Listener = Either< Either, TlsListener>, Either, >; + + /// The default connection listener. + /// + /// # Configuration + /// + /// Reads the following optional configuration parameters: + /// + /// | parameter | type | default | + /// | ----------- | ----------------- | --------------------- | + /// | `address` | [`Endpoint`] | `tcp:127.0.0.1:8000` | + /// | `tls` | [`TlsConfig`] | None | + /// | `reuse` | boolean | `true` | + /// + /// # Listener + /// + /// Based on the above configuration, this listener defers to one of the + /// following existing listeners: + /// + /// | listener | `address` type | `tls` enabled | + /// |-------------------------------|--------------------|---------------| + /// | [`TcpListener`] | [`Endpoint::Tcp`] | no | + /// | [`UnixListener`] | [`Endpoint::Unix`] | no | + /// | [`TlsListener`] | [`Endpoint::Tcp`] | yes | + /// | [`TlsListener`] | [`Endpoint::Unix`] | yes | + /// + /// [`UnixListener`]: crate::listener::unix::UnixListener + /// [`TlsListener`]: crate::tls::TlsListener + /// [`TlsListener`]: crate::tls::TlsListener + /// + /// * **address type** is the variant the `address` parameter parses as. + /// * **`tls` enabled** is `yes` when the `tls` feature is enabled _and_ a + /// `tls` configuration is provided. + #[cfg(doc)] + pub struct DefaultListener(()); } #[derive(Deserialize)] @@ -31,50 +67,61 @@ struct Config { tls: Option, } +#[cfg(doc)] +pub use private::DefaultListener; + +#[cfg(doc)] +type Connection = crate::listener::tcp::TcpStream; + +#[cfg(doc)] +impl<'r> Bind<&'r Rocket> for DefaultListener { + type Error = Error; + async fn bind(_: &'r Rocket) -> Result { unreachable!() } + fn bind_endpoint(_: &&'r Rocket) -> Result { unreachable!() } +} + +#[cfg(doc)] +impl super::Listener for DefaultListener { + #[doc(hidden)] type Accept = Connection; + #[doc(hidden)] type Connection = Connection; + #[doc(hidden)] + async fn accept(&self) -> std::io::Result { unreachable!() } + #[doc(hidden)] + async fn connect(&self, _: Self::Accept) -> std::io::Result { unreachable!() } + #[doc(hidden)] + fn endpoint(&self) -> std::io::Result { unreachable!() } +} + +#[cfg(not(doc))] pub type DefaultListener = private::Listener; +#[cfg(not(doc))] impl<'r> Bind<&'r Rocket> for DefaultListener { - type Error = crate::Error; + type Error = Error; async fn bind(rocket: &'r Rocket) -> Result { let config: Config = rocket.figment().extract()?; match config.address { #[cfg(feature = "tls")] - endpoint@Endpoint::Tcp(_) if config.tls.is_some() => { - let listener = as Bind<_>>::bind(rocket) - .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) - .await?; - + Endpoint::Tcp(_) if config.tls.is_some() => { + let listener = as Bind<_>>::bind(rocket).await?; Ok(Left(Left(listener))) } - endpoint@Endpoint::Tcp(_) => { - let listener = >::bind(rocket) - .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) - .await?; - + Endpoint::Tcp(_) => { + let listener = >::bind(rocket).await?; Ok(Right(Left(listener))) } #[cfg(all(unix, feature = "tls"))] - endpoint@Endpoint::Unix(_) if config.tls.is_some() => { - let listener = as Bind<_>>::bind(rocket) - .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) - .await?; - + Endpoint::Unix(_) if config.tls.is_some() => { + let listener = as Bind<_>>::bind(rocket).await?; Ok(Left(Right(listener))) } #[cfg(unix)] - endpoint@Endpoint::Unix(_) => { - let listener = >::bind(rocket) - .map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e))) - .await?; - + Endpoint::Unix(_) => { + let listener = >::bind(rocket).await?; Ok(Right(Right(listener))) } - endpoint => { - let msg = format!("unsupported bind endpoint: {endpoint}"); - let error = Box::::from(msg); - Err(ErrorKind::Bind(Some(endpoint), error).into()) - } + endpoint => Err(Error::Unsupported(endpoint)), } } @@ -83,3 +130,61 @@ impl<'r> Bind<&'r Rocket> for DefaultListener { Ok(config.address) } } + +#[derive(Debug)] +pub enum Error { + Config(figment::Error), + Io(std::io::Error), + Unsupported(Endpoint), + #[cfg(feature = "tls")] + Tls(crate::tls::Error), +} + +impl From for Error { + fn from(value: figment::Error) -> Self { + Error::Config(value) + } +} + +impl From for Error { + fn from(value: std::io::Error) -> Self { + Error::Io(value) + } +} + +#[cfg(feature = "tls")] +impl From for Error { + fn from(value: crate::tls::Error) -> Self { + Error::Tls(value) + } +} + +impl From> for Error { + fn from(value: Either) -> Self { + value.either(Error::Config, Error::Io) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Config(e) => e.fmt(f), + Error::Io(e) => e.fmt(f), + Error::Unsupported(e) => write!(f, "unsupported endpoint: {e:?}"), + #[cfg(feature = "tls")] + Error::Tls(error) => error.fmt(f), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Config(e) => Some(e), + Error::Io(e) => Some(e), + Error::Unsupported(_) => None, + #[cfg(feature = "tls")] + Error::Tls(e) => Some(e), + } + } +} diff --git a/core/lib/src/listener/endpoint.rs b/core/lib/src/listener/endpoint.rs index 2864e75c60..a9a49b11fd 100644 --- a/core/lib/src/listener/endpoint.rs +++ b/core/lib/src/listener/endpoint.rs @@ -13,15 +13,31 @@ use crate::http::uncased::AsUncased; #[cfg(feature = "tls")] type TlsInfo = Option>; #[cfg(not(feature = "tls"))] type TlsInfo = Option<()>; -pub trait EndpointAddr: fmt::Display + fmt::Debug + Sync + Send + Any { } +pub trait CustomEndpoint: fmt::Display + fmt::Debug + Sync + Send + Any { } -impl EndpointAddr for T {} +impl CustomEndpoint for T {} /// # Conversions /// /// * [`&str`] - parse with [`FromStr`] /// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`] /// * [`PathBuf`] - infallibly as [`Endpoint::Unix`] +/// +/// # Syntax +/// +/// The string syntax is: +/// +/// ```text +/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket +/// socket := IP_ADDR | SOCKET_ADDR +/// path := PATH +/// +/// IP_ADDR := `std::net::IpAddr` string as defined by Rust +/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust +/// PATH := `PathBuf` (any UTF-8) string as defined by Rust +/// ``` +/// +/// If `IP_ADDR` is specified in socket, port defaults to `8000`. #[derive(Clone)] #[non_exhaustive] pub enum Endpoint { @@ -29,11 +45,11 @@ pub enum Endpoint { Quic(net::SocketAddr), Unix(PathBuf), Tls(Arc, TlsInfo), - Custom(Arc), + Custom(Arc), } impl Endpoint { - pub fn new(value: T) -> Endpoint { + pub fn new(value: T) -> Endpoint { Endpoint::Custom(Arc::new(value)) } @@ -222,21 +238,6 @@ impl Default for Endpoint { } } -/// Parses an address into a `Endpoint`. -/// -/// The syntax is: -/// -/// ```text -/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket -/// socket := IP_ADDR | SOCKET_ADDR -/// path := PATH -/// -/// IP_ADDR := `std::net::IpAddr` string as defined by Rust -/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust -/// PATH := `PathBuf` (any UTF-8) string as defined by Rust -/// ``` -/// -/// If `IP_ADDR` is specified in socket, port defaults to `8000`. impl FromStr for Endpoint { type Err = AddrParseError; diff --git a/core/lib/src/listener/quic.rs b/core/lib/src/listener/quic.rs index 9ba98d4056..dd94fd054f 100644 --- a/core/lib/src/listener/quic.rs +++ b/core/lib/src/listener/quic.rs @@ -51,14 +51,16 @@ pub struct QuicListener { pub struct H3Stream(H3Conn); pub struct H3Connection { - pub handle: quic::connection::Handle, - pub parts: http::request::Parts, - pub tx: QuicTx, - pub rx: QuicRx, + pub(crate) handle: quic::connection::Handle, + pub(crate) parts: http::request::Parts, + pub(crate) tx: QuicTx, + pub(crate) rx: QuicRx, } +#[doc(hidden)] pub struct QuicRx(h3::server::RequestStream); +#[doc(hidden)] pub struct QuicTx(h3::server::RequestStream, Bytes>); impl QuicListener { @@ -95,19 +97,19 @@ impl QuicListener { } impl QuicListener { - pub(crate) async fn accept(&self) -> Option { + pub async fn accept(&self) -> Option { self.listener .lock().await .accept().await } - pub(crate) async fn connect(&self, accept: quic::Connection) -> io::Result { + pub async fn connect(&self, accept: quic::Connection) -> io::Result { let quic_conn = quic_h3::Connection::new(accept); let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?; Ok(H3Stream(conn)) } - pub(crate) fn endpoint(&self) -> io::Result { + pub fn endpoint(&self) -> io::Result { Ok(Endpoint::Quic(self.endpoint).with_tls(&self.tls)) } } diff --git a/core/lib/src/listener/tcp.rs b/core/lib/src/listener/tcp.rs index e5c1a418c1..09348cba07 100644 --- a/core/lib/src/listener/tcp.rs +++ b/core/lib/src/listener/tcp.rs @@ -1,3 +1,14 @@ +//! TCP listener. +//! +//! # Configuration +//! +//! Reads the following configuration parameters: +//! +//! | parameter | type | default | note | +//! |-----------|--------------|-------------|---------------------------------| +//! | `address` | [`Endpoint`] | `127.0.0.1` | must be `tcp:ip` | +//! | `port` | `u16` | `8000` | replaces the port in `address ` | + use std::io; use std::net::{Ipv4Addr, SocketAddr}; diff --git a/core/lib/src/listener/unix.rs b/core/lib/src/listener/unix.rs index 92e6326497..dac1faec4f 100644 --- a/core/lib/src/listener/unix.rs +++ b/core/lib/src/listener/unix.rs @@ -11,6 +11,16 @@ use crate::{Ignite, Rocket}; pub use tokio::net::UnixStream; +/// Unix domain sockets listener. +/// +/// # Configuration +/// +/// Reads the following configuration parameters: +/// +/// | parameter | type | default | note | +/// |-----------|--------------|---------|-------------------------------------------| +/// | `address` | [`Endpoint`] | | required: must be `unix:path` | +/// | `reuse` | boolean | `true` | whether to create/reuse/delete the socket | pub struct UnixListener { path: PathBuf, lock: Option, @@ -18,7 +28,7 @@ pub struct UnixListener { } impl UnixListener { - async fn bind>(path: P, reuse: bool) -> io::Result { + pub async fn bind>(path: P, reuse: bool) -> io::Result { let path = path.as_ref(); let lock = if reuse { let lock_ext = match path.extension().and_then(|s| s.to_str()) { diff --git a/core/lib/src/mtls/config.rs b/core/lib/src/mtls/config.rs index 8bcbf0c055..fdc16efcfe 100644 --- a/core/lib/src/mtls/config.rs +++ b/core/lib/src/mtls/config.rs @@ -79,8 +79,8 @@ pub struct MtlsConfig { impl MtlsConfig { /// Constructs a `MtlsConfig` from a path to a PEM file with a certificate /// authority `ca_certs` DER-encoded X.509 TLS certificate chain. This - /// method does no validation; it simply creates a structure suitable for - /// passing into a [`TlsConfig`]. + /// method does no validation; it simply creates an [`MtlsConfig`] for later + /// use. /// /// These certificates will be used to verify client-presented certificates /// in TLS connections. @@ -101,8 +101,7 @@ impl MtlsConfig { /// Constructs a `MtlsConfig` from a byte buffer to a certificate authority /// `ca_certs` DER-encoded X.509 TLS certificate chain. This method does no - /// validation; it simply creates a structure suitable for passing into a - /// [`TlsConfig`]. + /// validation; it simply creates an [`MtlsConfig`] for later use. /// /// These certificates will be used to verify client-presented certificates /// in TLS connections. diff --git a/core/lib/src/request/from_request.rs b/core/lib/src/request/from_request.rs index a87b3bd4bb..f4677db3c9 100644 --- a/core/lib/src/request/from_request.rs +++ b/core/lib/src/request/from_request.rs @@ -1,6 +1,6 @@ -use std::convert::Infallible; use std::fmt::Debug; -use std::net::IpAddr; +use std::convert::Infallible; +use std::net::{IpAddr, SocketAddr}; use crate::{Request, Route}; use crate::outcome::{self, IntoOutcome, Outcome::*}; @@ -496,7 +496,7 @@ impl<'r> FromRequest<'r> for &'r Endpoint { } #[crate::async_trait] -impl<'r> FromRequest<'r> for std::net::SocketAddr { +impl<'r> FromRequest<'r> for SocketAddr { type Error = Infallible; async fn from_request(request: &'r Request<'_>) -> Outcome { diff --git a/core/lib/src/response/response.rs b/core/lib/src/response/response.rs index 4d399e9174..e3b266f014 100644 --- a/core/lib/src/response/response.rs +++ b/core/lib/src/response/response.rs @@ -114,9 +114,8 @@ impl<'r> Builder<'r> { /// the same name exist, they are all removed, and only the new header and /// value will remain. /// - /// The type of `header` can be any type that implements `Into
`. - /// This includes `Header` itself, [`ContentType`](crate::http::ContentType) and - /// [hyper::header types](crate::http::hyper::header). + /// The type of `header` can be any type that implements `Into
`. See + /// [trait implementations](Header#trait-implementations). /// /// # Example /// @@ -144,9 +143,8 @@ impl<'r> Builder<'r> { /// `Response`. This allows for multiple headers with the same name and /// potentially different values to be present in the `Response`. /// - /// The type of `header` can be any type that implements `Into
`. - /// This includes `Header` itself, [`ContentType`](crate::http::ContentType) - /// and [`Accept`](crate::http::Accept). + /// The type of `header` can be any type that implements `Into
`. See + /// [trait implementations](Header#trait-implementations). /// /// # Example /// @@ -641,9 +639,8 @@ impl<'r> Response<'r> { /// Sets the header `header` in `self`. Any existing headers with the name /// `header.name` will be lost, and only `header` will remain. The type of - /// `header` can be any type that implements `Into
`. This includes - /// `Header` itself, [`ContentType`](crate::http::ContentType) and - /// [`hyper::header` types](crate::http::hyper::header). + /// `header` can be any type that implements `Into
`. See [trait + /// implementations](Header#trait-implementations). /// /// # Example /// @@ -723,10 +720,7 @@ impl<'r> Response<'r> { /// Adds a custom header with name `name` and value `value` to `self`. If /// `self` already contains headers with the name `name`, another header - /// with the same `name` and `value` is added. The type of `header` can be - /// any type implements `Into
`. This includes `Header` itself, - /// [`ContentType`](crate::http::ContentType) and [`hyper::header` - /// types](crate::http::hyper::header). + /// with the same `name` and `value` is added. /// /// # Example /// diff --git a/core/lib/src/shutdown/handle.rs b/core/lib/src/shutdown/handle.rs index 862a11df95..1b11366cd0 100644 --- a/core/lib/src/shutdown/handle.rs +++ b/core/lib/src/shutdown/handle.rs @@ -88,7 +88,7 @@ impl Shutdown { /// This function returns immediately; pending requests will continue to run /// until completion or expiration of the grace period, which ever comes /// first, before the actual shutdown occurs. The grace period can be - /// configured via [`Shutdown::grace`](crate::config::ShutdownConfig::grace). + /// configured via [`ShutdownConfig`]'s `grace` field. /// /// ```rust /// # use rocket::*; diff --git a/core/lib/src/tls/config.rs b/core/lib/src/tls/config.rs index c291ffebae..e6e9fb30b2 100644 --- a/core/lib/src/tls/config.rs +++ b/core/lib/src/tls/config.rs @@ -39,7 +39,8 @@ use crate::tls::error::{Result, Error, KeyError}; /// /// Additionally, the `mutual` parameter controls if and how the server /// authenticates clients via mutual TLS. It works in concert with the -/// [`mtls`](crate::mtls) module. See [`MtlsConfig`] for configuration details. +/// [`mtls`](crate::mtls) module. See [`MtlsConfig`](crate::mtls::MtlsConfig) +/// for configuration details. /// /// In `Rocket.toml`, configuration might look like: /// diff --git a/core/lib/src/tls/mod.rs b/core/lib/src/tls/mod.rs index df9899dfb9..8c439e42cd 100644 --- a/core/lib/src/tls/mod.rs +++ b/core/lib/src/tls/mod.rs @@ -3,8 +3,6 @@ mod resolver; mod listener; pub(crate) mod config; -pub use rustls; - pub use error::{Error, Result}; pub use config::{TlsConfig, CipherSuite}; pub use resolver::{Resolver, ClientHello, ServerConfig}; diff --git a/core/lib/src/tls/resolver.rs b/core/lib/src/tls/resolver.rs index 475ec2b8e9..d7fb39677e 100644 --- a/core/lib/src/tls/resolver.rs +++ b/core/lib/src/tls/resolver.rs @@ -15,6 +15,39 @@ pub(crate) struct DynResolver(Arc); pub struct Fairing(PhantomData); /// A dynamic TLS configuration resolver. +/// +/// # Example +/// +/// This is an async trait. Implement it as follows: +/// +/// ```rust +/// # #[macro_use] extern crate rocket; +/// use std::sync::Arc; +/// use rocket::tls::{self, Resolver, TlsConfig, ClientHello, ServerConfig}; +/// use rocket::{Rocket, Build}; +/// +/// struct MyResolver(Arc); +/// +/// #[rocket::async_trait] +/// impl Resolver for MyResolver { +/// async fn init(rocket: &Rocket) -> tls::Result { +/// // This is equivalent to what the default resolver would do. +/// let config: TlsConfig = rocket.figment().extract_inner("tls")?; +/// let server_config = config.server_config().await?; +/// Ok(MyResolver(Arc::new(server_config))) +/// } +/// +/// async fn resolve(&self, hello: ClientHello<'_>) -> Option> { +/// // return a `ServerConfig` based on `hello`; here we ignore it +/// Some(self.0.clone()) +/// } +/// } +/// +/// #[launch] +/// fn rocket() -> _ { +/// rocket::build().attach(MyResolver::fairing()) +/// } +/// ``` #[crate::async_trait] pub trait Resolver: Send + Sync + 'static { async fn init(rocket: &Rocket) -> crate::tls::Result where Self: Sized { diff --git a/scripts/test.sh b/scripts/test.sh index 5212be19d0..257fc759bd 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -171,6 +171,15 @@ function test_default() { echo ":: Checking fuzzers..." indir "${FUZZ_ROOT}" $CARGO update indir "${FUZZ_ROOT}" $CARGO check --all --all-features $@ + + case "$OSTYPE" in + darwin* | linux*) + echo ":: Checking testbench..." + indir "${TESTBENCH_ROOT}" $CARGO update + indir "${TESTBENCH_ROOT}" $CARGO check $@ + ;; + *) echo ":: Skipping testbench [$OSTYPE]" ;; + esac } function test_ui() { diff --git a/testbench/src/main.rs b/testbench/src/main.rs index 44b1a24fac..093099fe8d 100644 --- a/testbench/src/main.rs +++ b/testbench/src/main.rs @@ -1,4 +1,5 @@ use std::process::ExitCode; +use std::time::Duration; use rocket::listener::unix::UnixListener; use rocket::tokio::net::TcpListener; @@ -137,8 +138,7 @@ fn tls_info() -> Result<()> { fn tls_resolver() -> Result<()> { use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; - use rocket::tls::rustls::{server::ClientHello, ServerConfig}; - use rocket::tls::{Resolver, TlsConfig}; + use rocket::tls::{Resolver, TlsConfig, ClientHello, ServerConfig}; struct CountingResolver { config: Arc, @@ -163,9 +163,7 @@ fn tls_resolver() -> Result<()> { let server = spawn! { #[get("/count")] fn count(counter: &State>) -> String { - let count = counter.load(Ordering::Acquire); - println!("{count}"); - count.to_string() + counter.load(Ordering::Acquire).to_string() } let counter = Arc::new(AtomicUsize::new(0)); @@ -242,8 +240,7 @@ fn sni_resolver() -> Result<()> { use std::collections::HashMap; use rocket::http::uri::Host; - use rocket::tls::rustls::{server::ClientHello, ServerConfig}; - use rocket::tls::{Resolver, TlsConfig}; + use rocket::tls::{Resolver, TlsConfig, ClientHello, ServerConfig}; struct SniResolver { default: Arc, @@ -329,8 +326,8 @@ fn tcp_unix_listener_fail() -> Result<()> { }; if let Err(Error::Liftoff(stdout, _)) = server { - assert!(stdout.contains("expected valid TCP")); - assert!(stdout.contains("for key default.address")); + assert!(stdout.contains("expected valid TCP (ip) or unix (path)")); + assert!(stdout.contains("default.address")); } else { panic!("unexpected result: {server:#?}"); } @@ -361,14 +358,17 @@ fn tcp_unix_listener_fail() -> Result<()> { macro_rules! tests { ($($f:ident),* $(,)?) => {[ - $(Test { name: stringify!($f), func: $f, }),* + $(Test { + name: stringify!($f), + run: |_: ()| $f().map_err(|e| e.to_string()), + }),* ]}; } #[derive(Copy, Clone)] struct Test { name: &'static str, - func: fn() -> Result<()>, + run: fn(()) -> Result<(), String>, } static TESTS: &[Test] = &tests![ @@ -377,37 +377,58 @@ static TESTS: &[Test] = &tests![ ]; fn main() -> ExitCode { + procspawn::init(); + let filter = std::env::args().nth(1).unwrap_or_default(); let filtered = TESTS.into_iter().filter(|test| test.name.contains(&filter)); println!("running {}/{} tests", filtered.clone().count(), TESTS.len()); - let handles: Vec<_> = filtered - .map(|test| (test, std::thread::spawn(move || { - if let Err(e) = (test.func)() { - println!("test {} ... {}\n {e}", test.name.bold(), "fail".red()); - return Err(e); + let handles = filtered.map(|test| (test, std::thread::spawn(|| { + let name = test.name; + let start = std::time::SystemTime::now(); + let mut proc = procspawn::spawn((), test.run); + let result = loop { + match proc.join_timeout(Duration::from_secs(10)) { + Err(e) if e.is_timeout() => { + let elapsed = start.elapsed().unwrap().as_secs(); + println!("{name} has been running for {elapsed} seconds..."); + + if elapsed >= 30 { + println!("{name} timeout"); + break Err(e); + } + }, + result => break result, } + }; - println!("test {} ... {}", test.name.bold(), "ok".green()); - Ok(()) - }))) - .collect(); - - let mut failure = false; - for (test, handle) in handles { - let result = handle.join(); - failure |= matches!(result, Err(_) | Ok(Err(_))); - if result.is_err() { - println!("test {} ... {}", test.name.bold(), "panic".red().underline()); + match result.as_ref().map_err(|e| e.panic_info()) { + Ok(Ok(_)) => println!("test {name} ... {}", "ok".green()), + Ok(Err(e)) => println!("test {name} ... {}\n {e}", "fail".red()), + Err(Some(_)) => println!("test {name} ... {}", "panic".red().underline()), + Err(None) => println!("test {name} ... {}", "error".magenta()), } + + matches!(result, Ok(Ok(()))) + }))); + + let mut success = true; + for (_, handle) in handles { + success &= handle.join().unwrap_or(false); } - match failure { - true => ExitCode::FAILURE, - false => ExitCode::SUCCESS + match success { + true => ExitCode::SUCCESS, + false => { + println!("note: use `NOCAPTURE=1` to see test output"); + ExitCode::FAILURE + } } } +// TODO: Implement an `UpdatingResolver`. Expose `SniResolver` and +// `UpdatingResolver` in a `contrib` library or as part of `rocket`. +// // struct UpdatingResolver { // timestamp: AtomicU64, // config: ArcSwap diff --git a/testbench/src/server.rs b/testbench/src/server.rs index 8a07cc14be..13b40c3e8d 100644 --- a/testbench/src/server.rs +++ b/testbench/src/server.rs @@ -42,6 +42,16 @@ fn stdio() -> Stdio { .unwrap_or_else(Stdio::piped) } +fn read(io: Option) -> Result { + if let Some(mut io) = io { + let mut string = String::new(); + io.read_to_string(&mut string)?; + return Ok(string); + } + + Ok(String::new()) +} + impl Server { pub fn spawn(ctxt: T, f: fn((Token, T)) -> Launched) -> Result where T: Serialize + DeserializeOwned @@ -62,14 +72,7 @@ impl Server { Ok(Server { proc, tls, port, _rx: rx }) }, Message::Failure => { - let stdout = proc.stdout().unwrap(); - let mut out = String::new(); - stdout.read_to_string(&mut out)?; - - let stderr = proc.stderr().unwrap(); - let mut err = String::new(); - stderr.read_to_string(&mut err)?; - Err(Error::Liftoff(out, err)) + Err(Error::Liftoff(read(proc.stdout())?, read(proc.stderr())?)) } } } @@ -80,23 +83,11 @@ impl Server { } pub fn read_stdout(&mut self) -> Result { - let Some(stdout) = self.proc.stdout() else { - return Ok(String::new()); - }; - - let mut string = String::new(); - stdout.read_to_string(&mut string)?; - Ok(string) + read(self.proc.stdout()) } pub fn read_stderr(&mut self) -> Result { - let Some(stderr) = self.proc.stderr() else { - return Ok(String::new()); - }; - - let mut string = String::new(); - stderr.read_to_string(&mut string)?; - Ok(string) + read(self.proc.stderr()) } pub fn kill(&mut self) -> Result<()> { @@ -133,8 +124,7 @@ impl Token { }))); let server = self.0.clone(); - let fut = rocket.launch_with::(); - if let Err(e) = rocket::execute(fut) { + if let Err(e) = rocket::execute(rocket.launch_with::()) { let sender = IpcSender::::connect(server).unwrap(); let _ = sender.send(Message::Failure); let _ = sender.send(Message::Failure);