Skip to content

Commit

Permalink
wip: dynamic tls cert resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Mar 26, 2024
1 parent bd26ca4 commit e4e46ef
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 53 deletions.
68 changes: 17 additions & 51 deletions core/lib/src/listener/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ use std::io;
use std::sync::Arc;

use serde::Deserialize;
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;
use tokio_rustls::LazyConfigAcceptor;
use rustls::server::{Acceptor, ServerConfig};

use crate::tls::{TlsConfig, Error};
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
use crate::tls::{Error, Resolver, TlsConfig};
use crate::listener::{Listener, Bindable, Connection, Certificates, Endpoint};

#[doc(inline)]
Expand All @@ -16,59 +15,17 @@ pub use tokio_rustls::server::TlsStream;
/// A TLS listener over some listener interface L.
pub struct TlsListener<L> {
listener: L,
acceptor: TlsAcceptor,
resolver: Option<Arc<dyn Resolver>>,
default: Arc<ServerConfig>,
config: TlsConfig,
}

#[derive(Clone, Deserialize)]
#[derive(Clone)]
pub struct TlsBindable<I> {
#[serde(flatten)]
pub inner: I,
pub tls: TlsConfig,
}

impl TlsConfig {
pub(crate) fn server_config(&self) -> Result<ServerConfig, Error> {
let provider = rustls::crypto::CryptoProvider {
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
..rustls::crypto::ring::default_provider()
};

#[cfg(feature = "mtls")]
let verifier = match self.mutual {
Some(ref mtls) => {
let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs));
match mtls.mandatory {
true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?,
}
},
None => WebPkiClientVerifier::no_client_auth(),
};

#[cfg(not(feature = "mtls"))]
let verifier = WebPkiClientVerifier::no_client_auth();

let key = load_key(&mut self.key_reader()?)?;
let cert_chain = load_cert_chain(&mut self.certs_reader()?)?;
let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)?;

tls_config.ignore_client_order = self.prefer_server_cipher_order;
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
if cfg!(feature = "http2") {
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
}

Ok(tls_config)
}
}

impl<I: Bindable> Bindable for TlsBindable<I>
where I::Listener: Listener<Accept = <I::Listener as Listener>::Connection>,
<I::Listener as Listener>::Connection: AsyncRead + AsyncWrite
Expand All @@ -79,7 +36,8 @@ impl<I: Bindable> Bindable for TlsBindable<I>

async fn bind(self) -> Result<Self::Listener, Self::Error> {
Ok(TlsListener {
acceptor: TlsAcceptor::from(Arc::new(self.tls.server_config()?)),
default: Arc::new(self.tls.to_server_config()?),
resolver: None,
listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?,
config: self.tls,
})
Expand All @@ -104,7 +62,15 @@ impl<L> Listener for TlsListener<L>
}

async fn connect(&self, conn: L::Connection) -> io::Result<Self::Connection> {
self.acceptor.accept(conn).await
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), conn);
let handshake = acceptor.await?;
let hello = handshake.client_hello();
let config = match &self.resolver {
Some(r) => r.resolve(hello).await.unwrap_or_else(|| self.default.clone()),
None => self.default.clone(),
};

handshake.into_stream(config).await
}

fn endpoint(&self) -> io::Result<Endpoint> {
Expand Down
53 changes: 51 additions & 2 deletions core/lib/src/tls/config.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use std::io;
use std::sync::Arc;

use figment::value::magic::{Either, RelativePathBuf};
use serde::{Deserialize, Serialize};
use indexmap::IndexSet;

use crate::tls::Result;

/// TLS configuration: certificate chain, key, and ciphersuites.
///
/// Four parameters control `tls` configuration:
Expand Down Expand Up @@ -426,8 +429,54 @@ impl TlsConfig {
self.mutual.as_ref()
}

pub fn validate(&self) -> Result<(), crate::tls::Error> {
self.server_config().map(|_| ())
/// Try to convert `self` into a [rustls] [`ServerConfig`].
///
/// [`ServerConfig`]: rustls::server::ServerConfig
pub fn to_server_config(&self) -> Result<rustls::server::ServerConfig> {
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};

let provider = rustls::crypto::CryptoProvider {
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
..rustls::crypto::ring::default_provider()
};

#[cfg(feature = "mtls")]
let verifier = match self.mutual {
Some(ref mtls) => {
let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs));
match mtls.mandatory {
true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?,
}
},
None => WebPkiClientVerifier::no_client_auth(),
};

#[cfg(not(feature = "mtls"))]
let verifier = WebPkiClientVerifier::no_client_auth();

let key = load_key(&mut self.key_reader()?)?;
let cert_chain = load_cert_chain(&mut self.certs_reader()?)?;
let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)?;

tls_config.ignore_client_order = self.prefer_server_cipher_order;
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
if cfg!(feature = "http2") {
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
}

Ok(tls_config)
}

pub fn validate(&self) -> Result<()> {
self.to_server_config().map(|_| ())
}
}

Expand Down
4 changes: 4 additions & 0 deletions core/lib/src/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
mod error;
mod resolver;
pub(crate) mod config;
pub(crate) mod util;

pub use rustls;

pub use error::Result;
pub use config::{TlsConfig, CipherSuite};
pub use error::Error;
pub use resolver::{Resolver, ClientHello, ServerConfig};
96 changes: 96 additions & 0 deletions core/lib/src/tls/resolver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::sync::Arc;

pub use rustls::server::{ClientHello, ServerConfig};

use crate::{fairing, Build, Rocket};

/// A dynamic TLS configuration resolver.
#[crate::async_trait]
pub trait Resolver: Send + Sync + 'static {
async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>>;

async fn fairing(self) -> Fairing where Self: Sized {
Fairing {
resolver: Arc::new(self)
}
}
}

#[derive(Clone)]
pub struct Fairing {
resolver: Arc<dyn Resolver>,
}

#[crate::async_trait]
impl fairing::Fairing for Fairing {
fn info(&self) -> fairing::Info {
fairing::Info {
name: "TLS Resolver",
kind: fairing::Kind::Ignite | fairing::Kind::Singleton
}
}

async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
Ok(rocket.manage(self.clone()))
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::collections::HashMap;
use serde::Deserialize;
use crate::http::uri::Host;
use crate::tls::{TlsConfig, ServerConfig, Resolver, ClientHello};

/// ```toml
/// [sni."api.rocket.rs"]
/// certs = "private/api_rocket_rs.rsa_sha256_cert.pem"
/// key = "private/api_rocket_rs.rsa_sha256_key.pem"
///
/// [sni."blob.rocket.rs"]
/// certs = "private/blob_rsa_sha256_cert.pem"
/// key = "private/blob_rsa_sha256_key.pem"
/// ```
#[derive(Deserialize)]
struct SniConfig {
sni: HashMap<Host<'static>, TlsConfig>,
}

struct SniResolver {
sni_map: HashMap<Host<'static>, Arc<ServerConfig>>
}

#[crate::async_trait]
impl Resolver for SniResolver {
async fn resolve(&self, hello: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
let host = Host::parse(hello.server_name()?).ok()?;
self.sni_map.get(&host).cloned()
}
}

#[test]
fn test_config() {
figment::Jail::expect_with(|jail| {
use crate::fs::relative;

let cert_path = relative!("../../examples/tls/private/rsa_sha256_cert.pem");
let key_path = relative!("../../examples/tls/private/rsa_sha256_key.pem");

jail.create_file("Rocket.toml", &format!(r#"
[default.sni."api.rocket.rs"]
certs = "{cert_path}"
key = "{key_path}"
[default.sni."blob.rocket.rs"]
certs = "{cert_path}"
key = "{key_path}"
"#))?;

let config = crate::Config::figment().extract::<SniConfig>()?;
assert!(config.sni.contains_key(&Host::parse("api.rocket.rs").unwrap()));
assert!(config.sni.contains_key(&Host::parse("blob.rocket.rs").unwrap()));
Ok(())
});
}
}

0 comments on commit e4e46ef

Please sign in to comment.