From 9b0d57000e80d34c24c35586876ca84d4973763d Mon Sep 17 00:00:00 2001 From: Honsun Zhu Date: Mon, 2 Dec 2024 04:35:26 +0800 Subject: [PATCH] feat(tls): Add tls handshake timeout support --- tonic/src/transport/channel/service/tls.rs | 17 +++++++++++++---- tonic/src/transport/channel/tls.rs | 11 +++++++++++ tonic/src/transport/server/service/tls.rs | 19 ++++++++++++++++--- tonic/src/transport/server/tls.rs | 13 ++++++++++++- tonic/src/transport/service/tls.rs | 2 ++ 5 files changed, 54 insertions(+), 8 deletions(-) diff --git a/tonic/src/transport/channel/service/tls.rs b/tonic/src/transport/channel/service/tls.rs index 5dd227f81..33cd8f66a 100644 --- a/tonic/src/transport/channel/service/tls.rs +++ b/tonic/src/transport/channel/service/tls.rs @@ -1,8 +1,9 @@ use std::fmt; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use hyper_util::rt::TokioIo; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time; use tokio_rustls::{ rustls::{ crypto, @@ -23,6 +24,7 @@ pub(crate) struct TlsConnector { config: Arc, domain: Arc>, assume_http2: bool, + timeout: Option, } impl TlsConnector { @@ -34,6 +36,7 @@ impl TlsConnector { assume_http2: bool, #[cfg(feature = "tls-native-roots")] with_native_roots: bool, #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool, + timeout: Option, ) -> Result { fn with_provider( provider: Arc, @@ -92,6 +95,7 @@ impl TlsConnector { config: Arc::new(config), domain: Arc::new(ServerName::try_from(domain)?.to_owned()), assume_http2, + timeout, }) } @@ -99,9 +103,14 @@ impl TlsConnector { where I: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - let io = RustlsConnector::from(self.config.clone()) - .connect(self.domain.as_ref().to_owned(), io) - .await?; + let conn_fut = + RustlsConnector::from(self.config.clone()).connect(self.domain.as_ref().to_owned(), io); + let io = match self.timeout { + Some(timeout) => time::timeout(timeout, conn_fut) + .await + .map_err(|_| TlsError::HandshakeTimeout)?, + None => conn_fut.await, + }?; // Generally we require ALPN to be negotiated, but if the user has // explicitly set `assume_http2` to true, we'll allow it to be missing. diff --git a/tonic/src/transport/channel/tls.rs b/tonic/src/transport/channel/tls.rs index 0c2eb37e0..2b893d50b 100644 --- a/tonic/src/transport/channel/tls.rs +++ b/tonic/src/transport/channel/tls.rs @@ -4,6 +4,7 @@ use crate::transport::{ Error, }; use http::Uri; +use std::time::Duration; use tokio_rustls::rustls::pki_types::TrustAnchor; /// Configures TLS settings for endpoints. @@ -18,6 +19,7 @@ pub struct ClientTlsConfig { with_native_roots: bool, #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool, + timeout: Option, } impl ClientTlsConfig { @@ -112,6 +114,14 @@ impl ClientTlsConfig { config } + /// Sets the timeout for the TLS handshake. + pub fn timeout(self, timeout: Duration) -> Self { + ClientTlsConfig { + timeout: Some(timeout), + ..self + } + } + pub(crate) fn into_tls_connector(self, uri: &Uri) -> Result { let domain = match &self.domain { Some(domain) => domain, @@ -127,6 +137,7 @@ impl ClientTlsConfig { self.with_native_roots, #[cfg(feature = "tls-webpki-roots")] self.with_webpki_roots, + self.timeout, ) } } diff --git a/tonic/src/transport/server/service/tls.rs b/tonic/src/transport/server/service/tls.rs index 395d5132b..798bed1d7 100644 --- a/tonic/src/transport/server/service/tls.rs +++ b/tonic/src/transport/server/service/tls.rs @@ -1,6 +1,7 @@ -use std::{fmt, sync::Arc}; +use std::{fmt, sync::Arc, time::Duration}; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time; use tokio_rustls::{ rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}, server::TlsStream, @@ -8,13 +9,16 @@ use tokio_rustls::{ }; use crate::transport::{ - service::tls::{convert_certificate_to_pki_types, convert_identity_to_pki_types, ALPN_H2}, + service::tls::{ + convert_certificate_to_pki_types, convert_identity_to_pki_types, TlsError, ALPN_H2, + }, Certificate, Identity, }; #[derive(Clone)] pub(crate) struct TlsAcceptor { inner: Arc, + timeout: Option, } impl TlsAcceptor { @@ -22,6 +26,7 @@ impl TlsAcceptor { identity: Identity, client_ca_root: Option, client_auth_optional: bool, + timeout: Option, ) -> Result { let builder = ServerConfig::builder(); @@ -46,6 +51,7 @@ impl TlsAcceptor { config.alpn_protocols.push(ALPN_H2.into()); Ok(Self { inner: Arc::new(config), + timeout, }) } @@ -54,7 +60,14 @@ impl TlsAcceptor { IO: AsyncRead + AsyncWrite + Unpin, { let acceptor = RustlsAcceptor::from(self.inner.clone()); - acceptor.accept(io).await.map_err(Into::into) + let accept_fut = acceptor.accept(io); + match self.timeout { + Some(timeout) => time::timeout(timeout, accept_fut) + .await + .map_err(|_| TlsError::HandshakeTimeout)?, + None => accept_fut.await, + } + .map_err(Into::into) } } diff --git a/tonic/src/transport/server/tls.rs b/tonic/src/transport/server/tls.rs index 331df8d31..c1f43e0a8 100644 --- a/tonic/src/transport/server/tls.rs +++ b/tonic/src/transport/server/tls.rs @@ -1,4 +1,4 @@ -use std::fmt; +use std::{fmt, time::Duration}; use super::service::TlsAcceptor; use crate::transport::tls::{Certificate, Identity}; @@ -9,6 +9,7 @@ pub struct ServerTlsConfig { identity: Option, client_ca_root: Option, client_auth_optional: bool, + timeout: Option, } impl fmt::Debug for ServerTlsConfig { @@ -24,6 +25,7 @@ impl ServerTlsConfig { identity: None, client_ca_root: None, client_auth_optional: false, + timeout: None, } } @@ -56,11 +58,20 @@ impl ServerTlsConfig { } } + /// Sets the timeout for the TLS handshake. + pub fn timeout(self, timeout: Duration) -> Self { + ServerTlsConfig { + timeout: Some(timeout), + ..self + } + } + pub(crate) fn tls_acceptor(&self) -> Result { TlsAcceptor::new( self.identity.clone().unwrap(), self.client_ca_root.clone(), self.client_auth_optional, + self.timeout, ) } } diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 8cb30c73c..0d6e9bc87 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -15,6 +15,7 @@ pub(crate) enum TlsError { NativeCertsNotFound, CertificateParseError, PrivateKeyParseError, + HandshakeTimeout, } impl fmt::Display for TlsError { @@ -29,6 +30,7 @@ impl fmt::Display for TlsError { f, "Error parsing TLS private key - no RSA or PKCS8-encoded keys found." ), + TlsError::HandshakeTimeout => write!(f, "TLS handshake timeout."), } } }