diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index c09efb119d..5dcb7bdb3a 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -131,3 +131,4 @@ version_check = "0.9.1" tokio = { version = "1", features = ["macros", "io-std"] } figment = { version = "0.10", features = ["test"] } pretty_assertions = "1" +arc-swap = "1.7" diff --git a/core/lib/src/listener/tls.rs b/core/lib/src/listener/tls.rs index ab52ddb676..d955a0ff8d 100644 --- a/core/lib/src/listener/tls.rs +++ b/core/lib/src/listener/tls.rs @@ -1,7 +1,6 @@ use std::io; use std::sync::Arc; -use serde::Deserialize; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::LazyConfigAcceptor; use rustls::server::{Acceptor, ServerConfig}; diff --git a/core/lib/src/tls/resolver.rs b/core/lib/src/tls/resolver.rs index 4d864c1758..05ea9ed71a 100644 --- a/core/lib/src/tls/resolver.rs +++ b/core/lib/src/tls/resolver.rs @@ -37,8 +37,13 @@ impl fairing::Fairing for Fairing { #[cfg(test)] mod tests { + use std::sync::atomic::AtomicU64; + use std::sync::atomic::Ordering; use std::sync::Arc; use std::collections::HashMap; + use std::time::UNIX_EPOCH; + use arc_swap::ArcSwap; + use either::Either; use serde::Deserialize; use crate::http::uri::Host; use crate::tls::{TlsConfig, ServerConfig, Resolver, ClientHello}; @@ -69,10 +74,49 @@ mod tests { } } + struct UpdatingResolver { + timestamp: AtomicU64, + tls_config: TlsConfig, + server_config: ArcSwap + } + + impl TryFrom for UpdatingResolver { + type Error = crate::tls::Error; + + fn try_from(tls_config: TlsConfig) -> Result { + Ok(UpdatingResolver { + timestamp: AtomicU64::new(0), + server_config: ArcSwap::new(Arc::new(tls_config.to_server_config()?)), + tls_config, + }) + } + } + + #[crate::async_trait] + impl Resolver for UpdatingResolver { + async fn resolve(&self, _: ClientHello<'_>) -> Option> { + if let Either::Left(path) = self.tls_config.certs() { + let metadata = tokio::fs::metadata(&path).await.ok()?; + let modtime = metadata.modified().ok()?; + let timestamp = modtime.duration_since(UNIX_EPOCH).ok()?.as_secs(); + let old_timestamp = self.timestamp.load(Ordering::Acquire); + if timestamp > old_timestamp { + let new_config = self.tls_config.to_server_config().ok()?; + self.server_config.store(Arc::new(new_config)); + self.timestamp.store(timestamp, Ordering::Release); + } + } + + Some(self.server_config.load_full()) + } + } + #[test] fn test_config() { figment::Jail::expect_with(|jail| { use crate::fs::relative; + use figment::Figment; + use figment::providers::{Toml, Format}; let cert_path = relative!("../../examples/tls/private/rsa_sha256_cert.pem"); let key_path = relative!("../../examples/tls/private/rsa_sha256_key.pem"); @@ -87,7 +131,8 @@ mod tests { key = "{key_path}" "#))?; - let config = crate::Config::figment().extract::()?; + let toml = Toml::file("Rocket.toml").nested(); + let config: SniConfig = Figment::from(toml).extract().unwrap(); assert!(config.sni.contains_key(&Host::parse("api.rocket.rs").unwrap())); assert!(config.sni.contains_key(&Host::parse("blob.rocket.rs").unwrap())); Ok(())