Skip to content

Commit

Permalink
Fix more things
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Jan 27, 2024
1 parent 0097746 commit e763c54
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 67 deletions.
3 changes: 2 additions & 1 deletion core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ rust-version = "1.64"
all-features = true

[features]
default = ["http2"]
default = ["http2", "tokio-macros"]
http2 = ["hyper/http2", "hyper-util/http2"]
secrets = ["cookie/private", "cookie/key-expansion"]
json = ["serde_json"]
msgpack = ["rmp-serde"]
uuid = ["uuid_", "rocket_http/uuid"]
tls = ["rustls", "tokio-rustls", "rustls-pemfile"]
mtls = ["tls", "x509-parser"]
tokio-macros = ["tokio/macros"]

[dependencies]
# Optional serialization dependencies.
Expand Down
2 changes: 1 addition & 1 deletion core/lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pub enum ErrorKind {
FailedFairings(Vec<crate::fairing::Info>),
/// Sentinels requested abort.
SentinelAborts(Vec<crate::sentinel::Sentry>),
/// The configuration profile is not debug but not secret key is configured.
/// The configuration profile is not debug but no secret key is configured.
InsecureSecretKey(Profile),
/// Shutdown failed. Contains the Rocket instance that failed to shutdown.
Shutdown(Arc<Rocket<Orbit>>),
Expand Down
39 changes: 22 additions & 17 deletions core/lib/src/listener/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ use tokio_rustls::TlsAcceptor;

use crate::tls::{TlsConfig, Error};
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
use crate::listener::{Listener, Connection, Certificates, ListenerAddr};

use super::Bindable;
use crate::listener::{Listener, Bindable, Connection, Certificates, ListenerAddr};

#[doc(inline)]
pub use tokio_rustls::server::TlsStream;
Expand All @@ -27,19 +25,15 @@ pub struct TlsBindable<I> {
pub tls: TlsConfig,
}

impl<I: Bindable> Bindable for TlsBindable<I> {
type Listener = TlsListener<I::Listener>;

type Error = Error;

async fn bind(self) -> Result<Self::Listener, Self::Error> {
impl TlsConfig {
pub(crate) fn acceptor(&self) -> Result<tokio_rustls::TlsAcceptor, Error> {
let provider = rustls::crypto::CryptoProvider {
cipher_suites: self.tls.ciphers().map(|c| c.into()).collect(),
cipher_suites: self.ciphers().map(|c| c.into()).collect(),
..rustls::crypto::ring::default_provider()
};

#[cfg(feature = "mtls")]
let verifier = match self.tls.mutual {
let verifier = match self.mutual {
Some(ref mtls) => {
let ca_certs = load_ca_certs(&mut mtls.ca_certs_reader()?)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_certs));
Expand All @@ -54,24 +48,35 @@ impl<I: Bindable> Bindable for TlsBindable<I> {
#[cfg(not(feature = "mtls"))]
let verifier = WebPkiClientVerifier::no_client_auth();

let key = load_key(&mut self.tls.key_reader()?)?;
let cert_chain = load_cert_chain(&mut self.tls.certs_reader()?)?;
let key = load_key(&mut self.key_reader()?)?;
let cert_chain = load_cert_chain(&mut self.certs_reader()?)?;
let mut tls_config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)?;

tls_config.ignore_client_order = self.tls.prefer_server_cipher_order;
tls_config.ignore_client_order = self.prefer_server_cipher_order;
tls_config.session_storage = ServerSessionMemoryCache::new(1024);
tls_config.ticketer = rustls::crypto::ring::Ticketer::new()?;
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
if cfg!(feature = "http2") {
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
}

let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let listener = self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?;
Ok(TlsListener { listener, acceptor })
Ok(TlsAcceptor::from(Arc::new(tls_config)))
}
}

impl<I: Bindable> Bindable for TlsBindable<I> {
type Listener = TlsListener<I::Listener>;

type Error = Error;

async fn bind(self) -> Result<Self::Listener, Self::Error> {
Ok(TlsListener {
acceptor: self.tls.acceptor()?,
listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?,
})
}
}

Expand Down
11 changes: 9 additions & 2 deletions core/lib/src/local/asynchronous/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use parking_lot::RwLock;
use crate::{Rocket, Phase, Orbit, Ignite, Error};
use crate::local::asynchronous::{LocalRequest, LocalResponse};
use crate::http::{Method, uri::Origin};
use crate::listener::ListenerAddr;

/// An `async` client to construct and dispatch local requests.
///
Expand Down Expand Up @@ -55,9 +56,15 @@ pub struct Client {
impl Client {
pub(crate) async fn _new<P: Phase>(
rocket: Rocket<P>,
tracked: bool
tracked: bool,
secure: bool,
) -> Result<Client, Error> {
let rocket = rocket.local_launch().await?;
let mut listener = ListenerAddr::new("local client");
if secure {
listener = listener.into_tls();
}

let rocket = rocket.local_launch(listener).await?;
let cookies = RwLock::new(cookie::CookieJar::new());
Ok(Client { rocket, cookies, tracked })
}
Expand Down
4 changes: 2 additions & 2 deletions core/lib/src/local/blocking/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct Client {
}

impl Client {
fn _new<P: Phase>(rocket: Rocket<P>, tracked: bool) -> Result<Client, Error> {
fn _new<P: Phase>(rocket: Rocket<P>, tracked: bool, secure: bool) -> Result<Client, Error> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.thread_name("rocket-local-client-worker-thread")
.worker_threads(1)
Expand All @@ -39,7 +39,7 @@ impl Client {
.expect("create tokio runtime");

// Initialize the Rocket instance
let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked))?);
let inner = Some(runtime.block_on(asynchronous::Client::_new(rocket, tracked, secure))?);
Ok(Self { inner, runtime: RefCell::new(runtime) })
}

Expand Down
23 changes: 12 additions & 11 deletions core/lib/src/local/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ macro_rules! req_method {
macro_rules! pub_client_impl {
($import:literal $(@$prefix:tt $suffix:tt)?) =>
{

/// Construct a new `Client` from an instance of `Rocket` _with_ cookie
/// tracking. This is typically the desired mode of operation for testing.
///
Expand Down Expand Up @@ -68,7 +69,12 @@ macro_rules! pub_client_impl {
/// ```
#[inline(always)]
pub $($prefix)? fn tracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
Self::_new(rocket, true) $(.$suffix)?
Self::_new(rocket, true, false) $(.$suffix)?
}

#[inline(always)]
pub $($prefix)? fn tracked_secure<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
Self::_new(rocket, true, true) $(.$suffix)?
}

/// Construct a new `Client` from an instance of `Rocket` _without_
Expand All @@ -92,7 +98,11 @@ macro_rules! pub_client_impl {
/// let client = Client::untracked(rocket);
/// ```
pub $($prefix)? fn untracked<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
Self::_new(rocket, false) $(.$suffix)?
Self::_new(rocket, false, false) $(.$suffix)?
}

pub $($prefix)? fn untracked_secure<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
Self::_new(rocket, false, true) $(.$suffix)?
}

/// Terminates `Client` by initiating a graceful shutdown via
Expand Down Expand Up @@ -135,15 +145,6 @@ macro_rules! pub_client_impl {
Self::tracked(rocket.configure(figment)) $(.$suffix)?
}

/// Deprecated alias to [`Client::tracked()`].
#[deprecated(
since = "0.6.0-dev",
note = "choose between `Client::untracked()` and `Client::tracked()`"
)]
pub $($prefix)? fn new<P: Phase>(rocket: Rocket<P>) -> Result<Self, Error> {
Self::tracked(rocket) $(.$suffix)?
}

/// Returns a reference to the `Rocket` this client is creating requests
/// for.
///
Expand Down
10 changes: 5 additions & 5 deletions core/lib/src/rocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,8 @@ impl Rocket<Ignite> {
})
}

async fn _local_launch(self) -> Rocket<Orbit> {
let rocket = self.into_orbit(ListenerAddr::new("local client"));
async fn _local_launch(self, addr: ListenerAddr) -> Rocket<Orbit> {
let rocket = self.into_orbit(addr);
Rocket::liftoff(&rocket).await;
rocket
}
Expand Down Expand Up @@ -879,10 +879,10 @@ impl<P: Phase> Rocket<P> {
}
}

pub(crate) async fn local_launch(self) -> Result<Rocket<Orbit>, Error> {
pub(crate) async fn local_launch(self, l: ListenerAddr) -> Result<Rocket<Orbit>, Error> {
let rocket = match self.0.into_state() {
State::Build(s) => Rocket::from(s).ignite().await?._local_launch().await,
State::Ignite(s) => Rocket::from(s)._local_launch().await,
State::Build(s) => Rocket::from(s).ignite().await?._local_launch(l).await,
State::Ignite(s) => Rocket::from(s)._local_launch(l).await,
State::Orbit(s) => Rocket::from(s)
};

Expand Down
4 changes: 4 additions & 0 deletions core/lib/src/tls/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,10 @@ impl TlsConfig {
pub fn key_reader(&self) -> io::Result<Box<dyn io::BufRead + Sync + Send>> {
to_reader(&self.key)
}

pub fn validate(&self) -> Result<(), crate::tls::Error> {
self.acceptor().map(|_| ())
}
}

impl CipherSuite {
Expand Down
4 changes: 0 additions & 4 deletions examples/config/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@ async fn test_config(profile: &str) {
let config = rocket.config();
match &*profile {
"debug" => {
assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST);
assert_eq!(config.port, 8000);
assert_eq!(config.workers, 1);
assert_eq!(config.keep_alive, 0);
assert_eq!(config.log_level, LogLevel::Normal);
}
"release" => {
assert_eq!(config.address, std::net::Ipv4Addr::LOCALHOST);
assert_eq!(config.port, 8000);
assert_eq!(config.workers, 12);
assert_eq!(config.keep_alive, 5);
assert_eq!(config.log_level, LogLevel::Critical);
Expand Down
2 changes: 1 addition & 1 deletion examples/tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ edition = "2021"
publish = false

[dependencies]
rocket = { path = "../../core/lib", features = ["tls", "mtls"] }
rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets"] }
yansi = "1.0.0-rc.1"
71 changes: 49 additions & 22 deletions examples/tls/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
use std::fs::{self, File};

use rocket::http::{CookieJar, Cookie};
use rocket::local::blocking::Client;
use rocket::fs::relative;

#[get("/cookie")]
fn cookie(jar: &CookieJar<'_>) {
jar.add(("k1", "v1"));
jar.add_private(("k2", "v2"));

jar.add(Cookie::build(("k1u", "v1u")).secure(false));
jar.add_private(Cookie::build(("k2u", "v2u")).secure(false));
}

#[test]
fn hello_mutual() {
let client = Client::tracked(super::rocket()).unwrap();
let client = Client::tracked_secure(super::rocket()).unwrap();
let cert_paths = fs::read_dir(relative!("private")).unwrap()
.map(|entry| entry.unwrap().path().to_string_lossy().into_owned())
.filter(|path| path.ends_with("_cert.pem") && !path.ends_with("ca_cert.pem"));
Expand All @@ -23,35 +33,43 @@ fn hello_mutual() {

#[test]
fn secure_cookies() {
use rocket::http::{CookieJar, Cookie};

#[get("/cookie")]
fn cookie(jar: &CookieJar<'_>) {
jar.add(("k1", "v1"));
jar.add_private(("k2", "v2"));

jar.add(Cookie::build(("k1u", "v1u")).secure(false));
jar.add_private(Cookie::build(("k2u", "v2u")).secure(false));
}
let rocket = super::rocket().mount("/", routes![cookie]);
let client = Client::tracked_secure(rocket).unwrap();

let client = Client::tracked(super::rocket().mount("/", routes![cookie])).unwrap();
let response = client.get("/cookie").dispatch();

let c1 = response.cookies().get("k1").unwrap();
assert_eq!(c1.secure(), Some(true));

let c2 = response.cookies().get_private("k2").unwrap();
let c3 = response.cookies().get("k1u").unwrap();
let c4 = response.cookies().get_private("k2u").unwrap();

assert_eq!(c1.secure(), Some(true));
assert_eq!(c2.secure(), Some(true));
assert_ne!(c3.secure(), Some(true));
assert_ne!(c4.secure(), Some(true));
}

let c1 = response.cookies().get("k1u").unwrap();
assert_ne!(c1.secure(), Some(true));
#[test]
fn insecure_cookies() {
let rocket = super::rocket().mount("/", routes![cookie]);
let client = Client::tracked(rocket).unwrap();

let response = client.get("/cookie").dispatch();
let c1 = response.cookies().get("k1").unwrap();
let c2 = response.cookies().get_private("k2").unwrap();
let c3 = response.cookies().get("k1u").unwrap();
let c4 = response.cookies().get_private("k2u").unwrap();

let c2 = response.cookies().get_private("k2u").unwrap();
assert_ne!(c2.secure(), Some(true));
assert_eq!(c1.secure(), None);
assert_eq!(c2.secure(), None);
assert_eq!(c3.secure(), None);
assert_eq!(c4.secure(), None);
}

#[test]
fn hello_world() {
use rocket::listener::DefaultListener;
use rocket::config::{Config, SecretKey};

let profiles = [
"rsa_sha256",
"ecdsa_nistp256_sha256_pkcs8",
Expand All @@ -61,11 +79,20 @@ fn hello_world() {
"ed25519",
];

// TODO: Testing doesn't actually read keys since we don't do TLS locally.
for profile in profiles {
let config = rocket::Config::figment().select(profile);
let client = Client::tracked(super::rocket().configure(config)).unwrap();
let config = Config {
secret_key: SecretKey::generate().unwrap(),
..Config::debug_default()
};

let figment = Config::figment().merge(config).select(profile);
let client = Client::tracked_secure(super::rocket().configure(figment)).unwrap();
let response = client.get("/").dispatch();
assert_eq!(response.into_string().unwrap(), "Hello, world!");

let figment = client.rocket().figment();
let listener: DefaultListener = figment.extract().unwrap();
assert_eq!(figment.profile(), profile);
listener.tls.as_ref().unwrap().validate().expect("valid TLS config");
}
}
3 changes: 2 additions & 1 deletion scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,11 @@ function test_contrib() {

function test_core() {
FEATURES=(
tokio-macros
http2
secrets
tls
mtls
http2
json
msgpack
uuid
Expand Down

0 comments on commit e763c54

Please sign in to comment.