Skip to content

Commit

Permalink
wip: tls-resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Apr 16, 2024
1 parent 60f3cd5 commit be10677
Show file tree
Hide file tree
Showing 37 changed files with 1,368 additions and 644 deletions.
24 changes: 23 additions & 1 deletion core/http/src/uri/authority.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl<'a> Authority<'a> {
self.host.from_cow_source(&self.source)
}

/// Returns the port part of the authority URI, if there is one.
/// Returns the `port` part of the authority URI, if there is one.
///
/// # Example
///
Expand All @@ -206,6 +206,28 @@ impl<'a> Authority<'a> {
pub fn port(&self) -> Option<u16> {
self.port
}

/// Set the `port` of the authority URI.
///
/// # Example
///
/// ```rust
/// # #[macro_use] extern crate rocket;
/// let mut uri = uri!("username:password@host:123");
/// assert_eq!(uri.port(), Some(123));
///
/// uri.set_port(1024);
/// assert_eq!(uri.port(), Some(1024));
/// assert_eq!(uri, "username:password@host:1024");
///
/// uri.set_port(None);
/// assert_eq!(uri.port(), None);
/// assert_eq!(uri, "username:password@host");
/// ```
#[inline(always)]
pub fn set_port<T: Into<Option<u16>>>(&mut self, port: T) {
self.port = port.into();
}
}

impl_serde!(Authority<'a>, "an authority-form URI");
Expand Down
5 changes: 3 additions & 2 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ ref-swap = "0.1.2"
parking_lot = "0.12"
ubyte = {version = "0.10.2", features = ["serde"] }
serde = { version = "1.0", features = ["derive"] }
figment = { version = "0.10.13", features = ["toml", "env"] }
figment = { version = "0.10.17", features = ["toml", "env"] }
rand = "0.8"
either = "1"
pin-project-lite = "0.2"
Expand Down Expand Up @@ -140,5 +140,6 @@ version_check = "0.9.1"

[dev-dependencies]
tokio = { version = "1", features = ["macros", "io-std"] }
figment = { version = "0.10", features = ["test"] }
figment = { version = "0.10.17", features = ["test"] }
pretty_assertions = "1"
arc-swap = "1.7"
3 changes: 0 additions & 3 deletions core/lib/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ mod secret_key;
#[cfg(unix)]
pub use crate::shutdown::Sig;

#[cfg(unix)]
pub use crate::listener::unix::UdsConfig;

#[cfg(feature = "secrets")]
pub use secret_key::SecretKey;

Expand Down
10 changes: 7 additions & 3 deletions core/lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,13 @@ impl Error {
self.mark_handled();
match self.kind() {
ErrorKind::Bind(ref a, ref e) => {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
if let Some(e) = e.downcast_ref::<Self>() {
e.pretty_print();
} else {
match a {
Some(a) => error!("Binding to {} failed.", a.primary().underline()),
None => error!("Binding to network interface failed."),
}
}

info_!("{}", e);
Expand Down
10 changes: 10 additions & 0 deletions core/lib/src/listener/bind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use crate::listener::{Endpoint, Listener};

pub trait Bind<T>: Listener + 'static {
type Error: std::error::Error + Send + 'static;

#[crate::async_bound(Send)]
async fn bind(to: T) -> Result<Self, Self::Error>;

fn bind_endpoint(to: &T) -> Result<Endpoint, Self::Error>;
}
52 changes: 0 additions & 52 deletions core/lib/src/listener/bindable.rs

This file was deleted.

3 changes: 2 additions & 1 deletion core/lib/src/listener/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::borrow::Cow;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::either::Either;

use super::Endpoint;
Expand All @@ -9,7 +10,7 @@ use super::Endpoint;
#[derive(Clone)]
pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>);

pub trait Connection: Send + Unpin {
pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin {
fn endpoint(&self) -> io::Result<Endpoint>;

/// DER-encoded X.509 certificate chain presented by the client, if any.
Expand Down
118 changes: 70 additions & 48 deletions core/lib/src/listener/default.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,86 @@
use either::Either;
use serde::Deserialize;
use tokio_util::either::{Either, Either::{Left, Right}};
use futures::TryFutureExt;

use crate::listener::{Bindable, Endpoint};
use crate::error::{Error, ErrorKind};
use crate::error::ErrorKind;
use crate::{Ignite, Rocket};
use crate::listener::{Bind, Endpoint, tcp::TcpListener};

#[derive(serde::Deserialize)]
pub struct DefaultListener {
#[cfg(unix)] use crate::listener::unix::UnixListener;
#[cfg(feature = "tls")] use crate::tls::{TlsListener, TlsConfig};

mod private {
use super::{Either, TcpListener};

#[cfg(feature = "tls")] pub type TlsListener<T> = super::TlsListener<T>;
#[cfg(not(feature = "tls"))] pub type TlsListener<T> = T;
#[cfg(unix)] pub type UnixListener = super::UnixListener;
#[cfg(not(unix))] pub type UnixListener = super::TcpListener;

pub type Listener = Either<
Either<TlsListener<TcpListener>, TlsListener<UnixListener>>,
Either<TcpListener, UnixListener>,
>;
}

#[derive(Deserialize)]
struct Config {
#[serde(default)]
pub address: Endpoint,
pub port: Option<u16>,
pub reuse: Option<bool>,
address: Endpoint,
#[cfg(feature = "tls")]
pub tls: Option<crate::tls::TlsConfig>,
tls: Option<TlsConfig>,
}

#[cfg(not(unix))] type BaseBindable = Either<std::net::SocketAddr, std::net::SocketAddr>;
#[cfg(unix)] type BaseBindable = Either<std::net::SocketAddr, super::unix::UdsConfig>;
pub type DefaultListener = private::Listener;

#[cfg(not(feature = "tls"))] type TlsBindable<T> = Either<T, T>;
#[cfg(feature = "tls")] type TlsBindable<T> = Either<super::tls::TlsBindable<T>, T>;
impl<'r> Bind<&'r Rocket<Ignite>> for DefaultListener {
type Error = crate::Error;

impl DefaultListener {
pub(crate) fn base_bindable(&self) -> Result<BaseBindable, crate::Error> {
match &self.address {
Endpoint::Tcp(mut address) => {
if let Some(port) = self.port {
address.set_port(port);
}
async fn bind(rocket: &'r Rocket<Ignite>) -> Result<Self, Self::Error> {
let config: Config = rocket.figment().extract()?;

Ok(BaseBindable::Left(address))
},
#[cfg(unix)]
Endpoint::Unix(path) => {
let uds = super::unix::UdsConfig { path: path.clone(), reuse: self.reuse, };
Ok(BaseBindable::Right(uds))
},
#[cfg(not(unix))]
e@Endpoint::Unix(_) => {
let msg = "Unix domain sockets unavailable on non-unix platforms.";
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(Some(e.clone()), boxed)))
},
other => {
let msg = format!("unsupported default listener address: {other}");
let boxed = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(Error::new(ErrorKind::Bind(Some(other.clone()), boxed)))
let endpoint = config.address;
match endpoint {
#[cfg(feature = "tls")]
Endpoint::Tcp(_) if config.tls.is_some() => {
let listener = <TlsListener<TcpListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Ok(Left(Left(listener)))
}
}
}
Endpoint::Tcp(_) => {
let listener = <TcpListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

pub(crate) fn tls_bindable<T>(&self, inner: T) -> TlsBindable<T> {
#[cfg(feature = "tls")]
if let Some(tls) = self.tls.clone() {
return TlsBindable::Left(super::tls::TlsBindable { inner, tls });
}
Ok(Right(Left(listener)))
}
#[cfg(all(unix, feature = "tls"))]
Endpoint::Unix(_) if config.tls.is_some() => {
let listener = <TlsListener<UnixListener> as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

Ok(Left(Right(listener)))
}
#[cfg(unix)]
Endpoint::Unix(_) => {
let listener = <UnixListener as Bind<_>>::bind(rocket)
.map_err(|e| ErrorKind::Bind(Some(endpoint), Box::new(e)))
.await?;

TlsBindable::Right(inner)
Ok(Right(Right(listener)))
}
_ => {
let msg = format!("unsupported bind endpoint: {endpoint}");
let error = Box::<dyn std::error::Error + Send + Sync>::from(msg);
Err(ErrorKind::Bind(Some(endpoint), error).into())
}
}
}

pub fn bindable(&self) -> Result<impl Bindable, crate::Error> {
self.base_bindable()
.map(|b| b.map_either(|b| self.tls_bindable(b), |b| self.tls_bindable(b)))
fn bind_endpoint(rocket: &&'r Rocket<Ignite>) -> Result<Endpoint, Self::Error> {
Ok(rocket.figment().extract_inner("address")?)
}
}
52 changes: 37 additions & 15 deletions core/lib/src/listener/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;

use figment::Figment;
use serde::de;

use crate::http::uncased::AsUncased;
Expand All @@ -21,7 +22,7 @@ impl<T: fmt::Display + fmt::Debug + Sync + Send + Any> EndpointAddr for T {}
/// * [`&str`] - parse with [`FromStr`]
/// * [`tokio::net::unix::SocketAddr`] - must be path: [`Endpoint::Unix`]
/// * [`PathBuf`] - infallibly as [`Endpoint::Unix`]
#[derive(Debug, Clone)]
#[derive(Clone)]
#[non_exhaustive]
pub enum Endpoint {
Tcp(net::SocketAddr),
Expand Down Expand Up @@ -152,6 +153,31 @@ impl Endpoint {

Self::Tls(Arc::new(self), None)
}

/// Fetch the endpoint at `path` in `figment` of kind `kind` (e.g, "tcp")
/// then map the value, if any, with `f` into a different value. If the
/// conversion succeeds, returns `Ok(value)`. If the conversion fails and
/// `Some` value was passed in, returns an error indicating the endpoint was
/// an invalid `kind` and otherwise returns a "missing field" error.
pub(crate) fn fetch<T, F>(figment: &Figment, kind: &str, path: &str, f: F) -> figment::Result<T>
where F: FnOnce(Option<&Endpoint>) -> Option<T>
{
use figment::error::{Error, Kind};

let endpoint = figment.extract_inner::<Option<Endpoint>>(path)?;
if let Some(value) = f(endpoint.as_ref()) {
return Ok(value);
}

let mut error = match endpoint {
Some(e) => Error::from(format!("invalid {kind} endpoint: {e:?}")).with_path(path),
None => Error::from(Kind::MissingField(path.to_string().into())),
};

error.profile = Some(figment.profile().clone());
error.metadata = figment.find_metadata(path).cloned();
Err(error)
}
}

impl fmt::Display for Endpoint {
Expand Down Expand Up @@ -180,9 +206,15 @@ impl fmt::Display for Endpoint {
}
}

impl From<PathBuf> for Endpoint {
fn from(value: PathBuf) -> Self {
Self::Unix(value)
impl fmt::Debug for Endpoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Tcp(a) => write!(f, "tcp:{a}"),
Self::Quic(a) => write!(f, "quic:{a}]"),
Self::Unix(a) => write!(f, "unix:{}", a.display()),
Self::Tls(e, _) => write!(f, "unix:{:?}", &**e),
Self::Custom(e) => e.fmt(f),
}
}
}

Expand All @@ -197,14 +229,6 @@ impl TryFrom<tokio::net::unix::SocketAddr> for Endpoint {
}
}

impl TryFrom<&str> for Endpoint {
type Error = AddrParseError;

fn try_from(value: &str) -> Result<Self, Self::Error> {
value.parse()
}
}

impl Default for Endpoint {
fn default() -> Self {
Endpoint::Tcp(net::SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8000))
Expand Down Expand Up @@ -237,8 +261,6 @@ impl FromStr for Endpoint {
if let Some((proto, string)) = string.split_once(':') {
if proto.trim().as_uncased() == "tcp" {
return parse_tcp(string.trim(), 8000).map(Self::Tcp);
} else if proto.trim().as_uncased() == "quic" {
return parse_tcp(string.trim(), 8000).map(Self::Quic);
} else if proto.trim().as_uncased() == "unix" {
return Ok(Self::Unix(PathBuf::from(string.trim())));
}
Expand All @@ -256,7 +278,7 @@ impl<'de> de::Deserialize<'de> for Endpoint {
type Value = Endpoint;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("TCP or Unix address")
formatter.write_str("valid TCP (ip) or unix (path) endpoint")
}

fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
Expand Down
Loading

0 comments on commit be10677

Please sign in to comment.