Skip to content

Commit

Permalink
wip: tls-resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Apr 16, 2024
1 parent 60f3cd5 commit b21225d
Show file tree
Hide file tree
Showing 38 changed files with 1,424 additions and 655 deletions.
2 changes: 1 addition & 1 deletion contrib/sync_db_pools/lib/tests/shutdown.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(all(feature = "diesel_sqlite_pool"))]
#[cfg(test)]
#[cfg(all(feature = "diesel_sqlite_pool"))]
mod sqlite_shutdown_test {
use rocket::{async_test, Build, Rocket};
use rocket_sync_db_pools::database;
Expand Down
24 changes: 23 additions & 1 deletion core/http/src/uri/authority.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl<'a> Authority<'a> {
self.host.from_cow_source(&self.source)
}

/// Returns the port part of the authority URI, if there is one.
/// Returns the `port` part of the authority URI, if there is one.
///
/// # Example
///
Expand All @@ -206,6 +206,28 @@ impl<'a> Authority<'a> {
pub fn port(&self) -> Option<u16> {
self.port
}

/// Set the `port` of the authority URI.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// let mut uri = uri!("username:password@host:123");
/// assert_eq!(uri.port(), Some(123));
///
/// uri.set_port(1024);
/// assert_eq!(uri.port(), Some(1024));
/// assert_eq!(uri, "username:password@host:1024");
///
/// uri.set_port(None);
/// assert_eq!(uri.port(), None);
/// assert_eq!(uri, "username:password@host");
/// ```
#[inline(always)]
pub fn set_port<T: Into<Option<u16>>>(&mut self, port: T) {
self.port = port.into();
}
}

impl_serde!(Authority<'a>, "an authority-form URI");
Expand Down
5 changes: 3 additions & 2 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ ref-swap = "0.1.2"
parking_lot = "0.12"
ubyte = {version = "0.10.2", features = ["serde"] }
serde = { version = "1.0", features = ["derive"] }
figment = { version = "0.10.13", features = ["toml", "env"] }
figment = { version = "0.10.17", features = ["toml", "env"] }
rand = "0.8"
either = "1"
pin-project-lite = "0.2"
Expand Down Expand Up @@ -140,5 +140,6 @@ version_check = "0.9.1"

[dev-dependencies]
tokio = { version = "1", features = ["macros", "io-std"] }
figment = { version = "0.10", features = ["test"] }
figment = { version = "0.10.17", features = ["test"] }
pretty_assertions = "1"
arc-swap = "1.7"
3 changes: 0 additions & 3 deletions core/lib/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ mod secret_key;
#[cfg(unix)]
pub use crate::shutdown::Sig;

#[cfg(unix)]
pub use crate::listener::unix::UdsConfig;

#[cfg(feature = "secrets")]
pub use secret_key::SecretKey;

Expand Down
10 changes: 7 additions & 3 deletions core/lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,13 @@ impl Error {
self.mark_handled();
match self.kind() {
ErrorKind::Bind(ref a, ref e) => {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
if let Some(e) = e.downcast_ref::<Self>() {
e.pretty_print();
} else {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
}
}

info_!("{}", e);
Expand Down
10 changes: 10 additions & 0 deletions core/lib/src/listener/bind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use crate::listener::{Endpoint, Listener};

pub trait Bind<T>: Listener + 'static {
type Error: std::error::Error + Send + 'static;

#[crate::async_bound(Send)]
async fn bind(to: T) -> Result<Self, Self::Error>;

fn bind_endpoint(to: &T) -> Result<Endpoint, Self::Error>;
}
52 changes: 0 additions & 52 deletions core/lib/src/listener/bindable.rs

This file was deleted.

3 changes: 2 additions & 1 deletion core/lib/src/listener/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::borrow::Cow;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::either::Either;

use super::Endpoint;
Expand All @@ -9,7 +10,7 @@ use super::Endpoint;
#[derive(Clone)]
pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>);

pub trait Connection: Send + Unpin {
pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
fn endpoint(&self) -> io::Result<Endpoint>;

/// DER-encoded X.509 certificate chain presented by the client, if any.
Expand Down
140 changes: 93 additions & 47 deletions core/lib/src/listener/default.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,110 @@
use either::Either;
use serde::Deserialize;
use tokio_util::either::{Either, Either::{Left, Right}};
use futures::TryFutureExt;

use crate::listener::{Bindable, Endpoint};
use crate::error::{Error, ErrorKind};
use crate::error::ErrorKind;
use crate::{Ignite, Rocket};
use crate::listener::{Bind, Endpoint, tcp::TcpListener};

#[derive(serde::Deserialize)]
pub struct DefaultListener {
#[cfg(unix)] use crate::listener::unix::UnixListener;
#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};

mod private {
use super::{Either, TcpListener};

#[cfg(feature = "tls")] pub type TlsListener<T> = super::TlsListener<T>;
#[cfg(not(feature = "tls"))] pub type TlsListener<T> = T;
#[cfg(unix)] pub type UnixListener = super::UnixListener;
#[cfg(not(unix))] pub type UnixListener = super::TcpListener;

pub type Listener = Either<
Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
Either<TcpListener, UnixListener>,
>;
}

#[derive(Deserialize)]
struct Config {
#[serde(default)]
pub address: Endpoint,
pub port: Option<u16>,
pub reuse: Option<bool>,
address: Endpoint,
#[cfg(feature = "tls")]
pub tls: Option<crate::tls::TlsConfig>,
tls: Option<TlsConfig>,
}

#[cfg(not(unix))] type BaseBindable = Either<std::net::SocketAddr, std::net::SocketAddr>;
#[cfg(unix)] type BaseBindable = Either<std::net::SocketAddr, super::unix::UdsConfig>;
pub type DefaultListener = private::Listener;

#[cfg(not(feature = "tls"))] type TlsBindable<T> = Either<T, T>;
#[cfg(feature = "tls")] type TlsBindable<T> = Either<super::tls::TlsBindable<T>, T>;
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
type Error = crate::Error;

impl DefaultListener {
pub(crate) fn base_bindable(&self) -> Result<BaseBindable, crate::Error> {
match &self.address {
Endpoint::Tcp(mut address) => {
if let Some(port) = self.port {
address.set_port(port);
}
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let config: Config = rocket.figment().extract()?;
match config.address {
#[cfg(feature = "tls")]
endpoint@Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Ok(BaseBindable::Left(address))
},
Ok(Left(Left(listener)))
}
endpoint@Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Ok(Right(Left(listener)))
}
#[cfg(all(unix, feature = "tls"))]
endpoint@Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Ok(Left(Right(listener)))
}
#[cfg(unix)]
Endpoint::Unix(path) => {
let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, };
Ok(BaseBindable::Right(uds))
},
#[cfg(not(unix))]
e@Endpoint::Unix(_) => {
let msg = "Unix domain sockets unavailable on non-unix platforms.";
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(Some(e.clone()), boxed)))
},
other => {
let msg = format!("unsupported default listener address: {other}");
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(Some(other.clone()), boxed)))
endpoint@Endpoint::Unix(_) => {
let listener = <UnixListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Ok(Right(Right(listener)))
}
endpoint => {
let msg = format!("unsupported bind endpoint: {endpoint}");
let error = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(ErrorKind::Bind(Some(endpoint), error).into())
}
}
}

pub(crate) fn tls_bindable<T>(&self, inner: T) -> TlsBindable<T> {
#[cfg(feature = "tls")]
if let Some(tls) = self.tls.clone() {
return TlsBindable::Left(super::tls::TlsBindable { inner, tls });
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
let config: Config = rocket.figment().extract()?;
match config.address {
#[cfg(feature = "tls")]
endpoint@Endpoint::Tcp(_) if config.tls.is_some() => {
<TlsListener<TcpListener> as Bind<_>>::bind_endpoint(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)).into())
}
endpoint@Endpoint::Tcp(_) => {
<TcpListener as Bind<_>>::bind_endpoint(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)).into())
}
#[cfg(all(unix, feature = "tls"))]
endpoint@Endpoint::Unix(_) if config.tls.is_some() => {
<TlsListener<UnixListener> as Bind<_>>::bind_endpoint(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)).into())
}
#[cfg(unix)]
endpoint@Endpoint::Unix(_) => {
<UnixListener as Bind<_>>::bind_endpoint(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)).into())
}
endpoint => {
let msg = format!("unsupported bind endpoint: {endpoint}");
let error = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(ErrorKind::Bind(Some(endpoint), error).into())
}
}

TlsBindable::Right(inner)
}

pub fn bindable(&self) -> Result<impl Bindable, crate::Error> {
self.base_bindable()
.map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b)))
}
}
Loading

0 comments on commit b21225d

Please sign in to comment.