Skip to content

Commit

Permalink
wip: dynamic tls cert resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Apr 15, 2024
1 parent 60f3cd5 commit 9522c68
Show file tree
Hide file tree
Showing 35 changed files with 1,304 additions and 622 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Generated by Cargo
target

# Generated databases
# Generated test files
db.sqlite
db.sqlite-shm
db.sqlite-wal
Expand Down
24 changes: 23 additions & 1 deletion core/http/src/uri/authority.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl<'a> Authority<'a> {
self.host.from_cow_source(&self.source)
}

/// Returns the port part of the authority URI, if there is one.
/// Returns the `port` part of the authority URI, if there is one.
///
/// # Example
///
Expand All @@ -206,6 +206,28 @@ impl<'a> Authority<'a> {
pub fn port(&self) -> Option<u16> {
self.port
}

/// Set the `port` of the authority URI.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// let mut uri = uri!("username:password@host:123");
/// assert_eq!(uri.port(), Some(123));
///
/// uri.set_port(1024);
/// assert_eq!(uri.port(), Some(1024));
/// assert_eq!(uri, "username:password@host:1024")
///
/// uri.set_port(None);
/// assert_eq!(uri.port(), None);
/// assert_eq!(uri, "username:password@host")
/// ```
#[inline(always)]
pub fn set_port<T: Into<Option<u16>>>(&mut self, port: T) {
self.port = port.into();
}
}

impl_serde!(Authority<'a>, "an authority-form URI");
Expand Down
1 change: 1 addition & 0 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@ version_check = "0.9.1"
tokio = { version = "1", features = ["macros", "io-std"] }
figment = { version = "0.10", features = ["test"] }
pretty_assertions = "1"
arc-swap = "1.7"
3 changes: 0 additions & 3 deletions core/lib/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ mod secret_key;
#[cfg(unix)]
pub use crate::shutdown::Sig;

#[cfg(unix)]
pub use crate::listener::unix::UdsConfig;

#[cfg(feature = "secrets")]
pub use secret_key::SecretKey;

Expand Down
10 changes: 7 additions & 3 deletions core/lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,13 @@ impl Error {
self.mark_handled();
match self.kind() {
ErrorKind::Bind(ref a, ref e) => {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
if let Some(e) = e.downcast_ref::<Self>() {
e.pretty_print();
} else {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
}
}

info_!("{}", e);
Expand Down
9 changes: 9 additions & 0 deletions core/lib/src/listener/bind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use crate::listener::{Endpoint, Listener};

pub trait Bind<T>: Listener + 'static {
type Error: std::error::Error + Send + 'static;

async fn bind(to: T) -> Result<Self, Self::Error>;

fn bind_endpoint(to: &T) -> Option<Endpoint>;
}
52 changes: 0 additions & 52 deletions core/lib/src/listener/bindable.rs

This file was deleted.

3 changes: 2 additions & 1 deletion core/lib/src/listener/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::borrow::Cow;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::either::Either;

use super::Endpoint;
Expand All @@ -9,7 +10,7 @@ use super::Endpoint;
#[derive(Clone)]
pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>);

pub trait Connection: Send + Unpin {
pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
fn endpoint(&self) -> io::Result<Endpoint>;

/// DER-encoded X.509 certificate chain presented by the client, if any.
Expand Down
119 changes: 71 additions & 48 deletions core/lib/src/listener/default.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,87 @@
use either::Either;
use serde::Deserialize;
use tokio_util::either::{Either, Either::{Left, Right}};
use futures::TryFutureExt;

use crate::listener::{Bindable, Endpoint};
use crate::error::{Error, ErrorKind};
use crate::error::ErrorKind;
use crate::{Ignite, Rocket};
use crate::listener::{Bind, Endpoint, tcp::TcpListener};

#[derive(serde::Deserialize)]
pub struct DefaultListener {
#[cfg(unix)] use crate::listener::unix::UnixListener;
#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};

mod private {
use super::{Either, TcpListener};

#[cfg(feature = "tls")] pub type TlsListener<T> = super::TlsListener<T>;
#[cfg(not(feature = "tls"))] pub type TlsListener<T> = T;
#[cfg(unix)] pub type UnixListener = super::UnixListener;
#[cfg(not(unix))] pub type UnixListener = super::TcpListener;

pub type Listener = Either<
Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
Either<TcpListener, UnixListener>,
>;
}

#[derive(Deserialize)]
struct Config {
#[serde(default)]
pub address: Endpoint,
pub port: Option<u16>,
pub reuse: Option<bool>,
address: Endpoint,
#[cfg(feature = "tls")]
pub tls: Option<crate::tls::TlsConfig>,
tls: Option<TlsConfig>,
}

#[cfg(not(unix))] type BaseBindable = Either<std::net::SocketAddr, std::net::SocketAddr>;
#[cfg(unix)] type BaseBindable = Either<std::net::SocketAddr, super::unix::UdsConfig>;
pub type DefaultListener = private::Listener;

#[cfg(not(feature = "tls"))] type TlsBindable<T> = Either<T, T>;
#[cfg(feature = "tls")] type TlsBindable<T> = Either<super::tls::TlsBindable<T>, T>;
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
type Error = crate::Error;

impl DefaultListener {
pub(crate) fn base_bindable(&self) -> Result<BaseBindable, crate::Error> {
match &self.address {
Endpoint::Tcp(mut address) => {
if let Some(port) = self.port {
address.set_port(port);
}
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let config: Config = rocket.figment().extract()?;

Ok(BaseBindable::Left(address))
},
#[cfg(unix)]
Endpoint::Unix(path) => {
let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, };
Ok(BaseBindable::Right(uds))
},
#[cfg(not(unix))]
e@Endpoint::Unix(_) => {
let msg = "Unix domain sockets unavailable on non-unix platforms.";
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(Some(e.clone()), boxed)))
},
other => {
let msg = format!("unsupported default listener address: {other}");
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(Some(other.clone()), boxed)))
let endpoint = config.address;
match endpoint {
#[cfg(feature = "tls")]
Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Ok(Left(Left(listener)))
}
}
}
Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

pub(crate) fn tls_bindable<T>(&self, inner: T) -> TlsBindable<T> {
#[cfg(feature = "tls")]
if let Some(tls) = self.tls.clone() {
return TlsBindable::Left(super::tls::TlsBindable { inner, tls });
}
Ok(Right(Left(listener)))
}
#[cfg(all(unix, feature = "tls"))]
Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Ok(Left(Right(listener)))
}
#[cfg(unix)]
Endpoint::Unix(_) => {
let listener = <UnixListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

TlsBindable::Right(inner)
Ok(Right(Right(listener)))
}
_ => {
let msg = format!("unsupported bind endpoint: {endpoint}");
let error = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(ErrorKind::Bind(Some(endpoint), error).into())
}
}
}

pub fn bindable(&self) -> Result<impl Bindable, crate::Error> {
self.base_bindable()
.map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b)))
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Option<Endpoint> {
let endpoint: Option<Endpoint> = rocket.figment().extract_inner("endpoint").ok()?;
endpoint
}
}
2 changes: 1 addition & 1 deletion core/lib/src/listener/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tokio_util::either::Either;

use crate::listener::{Connection, Endpoint};

pub trait Listener: Send + Sync {
pub trait Listener: Sized + Send + Sync {
type Accept: Send;

type Connection: Connection;
Expand Down
7 changes: 2 additions & 5 deletions core/lib/src/listener/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,20 @@ mod bounced;
mod listener;
mod endpoint;
mod connection;
mod bindable;
mod bind;
mod default;

#[cfg(unix)]
#[cfg_attr(nightly, doc(cfg(unix)))]
pub mod unix;
#[cfg(feature = "tls")]
#[cfg_attr(nightly, doc(cfg(feature = "tls")))]
pub mod tls;
pub mod tcp;
#[cfg(feature = "http3-preview")]
pub mod quic;

pub use endpoint::*;
pub use listener::*;
pub use connection::*;
pub use bindable::*;
pub use bind::*;
pub use default::*;

pub(crate) use cancellable::*;
Expand Down
27 changes: 7 additions & 20 deletions core/lib/src/listener/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use tokio::sync::Mutex;
use tokio_stream::StreamExt;

use crate::tls::{TlsConfig, Error};
use crate::listener::{Listener, Connection, Endpoint};
use crate::listener::Endpoint;

type H3Conn = h3::server::Connection<quic_h3::Connection, bytes::Bytes>;

Expand Down Expand Up @@ -94,25 +94,20 @@ impl QuicListener {
}
}

impl Listener for QuicListener {
type Accept = quic::Connection;

type Connection = H3Stream;

async fn accept(&self) -> io::Result<Self::Accept> {
impl QuicListener {
pub(crate) async fn accept(&self) -> Option<quic::Connection> {
self.listener
.lock().await
.accept().await
.ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "closed"))
}

async fn connect(&self, accept: Self::Accept) -> io::Result<Self::Connection> {
pub(crate) async fn connect(&self, accept: quic::Connection) -> io::Result<H3Stream> {
let quic_conn = quic_h3::Connection::new(accept);
let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?;
Ok(H3Stream(conn))
}

fn endpoint(&self) -> io::Result<Endpoint> {
pub(crate) fn endpoint(&self) -> io::Result<Endpoint> {
Ok(Endpoint::Quic(self.endpoint).with_tls(&self.tls))
}
}
Expand Down Expand Up @@ -159,16 +154,8 @@ impl QuicTx {
}

// FIXME: Expose certificates when possible.
impl Connection for H3Stream {
fn endpoint(&self) -> io::Result<Endpoint> {
let addr = self.0.inner.conn.handle().remote_addr()?;
Ok(Endpoint::Quic(addr).assume_tls())
}
}

// FIXME: Expose certificates when possible.
impl Connection for H3Connection {
fn endpoint(&self) -> io::Result<Endpoint> {
impl H3Connection {
pub fn endpoint(&self) -> io::Result<Endpoint> {
let addr = self.handle.remote_addr()?;
Ok(Endpoint::Quic(addr).assume_tls())
}
Expand Down
Loading

0 comments on commit 9522c68

Please sign in to comment.