diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index 38b42e15a3e9..644f670f8899 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -27,6 +27,7 @@ use proxy::rate_limiter::{ use proxy::scram::threadpool::ThreadPool; use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::{self, GlobalConnPoolOptions}; +use proxy::tls::client_config::compute_client_config_with_root_certs; use proxy::types::RoleName; use proxy::url::ApiUrl; @@ -34,8 +35,6 @@ project_git_version!(GIT_VERSION); project_build_tag!(BUILD_TAG); use clap::Parser; -use rustls::crypto::ring; -use rustls::RootCertStore; use thiserror::Error; use tokio::net::TcpListener; use tokio::sync::Notify; @@ -273,19 +272,9 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, }; - // local_proxy won't use TLS to talk to postgres. - let root_store = RootCertStore::empty(); - - let client_config = - rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_safe_default_protocol_versions() - .expect("ring should support the default protocol versions") - .with_root_certificates(root_store) - .with_no_client_auth(); - let compute_config = ComputeConfig { retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?, - tls: Arc::new(client_config), + tls: Arc::new(compute_client_config_with_root_certs()?), timeout: Duration::from_secs(2), }; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 1dace2ec8f97..3b122d771cb1 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -3,7 +3,7 @@ use std::pin::pin; use std::sync::Arc; use std::time::Duration; -use anyhow::{bail, Context}; +use anyhow::bail; use futures::future::Either; use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; @@ -24,9 +24,9 @@ use proxy::redis::{elasticache, notifications}; use proxy::scram::threadpool::ThreadPool; use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::GlobalConnPoolOptions; +use proxy::tls::client_config::compute_client_config_with_root_certs; use proxy::{auth, control_plane, http, serverless, usage_metrics}; use remote_storage::RemoteStorageConfig; -use rustls::crypto::ring; use tokio::net::TcpListener; use tokio::sync::Mutex; use tokio::task::JoinSet; @@ -637,20 +637,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { console_redirect_confirmation_timeout: args.webauth_confirmation_timeout, }; - let root_store = load_certs() - .context("loading native tls certificates")? - .clone(); - - let client_config = - rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_safe_default_protocol_versions() - .expect("ring should support the default protocol versions") - .with_root_certificates(root_store) - .with_no_client_auth(); - let compute_config = ComputeConfig { retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?, - tls: Arc::new(client_config), + tls: Arc::new(compute_client_config_with_root_certs()?), timeout: Duration::from_secs(2), }; @@ -674,18 +663,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { Ok(config) } -pub(crate) fn load_certs() -> anyhow::Result> { - let der_certs = rustls_native_certs::load_native_certs(); - - if !der_certs.errors.is_empty() { - bail!("could not parse certificates: {:?}", der_certs.errors); - } - - let mut store = rustls::RootCertStore::empty(); - store.add_parsable_certificates(der_certs.certs); - Ok(Arc::new(store)) -} - /// auth::Backend is created at proxy startup, and lives forever. fn build_auth_backend( args: &ProxyCliArgs, diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 1c4860e0c325..df618cf24257 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -326,11 +326,9 @@ impl

Drop for Session

{ mod tests { use std::time::Duration; - use rustls::crypto::ring; - use rustls::RootCertStore; - use super::*; use crate::config::RetryConfig; + use crate::tls::client_config::compute_client_config_with_certs; fn config() -> ComputeConfig { let retry = RetryConfig { @@ -339,18 +337,9 @@ mod tests { backoff_factor: 2.0, }; - let root_store = RootCertStore::empty(); - - let client_config = - rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_safe_default_protocol_versions() - .expect("ring should support the default protocol versions") - .with_root_certificates(root_store) - .with_no_client_auth(); - ComputeConfig { retry, - tls: Arc::new(client_config), + tls: Arc::new(compute_client_config_with_certs(std::iter::empty())), timeout: Duration::from_secs(2), } } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 148e84937278..10db2bcb303f 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -13,7 +13,7 @@ use postgres_client::tls::{MakeTlsConnect, NoTls}; use retry::{retry_after, ShouldRetryWakeCompute}; use rstest::rstest; use rustls::crypto::ring; -use rustls::{pki_types, RootCertStore}; +use rustls::pki_types; use tokio::io::DuplexStream; use super::connect_compute::ConnectMechanism; @@ -29,6 +29,7 @@ use crate::control_plane::{ self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache, }; use crate::error::ErrorKind; +use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::tls::server_config::CertResolver; use crate::types::{BranchId, EndpointId, ProjectId}; @@ -111,17 +112,7 @@ fn generate_tls_config<'a>( }; let client_config = { - let config = - rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_safe_default_protocol_versions() - .context("ring should support the default protocol versions")? - .with_root_certificates({ - let mut store = rustls::RootCertStore::empty(); - store.add(ca)?; - store - }) - .with_no_client_auth(); - let config = Arc::new(config); + let config = Arc::new(compute_client_config_with_certs([ca])); ClientConfig { config, hostname } }; @@ -585,18 +576,9 @@ fn config() -> ComputeConfig { backoff_factor: 2.0, }; - let root_store = RootCertStore::empty(); - - let client_config = - rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_safe_default_protocol_versions() - .expect("ring should support the default protocol versions") - .with_root_certificates(root_store) - .with_no_client_auth(); - ComputeConfig { retry, - tls: Arc::new(client_config), + tls: Arc::new(compute_client_config_with_certs(std::iter::empty())), timeout: Duration::from_secs(2), } } diff --git a/proxy/src/tls/client_config.rs b/proxy/src/tls/client_config.rs new file mode 100644 index 000000000000..a2d695aae11d --- /dev/null +++ b/proxy/src/tls/client_config.rs @@ -0,0 +1,42 @@ +use std::sync::Arc; + +use anyhow::bail; +use rustls::crypto::ring; + +pub(crate) fn load_certs() -> anyhow::Result> { + let der_certs = rustls_native_certs::load_native_certs(); + + if !der_certs.errors.is_empty() { + bail!("could not parse certificates: {:?}", der_certs.errors); + } + + let mut store = rustls::RootCertStore::empty(); + store.add_parsable_certificates(der_certs.certs); + Ok(Arc::new(store)) +} + +/// Loads the root certificates and constructs a client config suitable for connecting to the neon compute. +/// This function is blocking. +pub fn compute_client_config_with_root_certs() -> anyhow::Result { + Ok( + rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) + .with_safe_default_protocol_versions() + .expect("ring should support the default protocol versions") + .with_root_certificates(load_certs()?) + .with_no_client_auth(), + ) +} + +#[cfg(test)] +pub fn compute_client_config_with_certs( + certs: impl IntoIterator>, +) -> rustls::ClientConfig { + let mut store = rustls::RootCertStore::empty(); + store.add_parsable_certificates(certs); + + rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) + .with_safe_default_protocol_versions() + .expect("ring should support the default protocol versions") + .with_root_certificates(store) + .with_no_client_auth() +} diff --git a/proxy/src/tls/mod.rs b/proxy/src/tls/mod.rs index 2071ded23d92..d6ce6bd9fcf4 100644 --- a/proxy/src/tls/mod.rs +++ b/proxy/src/tls/mod.rs @@ -1,3 +1,4 @@ +pub mod client_config; pub mod postgres_rustls; pub mod server_config;