Skip to content

Commit

Permalink
chore(proxy): pre-load native tls certificates and propagate compute …
Browse files Browse the repository at this point in the history
…client config
  • Loading branch information
conradludgate committed Dec 18, 2024
1 parent d63602c commit ad091c6
Show file tree
Hide file tree
Showing 15 changed files with 211 additions and 160 deletions.
27 changes: 23 additions & 4 deletions proxy/src/bin/local_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
use proxy::auth::{self};
use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig};
use proxy::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
};
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use proxy::http::health_server::AppMetrics;
Expand All @@ -32,6 +34,8 @@ project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);

use clap::Parser;
use rustls::crypto::ring;
use rustls::RootCertStore;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
Expand Down Expand Up @@ -209,6 +213,7 @@ async fn main() -> anyhow::Result<()> {
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
&config.connect_to_compute,
Arc::new(DashMap::new()),
None,
proxy::metrics::CancellationSource::Local,
Expand Down Expand Up @@ -268,6 +273,22 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};

// local_proxy won't use TLS to talk to postgres.
let root_store = RootCertStore::empty();

let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();

let compute_config = ComputeConfig {
retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?,
tls: Arc::new(client_config),
timeout: Duration::from_secs(2),
};

Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
metric_collection: None,
Expand All @@ -289,9 +310,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute_retry_config: RetryConfig::parse(
RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES,
)?,
connect_to_compute: compute_config,
})))
}

Expand Down
42 changes: 37 additions & 5 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;

use anyhow::bail;
use anyhow::{bail, Context};
use futures::future::Either;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use proxy::cancellation::{CancelMap, CancellationHandler};
use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::context::parquet::ParquetUploadArgs;
Expand All @@ -25,6 +26,7 @@ use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use remote_storage::RemoteStorageConfig;
use rustls::crypto::ring;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
Expand Down Expand Up @@ -397,6 +399,7 @@ async fn main() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<Mutex<RedisPublisherClient>>>,
>::new(
&config.connect_to_compute,
cancel_map.clone(),
redis_publisher,
proxy::metrics::CancellationSource::FromClient,
Expand Down Expand Up @@ -492,6 +495,7 @@ async fn main() -> anyhow::Result<()> {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
config,
client,
cache.clone(),
cancel_map.clone(),
Expand All @@ -500,6 +504,7 @@ async fn main() -> anyhow::Result<()> {
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
config,
client,
cache.clone(),
cancel_map.clone(),
Expand Down Expand Up @@ -632,6 +637,23 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
};

let root_store = load_certs()
.context("loading native tls certificates")?
.clone();

let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();

let compute_config = ComputeConfig {
retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?,
tls: Arc::new(client_config),
timeout: Duration::from_secs(2),
};

let config = ProxyConfig {
tls_config,
metric_collection,
Expand All @@ -642,9 +664,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
connect_to_compute: compute_config,
};

let config = Box::leak(Box::new(config));
Expand All @@ -654,6 +674,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
Ok(config)
}

pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
let der_certs = rustls_native_certs::load_native_certs();

if !der_certs.errors.is_empty() {
bail!("could not parse certificates: {:?}", der_certs.errors);
}

let mut store = rustls::RootCertStore::empty();
store.add_parsable_certificates(der_certs.certs);
Ok(Arc::new(store))
}

/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &ProxyCliArgs,
Expand Down
85 changes: 56 additions & 29 deletions proxy/src/cancellation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@ use std::sync::Arc;

use dashmap::DashMap;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use once_cell::sync::OnceCell;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::CancelToken;
use pq_proto::CancelKeyData;
use rustls::crypto::ring;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tracing::{debug, info};
use uuid::Uuid;

use crate::auth::{check_peer_addr_is_in_list, IpPattern};
use crate::compute::load_certs;
use crate::config::ComputeConfig;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
Expand All @@ -35,6 +33,7 @@ type IpSubnetKey = IpNet;
///
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler<P> {
compute_config: &'static ComputeConfig,
map: CancelMap,
client: P,
/// This field used for the monitoring purposes.
Expand Down Expand Up @@ -183,7 +182,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
"cancelling query per user's request using key {key}, hostname {}, address: {}",
cancel_closure.hostname, cancel_closure.socket_addr
);
cancel_closure.try_cancel_query().await
cancel_closure.try_cancel_query(self.compute_config).await
}

#[cfg(test)]
Expand All @@ -198,8 +197,13 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}

impl CancellationHandler<()> {
pub fn new(map: CancelMap, from: CancellationSource) -> Self {
pub fn new(
compute_config: &'static ComputeConfig,
map: CancelMap,
from: CancellationSource,
) -> Self {
Self {
compute_config,
map,
client: (),
from,
Expand All @@ -214,8 +218,14 @@ impl CancellationHandler<()> {
}

impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
pub fn new(
compute_config: &'static ComputeConfig,
map: CancelMap,
client: Option<Arc<Mutex<P>>>,
from: CancellationSource,
) -> Self {
Self {
compute_config,
map,
client,
from,
Expand All @@ -229,8 +239,6 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
}
}

static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();

/// This should've been a [`std::future::Future`], but
/// it's impossible to name a type of an unboxed future
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
Expand All @@ -257,27 +265,13 @@ impl CancelClosure {
}
}
/// Cancels the query running on user's compute node.
pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
pub(crate) async fn try_cancel_query(
self,
compute_config: &ComputeConfig,
) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;

let root_store = TLS_ROOTS
.get_or_try_init(load_certs)
.map_err(|_e| {
CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
"TLS root store initialization failed".to_string(),
))
})?
.clone();

let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();

let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config);
let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone());
let tls = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
&mut mk_tls,
&self.hostname,
Expand Down Expand Up @@ -329,11 +323,41 @@ impl<P> Drop for Session<P> {
#[cfg(test)]
#[expect(clippy::unwrap_used)]
mod tests {
use std::time::Duration;

use rustls::crypto::ring;
use rustls::RootCertStore;

use super::*;
use crate::config::RetryConfig;

fn config() -> ComputeConfig {
let retry = RetryConfig {
base_delay: Duration::from_secs(1),
max_retries: 5,
backoff_factor: 2.0,
};

let root_store = RootCertStore::empty();

let client_config =
rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_safe_default_protocol_versions()
.expect("ring should support the default protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();

ComputeConfig {
retry,
tls: Arc::new(client_config),
timeout: Duration::from_secs(2),
}
}

#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
Box::leak(Box::new(config())),
CancelMap::default(),
CancellationSource::FromRedis,
));
Expand All @@ -349,8 +373,11 @@ mod tests {

#[tokio::test]
async fn cancel_session_noop_regression() {
let handler =
CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
let handler = CancellationHandler::<()>::new(
Box::leak(Box::new(config())),
CancelMap::default(),
CancellationSource::Local,
);
handler
.cancel_session(
CancelKeyData {
Expand Down
Loading

0 comments on commit ad091c6

Please sign in to comment.