From e2a8c74cae58782a78e3d147d9ee7f9788b78ed2 Mon Sep 17 00:00:00 2001 From: Warren Snipes Date: Mon, 22 Jul 2024 19:07:05 +0000 Subject: [PATCH] migrate to new rustls types --- amqprs/src/api/connection.rs | 13 +++-- amqprs/src/api/tls.rs | 93 +++++++++++++++++++----------- amqprs/src/net/split_connection.rs | 4 +- amqprs/tests/common/mod.rs | 4 +- examples/src/mtls.rs | 4 +- examples/src/tls.rs | 4 +- 6 files changed, 73 insertions(+), 49 deletions(-) diff --git a/amqprs/src/api/connection.rs b/amqprs/src/api/connection.rs index 2730617..30302e2 100644 --- a/amqprs/src/api/connection.rs +++ b/amqprs/src/api/connection.rs @@ -501,7 +501,9 @@ impl TryFrom<&str> for OpenConnectionArguments { ); #[cfg(not(feature = "tls"))] - return Err(Error::UriError("can't create amqps url without the `tls` feature enabled".to_string())); + return Err(Error::UriError( + "can't create amqps url without the `tls` feature enabled".to_string(), + )); } // Check & apply query @@ -558,7 +560,7 @@ impl Connection { } SplitConnection::open_tls( &format!("{}:{}", args.host, args.port), - &tls_adaptor.domain, + tls_adaptor.domain.clone(), &tls_adaptor.connector, ) .await? @@ -1510,8 +1512,7 @@ mod tests { #[cfg(all(feature = "urispec", feature = "tls"))] #[test] fn test_urispec_amqps_simple() { - let args = OpenConnectionArguments::try_from("amqps://localhost") - .unwrap(); + let args = OpenConnectionArguments::try_from("amqps://localhost").unwrap(); assert_eq!(args.host, "localhost"); assert_eq!(args.port, 5671); assert_eq!(args.virtual_host, "/"); @@ -1535,8 +1536,8 @@ mod tests { let domain = "AMQPRS_TEST"; let tls_adaptor = crate::tls::TlsAdaptor::with_client_auth( Some(root_ca_cert.as_path()), - client_cert.as_path(), - client_private_key.as_path(), + client_cert.to_path_buf(), + client_private_key.to_path_buf(), domain.to_owned(), ) .unwrap(); diff --git a/amqprs/src/api/tls.rs b/amqprs/src/api/tls.rs index 73fad65..c90d438 100644 --- a/amqprs/src/api/tls.rs +++ b/amqprs/src/api/tls.rs @@ -6,10 +6,16 @@ //! [`OpenConnectionArguments`]: ../connection/struct.OpenConnectionArguments.html //! [`Connection::open`]: ../connection/struct.Connection.html#method.open -use std::{fs::File, io::BufReader, path::Path, sync::Arc}; +use rustls_pki_types::{CertificateDer, PrivateKeyDer}; +use std::{ + fs::File, + io::BufReader, + path::{Path, PathBuf}, + sync::Arc, +}; use tokio_rustls::{ - rustls::{Certificate, ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore}, - webpki, TlsConnector, + rustls::{ClientConfig, RootCertStore}, + TlsConnector, }; /// The TLS adaptor used to enable TLS network stream. @@ -45,7 +51,6 @@ impl TlsAdaptor { let root_cert_store = Self::build_root_store(root_ca_cert)?; let config = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_cert_store) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(config)); @@ -64,17 +69,16 @@ impl TlsAdaptor { /// Panics if private key is invalid. pub fn with_client_auth( root_ca_cert: Option<&Path>, - client_cert: &Path, - client_private_key: &Path, + client_cert: PathBuf, + client_private_key: PathBuf, domain: String, ) -> std::io::Result { let root_cert_store = Self::build_root_store(root_ca_cert)?; - let client_certs = Self::build_client_certificates(client_cert)?; - let client_keys = Self::build_client_private_keys(client_private_key)?; + let client_certs: Vec = Self::build_client_certificates(client_cert)?; + let client_keys: Vec = Self::build_client_private_keys(client_private_key)?; let config = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_cert_store) - .with_single_cert(client_certs, client_keys.into_iter().next().unwrap()) + .with_client_auth_cert(client_certs, client_keys.into_iter().next().unwrap()) .unwrap(); let connector = TlsConnector::from(Arc::new(config)); @@ -85,44 +89,63 @@ impl TlsAdaptor { let mut root_store = RootCertStore::empty(); if let Some(root_ca_cert) = root_ca_cert { let mut pem = BufReader::new(File::open(root_ca_cert)?); + let certs = rustls_pemfile::certs(&mut pem)?; - let trust_anchors = certs.iter().map(|cert| { - let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) + + let trust_anchors = certs.into_iter().map(|cert| { + // let c = webpki::Cert::from(cert); + let der = rustls_pki_types::CertificateDer::from(cert); + let anchor = webpki::anchor_from_trusted_cert(&der).unwrap().to_owned(); + + rustls_pki_types::TrustAnchor { + subject: anchor.subject.into(), + subject_public_key_info: anchor.subject_public_key_info.into(), + name_constraints: anchor.name_constraints.map(|f| f.into()), + } }); - root_store.add_server_trust_anchors(trust_anchors); + + // NOTE: The old rustls add_server_trust_anchors function did this + // https://github.com/rustls/rustls/blob/d1345fc39ad597e27e6355341d2b2b40c501625b/rustls/src/anchors.rs#L117-L118 + root_store.roots.extend(trust_anchors); } else { - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map( - |ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }, - )); + root_store + .roots + .extend(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + rustls_pki_types::TrustAnchor { + subject: ta.subject.into(), + subject_public_key_info: ta.spki.into(), + name_constraints: ta.name_constraints.map(|f| f.into()), + } + })); } Ok(root_store) } - fn build_client_certificates(client_cert: &Path) -> std::io::Result> { - let mut pem = BufReader::new(File::open(client_cert)?); - let certs = rustls_pemfile::certs(&mut pem)?; - let certs = certs.into_iter().map(Certificate); + fn build_client_certificates( + client_cert: PathBuf, + ) -> std::io::Result>> { + let file = File::open(client_cert)?; + let mut pem = BufReader::new(file); + let raw_certs = rustls_pemfile::certs(&mut pem)?; + let certs = raw_certs.into_iter().map(CertificateDer::from); Ok(certs.collect()) } - fn build_client_private_keys(client_private_key: &Path) -> std::io::Result> { + fn build_client_private_keys( + client_private_key: PathBuf, + ) -> std::io::Result>> { let mut pem = BufReader::new(File::open(client_private_key)?); let keys = Self::read_private_keys_from_pem(&mut pem)?; - let keys = keys.into_iter().map(PrivateKey); - Ok(keys.collect()) + let keys = keys + .into_iter() + .map(|c| { + PrivateKeyDer::try_from(c) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e)) + }) + .collect::, _>>()?; + + Ok(keys) } - /// Parses PEM encoded private keys. /// /// The input should PEM encoded private key in RSA, SEC1 Elliptic Curve or PKCS#8 format. diff --git a/amqprs/src/net/split_connection.rs b/amqprs/src/net/split_connection.rs index 97c3c5e..ccf6079 100644 --- a/amqprs/src/net/split_connection.rs +++ b/amqprs/src/net/split_connection.rs @@ -131,8 +131,8 @@ impl SplitConnection { } #[cfg(feature = "tls")] - pub async fn open_tls(addr: &str, domain: &str, connector: &TlsConnector) -> Result { - let domain = rustls::ServerName::try_from(domain) + pub async fn open_tls(addr: &str, domain: String, connector: &TlsConnector) -> Result { + let domain = rustls_pki_types::ServerName::try_from(domain) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?; let stream = connector diff --git a/amqprs/tests/common/mod.rs b/amqprs/tests/common/mod.rs index 9315115..58dc982 100644 --- a/amqprs/tests/common/mod.rs +++ b/amqprs/tests/common/mod.rs @@ -30,8 +30,8 @@ pub fn build_conn_args() -> OpenConnectionArguments { .tls_adaptor( amqprs::tls::TlsAdaptor::with_client_auth( Some(root_ca_cert.as_path()), - client_cert.as_path(), - client_private_key.as_path(), + client_cert.to_path_buf(), + client_private_key.to_path_buf(), domain.to_owned(), ) .unwrap(), diff --git a/examples/src/mtls.rs b/examples/src/mtls.rs index d5eadc3..63c1c8c 100644 --- a/examples/src/mtls.rs +++ b/examples/src/mtls.rs @@ -38,8 +38,8 @@ async fn main() { .tls_adaptor( TlsAdaptor::with_client_auth( Some(root_ca_cert.as_path()), - client_cert.as_path(), - client_private_key.as_path(), + client_cert.to_path_buf(), + client_private_key.to_path_buf(), domain.to_owned(), ) .unwrap(), diff --git a/examples/src/tls.rs b/examples/src/tls.rs index 47110f4..1338995 100644 --- a/examples/src/tls.rs +++ b/examples/src/tls.rs @@ -37,8 +37,8 @@ async fn main() { .tls_adaptor( TlsAdaptor::with_client_auth( Some(root_ca_cert.as_path()), - client_cert.as_path(), - client_private_key.as_path(), + client_cert.to_path_buf(), + client_private_key.to_path_buf(), domain.to_owned(), ) .unwrap(),