Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sslnegotiation and direct ssl for postgres 17 #1151

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: '2'
services:
postgres:
image: postgres:14
image: docker.io/postgres:17
ports:
- 5433:5433
volumes:
Expand Down
2 changes: 1 addition & 1 deletion postgres-native-tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ default = ["runtime"]
runtime = ["tokio-postgres/runtime"]

[dependencies]
native-tls = "0.2"
native-tls = { version = "0.2", features = ["alpn"] }
tokio = "1.0"
tokio-native-tls = "0.3"
tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-features = false }
Expand Down
8 changes: 8 additions & 0 deletions postgres-native-tls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
//! ```
#![warn(rust_2018_idioms, clippy::all, missing_docs)]

use native_tls::TlsConnectorBuilder;
use std::future::Future;
use std::io;
use std::pin::Pin;
Expand Down Expand Up @@ -180,3 +181,10 @@ where
}
}
}

/// Set ALPN for `TlsConnectorBuilder`
///
/// This is required when using `sslnegotiation=direct`
pub fn set_postgresql_alpn(builder: &mut TlsConnectorBuilder) {
builder.request_alpns(&["postgresql"]);
}
16 changes: 15 additions & 1 deletion postgres-native-tls/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tokio_postgres::tls::TlsConnect;

#[cfg(feature = "runtime")]
use crate::MakeTlsConnector;
use crate::TlsConnector;
use crate::{set_postgresql_alpn, TlsConnector};

async fn smoke_test<T>(s: &str, tls: T)
where
Expand Down Expand Up @@ -42,6 +42,20 @@ async fn require() {
.await;
}

#[tokio::test]
async fn direct() {
let connector = set_postgresql_alpn(native_tls::TlsConnector::builder().add_root_certificate(
Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(),
))
.build()
.unwrap();
smoke_test(
"user=ssl_user dbname=postgres sslmode=require sslnegotiation=direct",
TlsConnector::new(connector, "localhost"),
)
.await;
}

#[tokio::test]
async fn prefer() {
let connector = native_tls::TlsConnector::builder()
Expand Down
9 changes: 8 additions & 1 deletion postgres-openssl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use openssl::hash::MessageDigest;
use openssl::nid::Nid;
#[cfg(feature = "runtime")]
use openssl::ssl::SslConnector;
use openssl::ssl::{self, ConnectConfiguration, SslRef};
use openssl::ssl::{self, ConnectConfiguration, SslConnectorBuilder, SslRef};
use openssl::x509::X509VerifyResult;
use std::error::Error;
use std::fmt::{self, Debug};
Expand Down Expand Up @@ -250,3 +250,10 @@ fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
};
cert.digest(md).ok().map(|b| b.to_vec())
}

/// Set ALPN for `SslConnectorBuilder`
///
/// This is required when using `sslnegotiation=direct`
pub fn set_postgresql_alpn(builder: &mut SslConnectorBuilder) -> Result<(), ErrorStack> {
builder.set_alpn_protos(b"\x0apostgresql")
}
13 changes: 13 additions & 0 deletions postgres-openssl/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ async fn require() {
.await;
}

#[tokio::test]
async fn direct() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
set_postgresql_alpn(&mut builder).unwrap();
let ctx = builder.build();
smoke_test(
"user=ssl_user dbname=postgres sslmode=require sslnegotiation=direct",
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
)
.await;
}

#[tokio::test]
async fn prefer() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
Expand Down
16 changes: 15 additions & 1 deletion postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::time::Duration;
use tokio::runtime;
#[doc(inline)]
pub use tokio_postgres::config::{
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
ChannelBinding, Host, LoadBalanceHosts, SslMode, SslNegotiation, TargetSessionAttrs,
};
use tokio_postgres::error::DbError;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
Expand Down Expand Up @@ -40,6 +40,9 @@ use tokio_postgres::{Error, Socket};
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
/// with the `connect` method.
/// * `sslnegotiation` - TLS negotiation method. If set to `direct`, the client will perform direct TLS handshake, this only works for PostgreSQL 17 and newer.
/// Note that you will need to setup ALPN of TLS client configuration to `postgresql` when using direct TLS.
/// If set to `postgres`, the default value, it follows original postgres wire protocol to perform the negotiation.
sunng87 marked this conversation as resolved.
Show resolved Hide resolved
/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format,
/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses.
/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address,
Expand Down Expand Up @@ -230,6 +233,17 @@ impl Config {
self.config.get_ssl_mode()
}

/// Sets the SSL negotiation method
pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
self.config.ssl_negotiation(ssl_negotiation);
self
}

/// Gets the SSL negotiation method
pub fn get_ssl_negotiation(&self) -> SslNegotiation {
self.config.get_ssl_negotiation()
}

/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
Expand Down
15 changes: 12 additions & 3 deletions tokio-postgres/src/cancel_query.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::client::SocketConfig;
use crate::config::SslMode;
use crate::config::{SslMode, SslNegotiation};
use crate::tls::MakeTlsConnect;
use crate::{cancel_query_raw, connect_socket, Error, Socket};
use std::io;

pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
ssl_mode: SslMode,
ssl_negotiation: SslNegotiation,
mut tls: T,
process_id: i32,
secret_key: i32,
Expand Down Expand Up @@ -38,6 +39,14 @@ where
)
.await?;

cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, has_hostname, process_id, secret_key)
.await
cancel_query_raw::cancel_query_raw(
socket,
ssl_mode,
ssl_negotiation,
tls,
has_hostname,
process_id,
secret_key,
)
.await
}
5 changes: 3 additions & 2 deletions tokio-postgres/src/cancel_query_raw.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::config::SslMode;
use crate::config::{SslMode, SslNegotiation};
use crate::tls::TlsConnect;
use crate::{connect_tls, Error};
use bytes::BytesMut;
Expand All @@ -8,6 +8,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
pub async fn cancel_query_raw<S, T>(
stream: S,
mode: SslMode,
negotiation: SslNegotiation,
tls: T,
has_hostname: bool,
process_id: i32,
Expand All @@ -17,7 +18,7 @@ where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let mut stream = connect_tls::connect_tls(stream, mode, tls, has_hostname).await?;
let mut stream = connect_tls::connect_tls(stream, mode, negotiation, tls, has_hostname).await?;

let mut buf = BytesMut::new();
frontend::cancel_request(process_id, secret_key, &mut buf);
Expand Down
5 changes: 4 additions & 1 deletion tokio-postgres/src/cancel_token.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::config::SslMode;
use crate::config::{SslMode, SslNegotiation};
use crate::tls::TlsConnect;
#[cfg(feature = "runtime")]
use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect, Socket};
Expand All @@ -12,6 +12,7 @@ pub struct CancelToken {
#[cfg(feature = "runtime")]
pub(crate) socket_config: Option<SocketConfig>,
pub(crate) ssl_mode: SslMode,
pub(crate) ssl_negotiation: SslNegotiation,
pub(crate) process_id: i32,
pub(crate) secret_key: i32,
}
Expand All @@ -37,6 +38,7 @@ impl CancelToken {
cancel_query::cancel_query(
self.socket_config.clone(),
self.ssl_mode,
self.ssl_negotiation,
tls,
self.process_id,
self.secret_key,
Expand All @@ -54,6 +56,7 @@ impl CancelToken {
cancel_query_raw::cancel_query_raw(
stream,
self.ssl_mode,
self.ssl_negotiation,
tls,
true,
self.process_id,
Expand Down
6 changes: 5 additions & 1 deletion tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::codec::BackendMessages;
use crate::config::SslMode;
use crate::config::{SslMode, SslNegotiation};
use crate::connection::{Request, RequestMessages};
use crate::copy_out::CopyOutStream;
#[cfg(feature = "runtime")]
Expand Down Expand Up @@ -180,6 +180,7 @@ pub struct Client {
#[cfg(feature = "runtime")]
socket_config: Option<SocketConfig>,
ssl_mode: SslMode,
ssl_negotiation: SslNegotiation,
process_id: i32,
secret_key: i32,
}
Expand All @@ -188,6 +189,7 @@ impl Client {
pub(crate) fn new(
sender: mpsc::UnboundedSender<Request>,
ssl_mode: SslMode,
ssl_negotiation: SslNegotiation,
process_id: i32,
secret_key: i32,
) -> Client {
Expand All @@ -200,6 +202,7 @@ impl Client {
#[cfg(feature = "runtime")]
socket_config: None,
ssl_mode,
ssl_negotiation,
process_id,
secret_key,
}
Expand Down Expand Up @@ -550,6 +553,7 @@ impl Client {
#[cfg(feature = "runtime")]
socket_config: self.socket_config.clone(),
ssl_mode: self.ssl_mode,
ssl_negotiation: self.ssl_negotiation,
process_id: self.process_id,
secret_key: self.secret_key,
}
Expand Down
40 changes: 40 additions & 0 deletions tokio-postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ pub enum SslMode {
Require,
}

/// TLS negotiation configuration
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum SslNegotiation {
/// Use PostgreSQL SslRequest for Ssl negotiation
Postgres,
/// Start Ssl handshake without negotiation, only works for PostgreSQL 17+
Direct,
}

/// Channel binding configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
Expand Down Expand Up @@ -106,6 +116,9 @@ pub enum Host {
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
/// with the `connect` method.
/// * `sslnegotiation` - TLS negotiation method. If set to `direct`, the client will perform direct TLS handshake, this only works for PostgreSQL 17 and newer.
/// Note that you will need to setup ALPN of TLS client configuration to `postgresql` when using direct TLS.
/// If set to `postgres`, the default value, it follows original postgres wire protocol to perform the negotiation.
/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format,
/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses.
/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address,
Expand Down Expand Up @@ -198,6 +211,7 @@ pub struct Config {
pub(crate) options: Option<String>,
pub(crate) application_name: Option<String>,
pub(crate) ssl_mode: SslMode,
pub(crate) ssl_negotiation: SslNegotiation,
pub(crate) host: Vec<Host>,
pub(crate) hostaddr: Vec<IpAddr>,
pub(crate) port: Vec<u16>,
Expand Down Expand Up @@ -227,6 +241,7 @@ impl Config {
options: None,
application_name: None,
ssl_mode: SslMode::Prefer,
ssl_negotiation: SslNegotiation::Postgres,
host: vec![],
hostaddr: vec![],
port: vec![],
Expand Down Expand Up @@ -325,6 +340,19 @@ impl Config {
self.ssl_mode
}

/// Sets the SSL negotiation method.
///
/// Defaults to `postgres`.
pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
self.ssl_negotiation = ssl_negotiation;
self
}

/// Gets the SSL negotiation method.
pub fn get_ssl_negotiation(&self) -> SslNegotiation {
self.ssl_negotiation
}

/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
Expand Down Expand Up @@ -550,6 +578,18 @@ impl Config {
};
self.ssl_mode(mode);
}
"sslnegotiation" => {
let mode = match value {
"postgres" => SslNegotiation::Postgres,
"direct" => SslNegotiation::Direct,
_ => {
return Err(Error::config_parse(Box::new(InvalidValue(
"sslnegotiation",
))))
}
};
self.ssl_negotiation(mode);
}
"host" => {
for host in value.split(',') {
self.host(host);
Expand Down
17 changes: 15 additions & 2 deletions tokio-postgres/src/connect_raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?;
let stream = connect_tls(
stream,
config.ssl_mode,
config.ssl_negotiation,
tls,
has_hostname,
)
.await?;

let mut stream = StartupStream {
inner: Framed::new(stream, PostgresCodec),
Expand All @@ -107,7 +114,13 @@ where
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;

let (sender, receiver) = mpsc::unbounded();
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
let client = Client::new(
sender,
config.ssl_mode,
config.ssl_negotiation,
process_id,
secret_key,
);
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);

Ok((client, connection))
Expand Down
Loading