Skip to content

Commit

Permalink
Introduce dynamic TLS resolvers.
Browse files Browse the repository at this point in the history
This commit introduces the ability to dynamically select a TLS
configuration based on the client's TLS hello.

Added `Authority::set_port()`.
Various `Config` structures for listeners removed.
`UdsListener` is now `UnixListener`.
`Bindable` removed in favor of new `Bind`.
`Connection` requires `AsyncRead + AsyncWrite` again
The `Debug` impl for `Endpoint` displays the underlying address in plaintext.
`Listener` must be `Sized`.
`tls` listener moved to `tls::TlsListener`
The preview `quic` listener no longer implements `Listener`.

All built-in listeners now implement `Bind<&Rocket>`.

Clarified docs for `mtls::Certificate` guard.

No reexporitng rustls from `tls`.
Added `TlsConfig::server_config()`.

Added some future helpers: `race()` and `race_io()`.

Fix an issue where the logger wouldn't respect a configuration during error
printing.

Added Rocket::launch_with(), launch_on(), bind_launch().

Added a default client.pem to the TLS example.

Revamped the testbench.

Added tests for TLS resolvers, MTLS, listener failure output.

TODO: clippy.
TODO: UDS testing.

Resolves #2730.
Resolves #2363.
Closes #2748.
Closes #2683.
Closes #2577.
  • Loading branch information
SergioBenitez committed Apr 17, 2024
1 parent 280fda4 commit 984c6e9
Show file tree
Hide file tree
Showing 17 changed files with 316 additions and 146 deletions.
4 changes: 3 additions & 1 deletion contrib/dyn_templates/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@
//! to an `Object` (a dictionary) value. The [`context!`] macro can be used to
//! create inline `Serialize`-able context objects.
//!
//! [`Serialize`]: rocket::serde::Serialize
//!
//! ```rust
//! # #[macro_use] extern crate rocket;
//! use rocket::serde::Serialize;
Expand Down Expand Up @@ -165,7 +167,7 @@
//! builds, template reloading is disabled to improve performance and cannot be
//! enabled.
//!
//! [attached]: Rocket::attach()
//! [attached]: rocket::Rocket::attach()
//!
//! ### Metadata and Rendering to `String`
//!
Expand Down
13 changes: 7 additions & 6 deletions contrib/dyn_templates/src/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,12 @@ impl Template {
}

/// Render the template named `name` with the context `context`. The
/// `context` is typically created using the [`context!`] macro, but it can
/// be of any type that implements `Serialize`, such as `HashMap` or a
/// custom `struct`.
/// `context` is typically created using the [`context!()`](crate::context!)
/// macro, but it can be of any type that implements `Serialize`, such as
/// `HashMap` or a custom `struct`.
///
/// To render a template directly into a string, use [`Metadata::render()`].
/// To render a template directly into a string, use
/// [`Metadata::render()`](crate::Metadata::render()).
///
/// # Examples
///
Expand Down Expand Up @@ -291,8 +292,8 @@ impl Sentinel for Template {
/// A macro to easily create a template rendering context.
///
/// Invocations of this macro expand to a value of an anonymous type which
/// implements [`serde::Serialize`]. Fields can be literal expressions or
/// variables captured from a surrounding scope, as long as all fields implement
/// implements [`Serialize`]. Fields can be literal expressions or variables
/// captured from a surrounding scope, as long as all fields implement
/// `Serialize`.
///
/// # Examples
Expand Down
8 changes: 4 additions & 4 deletions core/lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,16 @@ impl Error {
match self.kind() {
ErrorKind::Bind(ref a, ref e) => {
if let Some(e) = e.downcast_ref::<Self>() {
e.pretty_print();
e.pretty_print()
} else {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
}
}

info_!("{}", e);
"aborting due to bind error"
info_!("{}", e);
"aborting due to bind error"
}
}
ErrorKind::Io(ref e) => {
error!("Rocket failed to launch due to an I/O error.");
Expand Down
173 changes: 139 additions & 34 deletions core/lib/src/listener/default.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,62 @@
use core::fmt;

use serde::Deserialize;
use tokio_util::either::{Either, Either::{Left, Right}};
use futures::TryFutureExt;
use tokio_util::either::Either::{Left, Right};
use either::Either;

use crate::error::ErrorKind;
use crate::{Ignite, Rocket};
use crate::listener::{Bind, Endpoint, tcp::TcpListener};

#[cfg(unix)] use crate::listener::unix::UnixListener;
#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};

mod private {
use super::{Either, TcpListener};
use super::*;
use tokio_util::either::Either;

#[cfg(feature = "tls")] pub type TlsListener<T> = super::TlsListener<T>;
#[cfg(not(feature = "tls"))] pub type TlsListener<T> = T;
#[cfg(unix)] pub type UnixListener = super::UnixListener;
#[cfg(not(unix))] pub type UnixListener = super::TcpListener;
#[cfg(feature = "tls")] type TlsListener<T> = super::TlsListener<T>;
#[cfg(not(feature = "tls"))] type TlsListener<T> = T;
#[cfg(unix)] type UnixListener = super::UnixListener;
#[cfg(not(unix))] type UnixListener = TcpListener;

pub type Listener = Either<
Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
Either<TcpListener, UnixListener>,
>;

/// The default connection listener.
///
/// # Configuration
///
/// Reads the following optional configuration parameters:
///
/// | parameter | type | default |
/// | ----------- | ----------------- | --------------------- |
/// | `address` | [`Endpoint`] | `tcp:127.0.0.1:8000` |
/// | `tls` | [`TlsConfig`] | None |
/// | `reuse` | boolean | `true` |
///
/// # Listener
///
/// Based on the above configuration, this listener defers to one of the
/// following existing listeners:
///
/// | listener | `address` type | `tls` enabled |
/// |-------------------------------|--------------------|---------------|
/// | [`TcpListener`] | [`Endpoint::Tcp`] | no |
/// | [`UnixListener`] | [`Endpoint::Unix`] | no |
/// | [`TlsListener<TcpListener>`] | [`Endpoint::Tcp`] | yes |
/// | [`TlsListener<UnixListener>`] | [`Endpoint::Unix`] | yes |
///
/// [`UnixListener`]: crate::listener::unix::UnixListener
/// [`TlsListener<TcpListener>`]: crate::tls::TlsListener
/// [`TlsListener<UnixListener>`]: crate::tls::TlsListener
///
/// * **address type** is the variant the `address` parameter parses as.
/// * **`tls` enabled** is `yes` when the `tls` feature is enabled _and_ a
/// `tls` configuration is provided.
#[cfg(doc)]
pub struct DefaultListener(());
}

#[derive(Deserialize)]
Expand All @@ -31,50 +67,61 @@ struct Config {
tls: Option<TlsConfig>,
}

#[cfg(doc)]
pub use private::DefaultListener;

#[cfg(doc)]
type Connection = crate::listener::tcp::TcpStream;

#[cfg(doc)]
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
type Error = Error;
async fn bind(_: &'r Rocket<Ignite>) -> Result<Self, Error> { unreachable!() }
fn bind_endpoint(_: &&'r Rocket<Ignite>) -> Result<Endpoint, Error> { unreachable!() }
}

#[cfg(doc)]
impl super::Listener for DefaultListener {
#[doc(hidden)] type Accept = Connection;
#[doc(hidden)] type Connection = Connection;
#[doc(hidden)]
async fn accept(&self) -> std::io::Result<Connection> { unreachable!() }
#[doc(hidden)]
async fn connect(&self, _: Self::Accept) -> std::io::Result<Connection> { unreachable!() }
#[doc(hidden)]
fn endpoint(&self) -> std::io::Result<Endpoint> { unreachable!() }
}

#[cfg(not(doc))]
pub type DefaultListener = private::Listener;

#[cfg(not(doc))]
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
type Error = crate::Error;
type Error = Error;

async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let config: Config = rocket.figment().extract()?;
match config.address {
#[cfg(feature = "tls")]
endpoint@Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket).await?;
Ok(Left(Left(listener)))
}
endpoint@Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket).await?;
Ok(Right(Left(listener)))
}
#[cfg(all(unix, feature = "tls"))]
endpoint@Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket).await?;
Ok(Left(Right(listener)))
}
#[cfg(unix)]
endpoint@Endpoint::Unix(_) => {
let listener = <UnixListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Endpoint::Unix(_) => {
let listener = <UnixListener as Bind<_>>::bind(rocket).await?;
Ok(Right(Right(listener)))
}
endpoint => {
let msg = format!("unsupported bind endpoint: {endpoint}");
let error = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(ErrorKind::Bind(Some(endpoint), error).into())
}
endpoint => Err(Error::Unsupported(endpoint)),
}
}

Expand All @@ -83,3 +130,61 @@ impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
Ok(config.address)
}
}

#[derive(Debug)]
pub enum Error {
Config(figment::Error),
Io(std::io::Error),
Unsupported(Endpoint),
#[cfg(feature = "tls")]
Tls(crate::tls::Error),
}

impl From<figment::Error> for Error {
fn from(value: figment::Error) -> Self {
Error::Config(value)
}
}

impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Error::Io(value)
}
}

#[cfg(feature = "tls")]
impl From<crate::tls::Error> for Error {
fn from(value: crate::tls::Error) -> Self {
Error::Tls(value)
}
}

impl From<Either<figment::Error, std::io::Error>> for Error {
fn from(value: Either<figment::Error, std::io::Error>) -> Self {
value.either(Error::Config, Error::Io)
}
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Config(e) => e.fmt(f),
Error::Io(e) => e.fmt(f),
Error::Unsupported(e) => write!(f, "unsupported endpoint: {e:?}"),
#[cfg(feature = "tls")]
Error::Tls(error) => error.fmt(f),
}
}
}

impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Config(e) => Some(e),
Error::Io(e) => Some(e),
Error::Unsupported(_) => None,
#[cfg(feature = "tls")]
Error::Tls(e) => Some(e),
}
}
}
39 changes: 20 additions & 19 deletions core/lib/src/listener/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,43 @@ use crate::http::uncased::AsUncased;
#[cfg(feature = "tls")] type TlsInfo = Option<Box<crate::tls::TlsConfig>>;
#[cfg(not(feature = "tls"))] type TlsInfo = Option<()>;

pub trait EndpointAddr: fmt::Display + fmt::Debug + Sync + Send + Any { }
pub trait CustomEndpoint: fmt::Display + fmt::Debug + Sync + Send + Any { }

impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> EndpointAddr for T {}
impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> CustomEndpoint for T {}

/// # Conversions
///
/// * [`&str`] - parse with [`FromStr`]
/// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`]
/// * [`PathBuf`] - infallibly as [`Endpoint::Unix`]
///
/// # Syntax
///
/// The string syntax is:
///
/// ```text
/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket
/// socket := IP_ADDR | SOCKET_ADDR
/// path := PATH
///
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
/// ```
///
/// If `IP_ADDR` is specified in socket, port defaults to `8000`.
#[derive(Clone)]
#[non_exhaustive]
pub enum Endpoint {
Tcp(net::SocketAddr),
Quic(net::SocketAddr),
Unix(PathBuf),
Tls(Arc<Endpoint>, TlsInfo),
Custom(Arc<dyn EndpointAddr>),
Custom(Arc<dyn CustomEndpoint>),
}

impl Endpoint {
pub fn new<T: EndpointAddr>(value: T) -> Endpoint {
pub fn new<T: CustomEndpoint>(value: T) -> Endpoint {
Endpoint::Custom(Arc::new(value))
}

Expand Down Expand Up @@ -222,21 +238,6 @@ impl Default for Endpoint {
}
}

/// Parses an address into a `Endpoint`.
///
/// The syntax is:
///
/// ```text
/// endpoint = 'tcp' ':' socket | 'quic' ':' socket | 'unix' ':' path | socket
/// socket := IP_ADDR | SOCKET_ADDR
/// path := PATH
///
/// IP_ADDR := `std::net::IpAddr` string as defined by Rust
/// SOCKET_ADDR := `std::net::SocketAddr` string as defined by Rust
/// PATH := `PathBuf` (any UTF-8) string as defined by Rust
/// ```
///
/// If `IP_ADDR` is specified in socket, port defaults to `8000`.
impl FromStr for Endpoint {
type Err = AddrParseError;

Expand Down
Loading

0 comments on commit 984c6e9

Please sign in to comment.