Skip to content

Commit

Permalink
migrate to new rustls types
Browse files Browse the repository at this point in the history
  • Loading branch information
LockedThread committed Jul 22, 2024
1 parent a60e0c4 commit e2a8c74
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 49 deletions.
13 changes: 7 additions & 6 deletions amqprs/src/api/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,9 @@ impl TryFrom<&str> for OpenConnectionArguments {
);

#[cfg(not(feature = "tls"))]
return Err(Error::UriError("can't create amqps url without the `tls` feature enabled".to_string()));
return Err(Error::UriError(
"can't create amqps url without the `tls` feature enabled".to_string(),
));
}

// Check & apply query
Expand Down Expand Up @@ -558,7 +560,7 @@ impl Connection {
}
SplitConnection::open_tls(
&format!("{}:{}", args.host, args.port),
&tls_adaptor.domain,
tls_adaptor.domain.clone(),
&tls_adaptor.connector,
)
.await?
Expand Down Expand Up @@ -1510,8 +1512,7 @@ mod tests {
#[cfg(all(feature = "urispec", feature = "tls"))]
#[test]
fn test_urispec_amqps_simple() {
let args = OpenConnectionArguments::try_from("amqps://localhost")
.unwrap();
let args = OpenConnectionArguments::try_from("amqps://localhost").unwrap();
assert_eq!(args.host, "localhost");
assert_eq!(args.port, 5671);
assert_eq!(args.virtual_host, "/");
Expand All @@ -1535,8 +1536,8 @@ mod tests {
let domain = "AMQPRS_TEST";
let tls_adaptor = crate::tls::TlsAdaptor::with_client_auth(
Some(root_ca_cert.as_path()),
client_cert.as_path(),
client_private_key.as_path(),
client_cert.to_path_buf(),
client_private_key.to_path_buf(),
domain.to_owned(),
)
.unwrap();
Expand Down
93 changes: 58 additions & 35 deletions amqprs/src/api/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
//! [`OpenConnectionArguments`]: ../connection/struct.OpenConnectionArguments.html
//! [`Connection::open`]: ../connection/struct.Connection.html#method.open
use std::{fs::File, io::BufReader, path::Path, sync::Arc};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use std::{
fs::File,
io::BufReader,
path::{Path, PathBuf},
sync::Arc,
};
use tokio_rustls::{
rustls::{Certificate, ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore},
webpki, TlsConnector,
rustls::{ClientConfig, RootCertStore},
TlsConnector,
};

/// The TLS adaptor used to enable TLS network stream.
Expand Down Expand Up @@ -45,7 +51,6 @@ impl TlsAdaptor {
let root_cert_store = Self::build_root_store(root_ca_cert)?;

let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
Expand All @@ -64,17 +69,16 @@ impl TlsAdaptor {
/// Panics if private key is invalid.
pub fn with_client_auth(
root_ca_cert: Option<&Path>,
client_cert: &Path,
client_private_key: &Path,
client_cert: PathBuf,
client_private_key: PathBuf,
domain: String,
) -> std::io::Result<Self> {
let root_cert_store = Self::build_root_store(root_ca_cert)?;
let client_certs = Self::build_client_certificates(client_cert)?;
let client_keys = Self::build_client_private_keys(client_private_key)?;
let client_certs: Vec<CertificateDer> = Self::build_client_certificates(client_cert)?;
let client_keys: Vec<PrivateKeyDer> = Self::build_client_private_keys(client_private_key)?;
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_single_cert(client_certs, client_keys.into_iter().next().unwrap())
.with_client_auth_cert(client_certs, client_keys.into_iter().next().unwrap())
.unwrap();
let connector = TlsConnector::from(Arc::new(config));

Expand All @@ -85,44 +89,63 @@ impl TlsAdaptor {
let mut root_store = RootCertStore::empty();
if let Some(root_ca_cert) = root_ca_cert {
let mut pem = BufReader::new(File::open(root_ca_cert)?);

let certs = rustls_pemfile::certs(&mut pem)?;
let trust_anchors = certs.iter().map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)

let trust_anchors = certs.into_iter().map(|cert| {
// let c = webpki::Cert::from(cert);
let der = rustls_pki_types::CertificateDer::from(cert);
let anchor = webpki::anchor_from_trusted_cert(&der).unwrap().to_owned();

rustls_pki_types::TrustAnchor {
subject: anchor.subject.into(),
subject_public_key_info: anchor.subject_public_key_info.into(),
name_constraints: anchor.name_constraints.map(|f| f.into()),
}
});
root_store.add_server_trust_anchors(trust_anchors);

// NOTE: The old rustls add_server_trust_anchors function did this
// https://github.com/rustls/rustls/blob/d1345fc39ad597e27e6355341d2b2b40c501625b/rustls/src/anchors.rs#L117-L118
root_store.roots.extend(trust_anchors);
} else {
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(
|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
},
));
root_store
.roots
.extend(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls_pki_types::TrustAnchor {
subject: ta.subject.into(),
subject_public_key_info: ta.spki.into(),
name_constraints: ta.name_constraints.map(|f| f.into()),
}
}));
}
Ok(root_store)
}

fn build_client_certificates(client_cert: &Path) -> std::io::Result<Vec<Certificate>> {
let mut pem = BufReader::new(File::open(client_cert)?);
let certs = rustls_pemfile::certs(&mut pem)?;
let certs = certs.into_iter().map(Certificate);
fn build_client_certificates(
client_cert: PathBuf,
) -> std::io::Result<Vec<CertificateDer<'static>>> {
let file = File::open(client_cert)?;
let mut pem = BufReader::new(file);
let raw_certs = rustls_pemfile::certs(&mut pem)?;
let certs = raw_certs.into_iter().map(CertificateDer::from);
Ok(certs.collect())
}

fn build_client_private_keys(client_private_key: &Path) -> std::io::Result<Vec<PrivateKey>> {
fn build_client_private_keys(
client_private_key: PathBuf,
) -> std::io::Result<Vec<PrivateKeyDer<'static>>> {
let mut pem = BufReader::new(File::open(client_private_key)?);
let keys = Self::read_private_keys_from_pem(&mut pem)?;
let keys = keys.into_iter().map(PrivateKey);
Ok(keys.collect())
let keys = keys
.into_iter()
.map(|c| {
PrivateKeyDer::try_from(c)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
})
.collect::<Result<Vec<_>, _>>()?;

Ok(keys)
}

/// Parses PEM encoded private keys.
///
/// The input should PEM encoded private key in RSA, SEC1 Elliptic Curve or PKCS#8 format.
Expand Down
4 changes: 2 additions & 2 deletions amqprs/src/net/split_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ impl SplitConnection {
}

#[cfg(feature = "tls")]
pub async fn open_tls(addr: &str, domain: &str, connector: &TlsConnector) -> Result<Self> {
let domain = rustls::ServerName::try_from(domain)
pub async fn open_tls(addr: &str, domain: String, connector: &TlsConnector) -> Result<Self> {
let domain = rustls_pki_types::ServerName::try_from(domain)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;

let stream = connector
Expand Down
4 changes: 2 additions & 2 deletions amqprs/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ pub fn build_conn_args() -> OpenConnectionArguments {
.tls_adaptor(
amqprs::tls::TlsAdaptor::with_client_auth(
Some(root_ca_cert.as_path()),
client_cert.as_path(),
client_private_key.as_path(),
client_cert.to_path_buf(),
client_private_key.to_path_buf(),
domain.to_owned(),
)
.unwrap(),
Expand Down
4 changes: 2 additions & 2 deletions examples/src/mtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ async fn main() {
.tls_adaptor(
TlsAdaptor::with_client_auth(
Some(root_ca_cert.as_path()),
client_cert.as_path(),
client_private_key.as_path(),
client_cert.to_path_buf(),
client_private_key.to_path_buf(),
domain.to_owned(),
)
.unwrap(),
Expand Down
4 changes: 2 additions & 2 deletions examples/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ async fn main() {
.tls_adaptor(
TlsAdaptor::with_client_auth(
Some(root_ca_cert.as_path()),
client_cert.as_path(),
client_private_key.as_path(),
client_cert.to_path_buf(),
client_private_key.to_path_buf(),
domain.to_owned(),
)
.unwrap(),
Expand Down

0 comments on commit e2a8c74

Please sign in to comment.