From e763c54a56ed1d54f01a563a2b2a653650dc1181 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Fri, 26 Jan 2024 18:38:07 -0800 Subject: [PATCH] Fix more things --- core/lib/Cargo.toml | 3 +- core/lib/src/error.rs | 2 +- core/lib/src/listener/tls.rs | 39 +++++++------ core/lib/src/local/asynchronous/client.rs | 11 +++- core/lib/src/local/blocking/client.rs | 4 +- core/lib/src/local/client.rs | 23 ++++---- core/lib/src/rocket.rs | 10 ++-- core/lib/src/tls/config.rs | 4 ++ examples/config/src/tests.rs | 4 -- examples/tls/Cargo.toml | 2 +- examples/tls/src/tests.rs | 71 ++++++++++++++++------- scripts/test.sh | 3 +- 12 files changed, 109 insertions(+), 67 deletions(-) diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index ef3f7fdb36..6724b21d32 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -20,7 +20,7 @@ rust-version = "1.64" all-features = true [features] -default = ["http2"] +default = ["http2", "tokio-macros"] http2 = ["hyper/http2", "hyper-util/http2"] secrets = ["cookie/private", "cookie/key-expansion"] json = ["serde_json"] @@ -28,6 +28,7 @@ msgpack = ["rmp-serde"] uuid = ["uuid_", "rocket_http/uuid"] tls = ["rustls", "tokio-rustls", "rustls-pemfile"] mtls = ["tls", "x509-parser"] +tokio-macros = ["tokio/macros"] [dependencies] # Optional serialization dependencies. diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 2f42cb58e6..21753b1f1b 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -87,7 +87,7 @@ pub enum ErrorKind { FailedFairings(Vec), /// Sentinels requested abort. SentinelAborts(Vec), - /// The configuration profile is not debug but not secret key is configured. + /// The configuration profile is not debug but no secret key is configured. InsecureSecretKey(Profile), /// Shutdown failed. Contains the Rocket instance that failed to shutdown. Shutdown(Arc>), diff --git a/core/lib/src/listener/tls.rs b/core/lib/src/listener/tls.rs index 1bee67d75d..40a1d96700 100644 --- a/core/lib/src/listener/tls.rs +++ b/core/lib/src/listener/tls.rs @@ -7,9 +7,7 @@ use tokio_rustls::TlsAcceptor; use crate::tls::{TlsConfig, Error}; use crate::tls::util::{load_cert_chain, load_key, load_ca_certs}; -use crate::listener::{Listener, Connection, Certificates, ListenerAddr}; - -use super::Bindable; +use crate::listener::{Listener, Bindable, Connection, Certificates, ListenerAddr}; #[doc(inline)] pub use tokio_rustls::server::TlsStream; @@ -27,19 +25,15 @@ pub struct TlsBindable { pub tls: TlsConfig, } -impl Bindable for TlsBindable { - type Listener = TlsListener; - - type Error = Error; - - async fn bind(self) -> Result { +impl TlsConfig { + pub(crate) fn acceptor(&self) -> Result { let provider = rustls::crypto::CryptoProvider { - cipher_suites: self.tls.ciphers().map(|c| c.into()).collect(), + cipher_suites: self.ciphers().map(|c| c.into()).collect(), ..rustls::crypto::ring::default_provider() }; #[cfg(feature = "mtls")] - let verifier = match self.tls.mutual { + let verifier = match self.mutual { Some(ref mtls) => { let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?; let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs)); @@ -54,14 +48,14 @@ impl Bindable for TlsBindable { #[cfg(not(feature = "mtls"))] let verifier = WebPkiClientVerifier::no_client_auth(); - let key = load_key(&mut self.tls.key_reader()?)?; - let cert_chain = load_cert_chain(&mut self.tls.certs_reader()?)?; + let key = load_key(&mut self.key_reader()?)?; + let cert_chain = load_cert_chain(&mut self.certs_reader()?)?; let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider)) .with_safe_default_protocol_versions()? .with_client_cert_verifier(verifier) .with_single_cert(cert_chain, key)?; - tls_config.ignore_client_order = self.tls.prefer_server_cipher_order; + tls_config.ignore_client_order = self.prefer_server_cipher_order; tls_config.session_storage = ServerSessionMemoryCache::new(1024); tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?; tls_config.alpn_protocols = vec![b"http/1.1".to_vec()]; @@ -69,9 +63,20 @@ impl Bindable for TlsBindable { tls_config.alpn_protocols.insert(0, b"h2".to_vec()); } - let acceptor = TlsAcceptor::from(Arc::new(tls_config)); - let listener = self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?; - Ok(TlsListener { listener, acceptor }) + Ok(TlsAcceptor::from(Arc::new(tls_config))) + } +} + +impl Bindable for TlsBindable { + type Listener = TlsListener; + + type Error = Error; + + async fn bind(self) -> Result { + Ok(TlsListener { + acceptor: self.tls.acceptor()?, + listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?, + }) } } diff --git a/core/lib/src/local/asynchronous/client.rs b/core/lib/src/local/asynchronous/client.rs index 4d82f88335..fba607698b 100644 --- a/core/lib/src/local/asynchronous/client.rs +++ b/core/lib/src/local/asynchronous/client.rs @@ -5,6 +5,7 @@ use parking_lot::RwLock; use crate::{Rocket, Phase, Orbit, Ignite, Error}; use crate::local::asynchronous::{LocalRequest, LocalResponse}; use crate::http::{Method, uri::Origin}; +use crate::listener::ListenerAddr; /// An `async` client to construct and dispatch local requests. /// @@ -55,9 +56,15 @@ pub struct Client { impl Client { pub(crate) async fn _new( rocket: Rocket

, - tracked: bool + tracked: bool, + secure: bool, ) -> Result { - let rocket = rocket.local_launch().await?; + let mut listener = ListenerAddr::new("local client"); + if secure { + listener = listener.into_tls(); + } + + let rocket = rocket.local_launch(listener).await?; let cookies = RwLock::new(cookie::CookieJar::new()); Ok(Client { rocket, cookies, tracked }) } diff --git a/core/lib/src/local/blocking/client.rs b/core/lib/src/local/blocking/client.rs index 5408df30b3..f87df009f2 100644 --- a/core/lib/src/local/blocking/client.rs +++ b/core/lib/src/local/blocking/client.rs @@ -30,7 +30,7 @@ pub struct Client { } impl Client { - fn _new(rocket: Rocket

, tracked: bool) -> Result { + fn _new(rocket: Rocket

, tracked: bool, secure: bool) -> Result { let runtime = tokio::runtime::Builder::new_multi_thread() .thread_name("rocket-local-client-worker-thread") .worker_threads(1) @@ -39,7 +39,7 @@ impl Client { .expect("create tokio runtime"); // Initialize the Rocket instance - let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked))?); + let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked, secure))?); Ok(Self { inner, runtime: RefCell::new(runtime) }) } diff --git a/core/lib/src/local/client.rs b/core/lib/src/local/client.rs index f2b3b922d0..983019258f 100644 --- a/core/lib/src/local/client.rs +++ b/core/lib/src/local/client.rs @@ -41,6 +41,7 @@ macro_rules! req_method { macro_rules! pub_client_impl { ($import:literal $(@$prefix:tt $suffix:tt)?) => { + /// Construct a new `Client` from an instance of `Rocket` _with_ cookie /// tracking. This is typically the desired mode of operation for testing. /// @@ -68,7 +69,12 @@ macro_rules! pub_client_impl { /// ``` #[inline(always)] pub $($prefix)? fn tracked(rocket: Rocket

) -> Result { - Self::_new(rocket, true) $(.$suffix)? + Self::_new(rocket, true, false) $(.$suffix)? + } + + #[inline(always)] + pub $($prefix)? fn tracked_secure(rocket: Rocket

) -> Result { + Self::_new(rocket, true, true) $(.$suffix)? } /// Construct a new `Client` from an instance of `Rocket` _without_ @@ -92,7 +98,11 @@ macro_rules! pub_client_impl { /// let client = Client::untracked(rocket); /// ``` pub $($prefix)? fn untracked(rocket: Rocket

) -> Result { - Self::_new(rocket, false) $(.$suffix)? + Self::_new(rocket, false, false) $(.$suffix)? + } + + pub $($prefix)? fn untracked_secure(rocket: Rocket

) -> Result { + Self::_new(rocket, false, true) $(.$suffix)? } /// Terminates `Client` by initiating a graceful shutdown via @@ -135,15 +145,6 @@ macro_rules! pub_client_impl { Self::tracked(rocket.configure(figment)) $(.$suffix)? } - /// Deprecated alias to [`Client::tracked()`]. - #[deprecated( - since = "0.6.0-dev", - note = "choose between `Client::untracked()` and `Client::tracked()`" - )] - pub $($prefix)? fn new(rocket: Rocket

) -> Result { - Self::tracked(rocket) $(.$suffix)? - } - /// Returns a reference to the `Rocket` this client is creating requests /// for. /// diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index ba790dfd5c..dc73f24886 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -672,8 +672,8 @@ impl Rocket { }) } - async fn _local_launch(self) -> Rocket { - let rocket = self.into_orbit(ListenerAddr::new("local client")); + async fn _local_launch(self, addr: ListenerAddr) -> Rocket { + let rocket = self.into_orbit(addr); Rocket::liftoff(&rocket).await; rocket } @@ -879,10 +879,10 @@ impl Rocket

{ } } - pub(crate) async fn local_launch(self) -> Result, Error> { + pub(crate) async fn local_launch(self, l: ListenerAddr) -> Result, Error> { let rocket = match self.0.into_state() { - State::Build(s) => Rocket::from(s).ignite().await?._local_launch().await, - State::Ignite(s) => Rocket::from(s)._local_launch().await, + State::Build(s) => Rocket::from(s).ignite().await?._local_launch(l).await, + State::Ignite(s) => Rocket::from(s)._local_launch(l).await, State::Orbit(s) => Rocket::from(s) }; diff --git a/core/lib/src/tls/config.rs b/core/lib/src/tls/config.rs index 9c0d43e5f4..387e175b00 100644 --- a/core/lib/src/tls/config.rs +++ b/core/lib/src/tls/config.rs @@ -425,6 +425,10 @@ impl TlsConfig { pub fn key_reader(&self) -> io::Result> { to_reader(&self.key) } + + pub fn validate(&self) -> Result<(), crate::tls::Error> { + self.acceptor().map(|_| ()) + } } impl CipherSuite { diff --git a/examples/config/src/tests.rs b/examples/config/src/tests.rs index 7cabb9dc51..e774f7ec05 100644 --- a/examples/config/src/tests.rs +++ b/examples/config/src/tests.rs @@ -6,15 +6,11 @@ async fn test_config(profile: &str) { let config = rocket.config(); match &*profile { "debug" => { - assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST); - assert_eq!(config.port, 8000); assert_eq!(config.workers, 1); assert_eq!(config.keep_alive, 0); assert_eq!(config.log_level, LogLevel::Normal); } "release" => { - assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST); - assert_eq!(config.port, 8000); assert_eq!(config.workers, 12); assert_eq!(config.keep_alive, 5); assert_eq!(config.log_level, LogLevel::Critical); diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml index cf77632390..9c72493908 100644 --- a/examples/tls/Cargo.toml +++ b/examples/tls/Cargo.toml @@ -6,5 +6,5 @@ edition = "2021" publish = false [dependencies] -rocket = { path = "../../core/lib", features = ["tls", "mtls"] } +rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets"] } yansi = "1.0.0-rc.1" diff --git a/examples/tls/src/tests.rs b/examples/tls/src/tests.rs index 2629e3c487..61efbec9ff 100644 --- a/examples/tls/src/tests.rs +++ b/examples/tls/src/tests.rs @@ -1,11 +1,21 @@ use std::fs::{self, File}; +use rocket::http::{CookieJar, Cookie}; use rocket::local::blocking::Client; use rocket::fs::relative; +#[get("/cookie")] +fn cookie(jar: &CookieJar<'_>) { + jar.add(("k1", "v1")); + jar.add_private(("k2", "v2")); + + jar.add(Cookie::build(("k1u", "v1u")).secure(false)); + jar.add_private(Cookie::build(("k2u", "v2u")).secure(false)); +} + #[test] fn hello_mutual() { - let client = Client::tracked(super::rocket()).unwrap(); + let client = Client::tracked_secure(super::rocket()).unwrap(); let cert_paths = fs::read_dir(relative!("private")).unwrap() .map(|entry| entry.unwrap().path().to_string_lossy().into_owned()) .filter(|path| path.ends_with("_cert.pem") && !path.ends_with("ca_cert.pem")); @@ -23,35 +33,43 @@ fn hello_mutual() { #[test] fn secure_cookies() { - use rocket::http::{CookieJar, Cookie}; - - #[get("/cookie")] - fn cookie(jar: &CookieJar<'_>) { - jar.add(("k1", "v1")); - jar.add_private(("k2", "v2")); - - jar.add(Cookie::build(("k1u", "v1u")).secure(false)); - jar.add_private(Cookie::build(("k2u", "v2u")).secure(false)); - } + let rocket = super::rocket().mount("/", routes![cookie]); + let client = Client::tracked_secure(rocket).unwrap(); - let client = Client::tracked(super::rocket().mount("/", routes![cookie])).unwrap(); let response = client.get("/cookie").dispatch(); - let c1 = response.cookies().get("k1").unwrap(); - assert_eq!(c1.secure(), Some(true)); - let c2 = response.cookies().get_private("k2").unwrap(); + let c3 = response.cookies().get("k1u").unwrap(); + let c4 = response.cookies().get_private("k2u").unwrap(); + + assert_eq!(c1.secure(), Some(true)); assert_eq!(c2.secure(), Some(true)); + assert_ne!(c3.secure(), Some(true)); + assert_ne!(c4.secure(), Some(true)); +} - let c1 = response.cookies().get("k1u").unwrap(); - assert_ne!(c1.secure(), Some(true)); +#[test] +fn insecure_cookies() { + let rocket = super::rocket().mount("/", routes![cookie]); + let client = Client::tracked(rocket).unwrap(); + + let response = client.get("/cookie").dispatch(); + let c1 = response.cookies().get("k1").unwrap(); + let c2 = response.cookies().get_private("k2").unwrap(); + let c3 = response.cookies().get("k1u").unwrap(); + let c4 = response.cookies().get_private("k2u").unwrap(); - let c2 = response.cookies().get_private("k2u").unwrap(); - assert_ne!(c2.secure(), Some(true)); + assert_eq!(c1.secure(), None); + assert_eq!(c2.secure(), None); + assert_eq!(c3.secure(), None); + assert_eq!(c4.secure(), None); } #[test] fn hello_world() { + use rocket::listener::DefaultListener; + use rocket::config::{Config, SecretKey}; + let profiles = [ "rsa_sha256", "ecdsa_nistp256_sha256_pkcs8", @@ -61,11 +79,20 @@ fn hello_world() { "ed25519", ]; - // TODO: Testing doesn't actually read keys since we don't do TLS locally. for profile in profiles { - let config = rocket::Config::figment().select(profile); - let client = Client::tracked(super::rocket().configure(config)).unwrap(); + let config = Config { + secret_key: SecretKey::generate().unwrap(), + ..Config::debug_default() + }; + + let figment = Config::figment().merge(config).select(profile); + let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap(); let response = client.get("/").dispatch(); assert_eq!(response.into_string().unwrap(), "Hello, world!"); + + let figment = client.rocket().figment(); + let listener: DefaultListener = figment.extract().unwrap(); + assert_eq!(figment.profile(), profile); + listener.tls.as_ref().unwrap().validate().expect("valid TLS config"); } } diff --git a/scripts/test.sh b/scripts/test.sh index 40525a1f05..555f4c7bc5 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -126,10 +126,11 @@ function test_contrib() { function test_core() { FEATURES=( + tokio-macros + http2 secrets tls mtls - http2 json msgpack uuid