Skip to content

Commit

Permalink
Upgrade 'rustls' to '0.22'.
Browse files Browse the repository at this point in the history
In the process, the following improvements were also made:

  * Error messages related to TLS were improved.
  * 'Redirector' in 'tls' example was improved.
  • Loading branch information
SergioBenitez committed Dec 16, 2023
1 parent a59f3c4 commit 9c2b74b
Show file tree
Hide file tree
Showing 14 changed files with 253 additions and 120 deletions.
6 changes: 3 additions & 3 deletions core/http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ percent-encoding = "2"
http = "0.2"
time = { version = "0.3", features = ["formatting", "macros"] }
indexmap = "2"
rustls = { version = "0.21", optional = true }
tokio-rustls = { version = "0.24", optional = true }
rustls-pemfile = { version = "1.0.2", optional = true }
rustls = { version = "0.22", optional = true }
tokio-rustls = { version = "0.25", optional = true }
rustls-pemfile = { version = "2.0.0", optional = true }
tokio = { version = "1.6.1", features = ["net", "sync", "time"] }
log = "0.4"
ref-cast = "1.0"
Expand Down
33 changes: 21 additions & 12 deletions core/http/src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,45 @@ use state::InitCell;
pub use tokio::net::TcpListener;

/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
// NOTE: `rustls::Certificate` is exactly isomorphic to `CertificateData`.
#[doc(inline)]
#[cfg(feature = "tls")]
pub use rustls::Certificate as CertificateData;
#[cfg(not(feature = "tls"))]
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CertificateDer(pub(crate) Vec<u8>);

/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct CertificateData(pub Vec<u8>);
#[cfg(feature = "tls")]
#[derive(Debug, Clone, Eq, PartialEq)]
#[repr(transparent)]
pub struct CertificateDer(pub(crate) rustls::pki_types::CertificateDer<'static>);

/// A collection of raw certificate data.
#[derive(Clone, Default)]
pub struct Certificates(Arc<InitCell<Vec<CertificateData>>>);
pub struct Certificates(Arc<InitCell<Vec<CertificateDer>>>);

impl From<Vec<CertificateData>> for Certificates {
fn from(value: Vec<CertificateData>) -> Self {
impl From<Vec<CertificateDer>> for Certificates {
fn from(value: Vec<CertificateDer>) -> Self {
Certificates(Arc::new(value.into()))
}
}

#[cfg(feature = "tls")]
impl From<Vec<rustls::pki_types::CertificateDer<'static>>> for Certificates {
fn from(value: Vec<rustls::pki_types::CertificateDer<'static>>) -> Self {
let value: Vec<_> = value.into_iter().map(CertificateDer).collect();
Certificates(Arc::new(value.into()))
}
}

#[doc(hidden)]
impl Certificates {
/// Set the the raw certificate chain data. Only the first call actually
/// sets the data; the remaining do nothing.
#[cfg(feature = "tls")]
pub(crate) fn set(&self, data: Vec<CertificateData>) {
pub(crate) fn set(&self, data: Vec<CertificateDer>) {
self.0.set(data);
}

/// Returns the raw certificate chain data, if any is available.
pub fn chain_data(&self) -> Option<&[CertificateData]> {
pub fn chain_data(&self) -> Option<&[CertificateDer]> {
self.0.try_get().map(|v| v.as_slice())
}
}
Expand Down
95 changes: 95 additions & 0 deletions core/http/src/tls/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
pub type Result<T, E = Error> = std::result::Result<T, E>;

#[derive(Debug)]
pub enum KeyError {
BadKeyCount(usize),
Io(std::io::Error),
Unsupported(rustls::Error),
BadItem(rustls_pemfile::Item),
}

#[derive(Debug)]
pub enum Error {
Io(std::io::Error),
Tls(rustls::Error),
Mtls(rustls::server::VerifierBuilderError),
CertChain(std::io::Error),
PrivKey(KeyError),
CertAuth(rustls::Error),
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use Error::*;

match self {
Io(e) => write!(f, "i/o error during tls binding: {e}"),
Tls(e) => write!(f, "tls configuration error: {e}"),
Mtls(e) => write!(f, "mtls verifier error: {e}"),
CertChain(e) => write!(f, "failed to process certificate chain: {e}"),
PrivKey(e) => write!(f, "failed to process private key: {e}"),
CertAuth(e) => write!(f, "failed to process certificate authority: {e}"),
}
}
}

impl std::fmt::Display for KeyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use KeyError::*;

match self {
Io(e) => write!(f, "error reading key file: {e}"),
BadKeyCount(0) => write!(f, "no valid keys found. is the file malformed?"),
BadKeyCount(n) => write!(f, "expected exactly 1 key, found {n}"),
Unsupported(e) => write!(f, "key is valid but is unsupported: {e}"),
BadItem(i) => write!(f, "found unexpected item in key file: {i:#?}"),
}
}
}

impl std::error::Error for KeyError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
KeyError::Io(e) => Some(e),
KeyError::Unsupported(e) => Some(e),
_ => None,
}
}
}

impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Io(e) => Some(e),
Error::Tls(e) => Some(e),
Error::Mtls(e) => Some(e),
Error::CertChain(e) => Some(e),
Error::PrivKey(e) => Some(e),
Error::CertAuth(e) => Some(e),
}
}
}

impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Error::Io(e)
}
}

impl From<rustls::Error> for Error {
fn from(e: rustls::Error) -> Self {
Error::Tls(e)
}
}

impl From<rustls::server::VerifierBuilderError> for Error {
fn from(value: rustls::server::VerifierBuilderError) -> Self {
Error::Mtls(value)
}
}

impl From<KeyError> for Error {
fn from(value: KeyError) -> Self {
Error::PrivKey(value)
}
}
73 changes: 35 additions & 38 deletions core/http/src/tls/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream};
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};

use crate::tls::util::{load_certs, load_private_key, load_ca_certs};
use crate::listener::{Connection, Listener, Certificates};
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
use crate::listener::{Connection, Listener, Certificates, CertificateDer};

/// A TLS listener over TCP.
pub struct TlsListener {
Expand Down Expand Up @@ -40,7 +41,7 @@ pub struct TlsListener {
///
/// To work around this, we "lie" when `peer_certificates()` are requested and
/// always return `Some(Certificates)`. Internally, `Certificates` is an
/// `Arc<InitCell<Vec<CertificateData>>>`, effectively a shared, thread-safe,
/// `Arc<InitCell<Vec<CertificateDer>>>`, effectively a shared, thread-safe,
/// `OnceCell`. The cell is initially empty and is filled as soon as the
/// handshake is complete. If the certificate data were to be requested prior to
/// this point, it would be empty. However, in Rocket, we only request
Expand Down Expand Up @@ -72,49 +73,43 @@ pub struct Config<R> {
}

impl TlsListener {
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> io::Result<TlsListener>
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> crate::tls::Result<TlsListener>
where R: io::BufRead
{
use rustls::server::{AllowAnyAuthenticatedClient, AllowAnyAnonymousOrAuthenticatedClient};
use rustls::server::{NoClientAuth, ServerSessionMemoryCache, ServerConfig};

let cert_chain = load_certs(&mut c.cert_chain)
.map_err(|e| io::Error::new(e.kind(), format!("bad TLS cert chain: {}", e)))?;

let key = load_private_key(&mut c.private_key)
.map_err(|e| io::Error::new(e.kind(), format!("bad TLS private key: {}", e)))?;
let provider = rustls::crypto::CryptoProvider {
cipher_suites: c.ciphersuites,
..rustls::crypto::ring::default_provider()
};

let client_auth = match c.ca_certs {
Some(ref mut ca_certs) => match load_ca_certs(ca_certs) {
Ok(ca) if c.mandatory_mtls => AllowAnyAuthenticatedClient::new(ca).boxed(),
Ok(ca) => AllowAnyAnonymousOrAuthenticatedClient::new(ca).boxed(),
Err(e) => return Err(io::Error::new(e.kind(), format!("bad CA cert(s): {}", e))),
let verifier = match c.ca_certs {
Some(ref mut ca_certs) => {
let ca_roots = Arc::new(load_ca_certs(ca_certs)?);
let verifier = WebPkiClientVerifier::builder(ca_roots);
match c.mandatory_mtls {
true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?,
}
},
None => NoClientAuth::boxed(),
None => WebPkiClientVerifier::no_client_auth(),
};

let mut tls_config = ServerConfig::builder()
.with_cipher_suites(&c.ciphersuites)
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?
.with_client_cert_verifier(client_auth)
.with_single_cert(cert_chain, key)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?;

tls_config.ignore_client_order = c.prefer_server_order;

tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
let key = load_key(&mut c.private_key)?;
let cert_chain = load_cert_chain(&mut c.cert_chain)?;
let mut config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)?;

config.ignore_client_order = c.prefer_server_order;
config.session_storage = ServerSessionMemoryCache::new(1024);
config.ticketer = rustls::crypto::ring::Ticketer::new()?;
config.alpn_protocols = vec![b"http/1.1".to_vec()];
if cfg!(feature = "http2") {
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
config.alpn_protocols.insert(0, b"h2".to_vec());
}

tls_config.session_storage = ServerSessionMemoryCache::new(1024);
tls_config.ticketer = rustls::Ticketer::new()
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS ticketer: {}", e)))?;

let listener = TcpListener::bind(addr).await?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let acceptor = TlsAcceptor::from(Arc::new(config));
Ok(TlsListener { listener, acceptor })
}
}
Expand Down Expand Up @@ -179,8 +174,10 @@ impl TlsStream {
TlsState::Handshaking(ref mut accept) => {
match futures::ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => {
if let Some(cert_chain) = stream.get_ref().1.peer_certificates() {
self.certs.set(cert_chain.to_vec());
if let Some(peer_certs) = stream.get_ref().1.peer_certificates() {
self.certs.set(peer_certs.into_iter()
.map(|v| CertificateDer(v.clone().into_owned()))
.collect());
}

self.state = TlsState::Streaming(stream);
Expand Down
3 changes: 3 additions & 0 deletions core/http/src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ pub mod mtls;
pub use rustls;
pub use listener::{TlsListener, Config};
pub mod util;
pub mod error;

pub use error::Result;
6 changes: 3 additions & 3 deletions core/http/src/tls/mtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use x509_parser::nom;
use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error, FromDer};
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;

use crate::listener::CertificateData;
use crate::listener::CertificateDer;

/// A type alias for [`Result`](std::result::Result) with the error type set to
/// [`Error`].
Expand Down Expand Up @@ -144,7 +144,7 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, PartialEq)]
pub struct Certificate<'a> {
x509: X509Certificate<'a>,
data: &'a CertificateData,
data: &'a CertificateDer,
}

/// An X.509 Distinguished Name (DN) found in a [`Certificate`].
Expand Down Expand Up @@ -224,7 +224,7 @@ impl<'a> Certificate<'a> {

/// PRIVATE: For internal Rocket use only!
#[doc(hidden)]
pub fn parse(chain: &[CertificateData]) -> Result<Certificate<'_>> {
pub fn parse(chain: &[CertificateDer]) -> Result<Certificate<'_>> {
let data = chain.first().ok_or_else(|| Error::Empty)?;
let x509 = Certificate::parse_one(&data.0)?;
Ok(Certificate { x509, data })
Expand Down
Loading

0 comments on commit 9c2b74b

Please sign in to comment.