diff --git a/docker-compose.yml b/docker-compose.yml index 0ed44148d..991df2d01 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ version: '2' services: postgres: - image: postgres:14 + image: docker.io/postgres:17 ports: - 5433:5433 volumes: diff --git a/postgres-native-tls/Cargo.toml b/postgres-native-tls/Cargo.toml index 02259b3dc..e86c7ce2d 100644 --- a/postgres-native-tls/Cargo.toml +++ b/postgres-native-tls/Cargo.toml @@ -16,7 +16,7 @@ default = ["runtime"] runtime = ["tokio-postgres/runtime"] [dependencies] -native-tls = "0.2" +native-tls = { version = "0.2", features = ["alpn"] } tokio = "1.0" tokio-native-tls = "0.3" tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-features = false } diff --git a/postgres-native-tls/src/lib.rs b/postgres-native-tls/src/lib.rs index a06f185b5..9ee7da653 100644 --- a/postgres-native-tls/src/lib.rs +++ b/postgres-native-tls/src/lib.rs @@ -53,6 +53,7 @@ //! ``` #![warn(rust_2018_idioms, clippy::all, missing_docs)] +use native_tls::TlsConnectorBuilder; use std::future::Future; use std::io; use std::pin::Pin; @@ -180,3 +181,10 @@ where } } } + +/// Set ALPN for `TlsConnectorBuilder` +/// +/// This is required when using `sslnegotiation=direct` +pub fn set_postgresql_alpn(builder: &mut TlsConnectorBuilder) { + builder.request_alpns(&["postgresql"]); +} diff --git a/postgres-native-tls/src/test.rs b/postgres-native-tls/src/test.rs index 25cc6fdbd..b34fa7351 100644 --- a/postgres-native-tls/src/test.rs +++ b/postgres-native-tls/src/test.rs @@ -5,7 +5,7 @@ use tokio_postgres::tls::TlsConnect; #[cfg(feature = "runtime")] use crate::MakeTlsConnector; -use crate::TlsConnector; +use crate::{set_postgresql_alpn, TlsConnector}; async fn smoke_test(s: &str, tls: T) where @@ -42,6 +42,20 @@ async fn require() { .await; } +#[tokio::test] +async fn direct() { + let connector = set_postgresql_alpn(native_tls::TlsConnector::builder().add_root_certificate( + Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(), + )) + .build() + .unwrap(); + smoke_test( + "user=ssl_user dbname=postgres sslmode=require sslnegotiation=direct", + TlsConnector::new(connector, "localhost"), + ) + .await; +} + #[tokio::test] async fn prefer() { let connector = native_tls::TlsConnector::builder() diff --git a/postgres-openssl/src/lib.rs b/postgres-openssl/src/lib.rs index 837663fe7..232cccd05 100644 --- a/postgres-openssl/src/lib.rs +++ b/postgres-openssl/src/lib.rs @@ -53,7 +53,7 @@ use openssl::hash::MessageDigest; use openssl::nid::Nid; #[cfg(feature = "runtime")] use openssl::ssl::SslConnector; -use openssl::ssl::{self, ConnectConfiguration, SslRef}; +use openssl::ssl::{self, ConnectConfiguration, SslConnectorBuilder, SslRef}; use openssl::x509::X509VerifyResult; use std::error::Error; use std::fmt::{self, Debug}; @@ -250,3 +250,10 @@ fn tls_server_end_point(ssl: &SslRef) -> Option> { }; cert.digest(md).ok().map(|b| b.to_vec()) } + +/// Set ALPN for `SslConnectorBuilder` +/// +/// This is required when using `sslnegotiation=direct` +pub fn set_postgresql_alpn(builder: &mut SslConnectorBuilder) -> Result<(), ErrorStack> { + builder.set_alpn_protos(b"\x0apostgresql") +} diff --git a/postgres-openssl/src/test.rs b/postgres-openssl/src/test.rs index b361ee446..66bb22641 100644 --- a/postgres-openssl/src/test.rs +++ b/postgres-openssl/src/test.rs @@ -37,6 +37,19 @@ async fn require() { .await; } +#[tokio::test] +async fn direct() { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("../test/server.crt").unwrap(); + set_postgresql_alpn(&mut builder).unwrap(); + let ctx = builder.build(); + smoke_test( + "user=ssl_user dbname=postgres sslmode=require sslnegotiation=direct", + TlsConnector::new(ctx.configure().unwrap(), "localhost"), + ) + .await; +} + #[tokio::test] async fn prefer() { let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); diff --git a/postgres/src/config.rs b/postgres/src/config.rs index a32ddc78e..ae710d16b 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -12,7 +12,7 @@ use std::time::Duration; use tokio::runtime; #[doc(inline)] pub use tokio_postgres::config::{ - ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs, + ChannelBinding, Host, LoadBalanceHosts, SslMode, SslNegotiation, TargetSessionAttrs, }; use tokio_postgres::error::DbError; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; @@ -40,6 +40,9 @@ use tokio_postgres::{Error, Socket}; /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `sslnegotiation` - TLS negotiation method. If set to `direct`, the client will perform direct TLS handshake, this only works for PostgreSQL 17 and newer. +/// Note that you will need to setup ALPN of TLS client configuration to `postgresql` when using direct TLS. +/// If set to `postgres`, the default value, it follows original postgres wire protocol to perform the negotiation. /// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, /// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. /// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, @@ -230,6 +233,17 @@ impl Config { self.config.get_ssl_mode() } + /// Sets the SSL negotiation method + pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config { + self.config.ssl_negotiation(ssl_negotiation); + self + } + + /// Gets the SSL negotiation method + pub fn get_ssl_negotiation(&self) -> SslNegotiation { + self.config.get_ssl_negotiation() + } + /// Adds a host to the configuration. /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs index 078d4b8b6..2dfd47c06 100644 --- a/tokio-postgres/src/cancel_query.rs +++ b/tokio-postgres/src/cancel_query.rs @@ -1,5 +1,5 @@ use crate::client::SocketConfig; -use crate::config::SslMode; +use crate::config::{SslMode, SslNegotiation}; use crate::tls::MakeTlsConnect; use crate::{cancel_query_raw, connect_socket, Error, Socket}; use std::io; @@ -7,6 +7,7 @@ use std::io; pub(crate) async fn cancel_query( config: Option, ssl_mode: SslMode, + ssl_negotiation: SslNegotiation, mut tls: T, process_id: i32, secret_key: i32, @@ -38,6 +39,14 @@ where ) .await?; - cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, has_hostname, process_id, secret_key) - .await + cancel_query_raw::cancel_query_raw( + socket, + ssl_mode, + ssl_negotiation, + tls, + has_hostname, + process_id, + secret_key, + ) + .await } diff --git a/tokio-postgres/src/cancel_query_raw.rs b/tokio-postgres/src/cancel_query_raw.rs index 41aafe7d9..886606497 100644 --- a/tokio-postgres/src/cancel_query_raw.rs +++ b/tokio-postgres/src/cancel_query_raw.rs @@ -1,4 +1,4 @@ -use crate::config::SslMode; +use crate::config::{SslMode, SslNegotiation}; use crate::tls::TlsConnect; use crate::{connect_tls, Error}; use bytes::BytesMut; @@ -8,6 +8,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; pub async fn cancel_query_raw( stream: S, mode: SslMode, + negotiation: SslNegotiation, tls: T, has_hostname: bool, process_id: i32, @@ -17,7 +18,7 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let mut stream = connect_tls::connect_tls(stream, mode, tls, has_hostname).await?; + let mut stream = connect_tls::connect_tls(stream, mode, negotiation, tls, has_hostname).await?; let mut buf = BytesMut::new(); frontend::cancel_request(process_id, secret_key, &mut buf); diff --git a/tokio-postgres/src/cancel_token.rs b/tokio-postgres/src/cancel_token.rs index c925ce0ca..1652bec72 100644 --- a/tokio-postgres/src/cancel_token.rs +++ b/tokio-postgres/src/cancel_token.rs @@ -1,4 +1,4 @@ -use crate::config::SslMode; +use crate::config::{SslMode, SslNegotiation}; use crate::tls::TlsConnect; #[cfg(feature = "runtime")] use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect, Socket}; @@ -12,6 +12,7 @@ pub struct CancelToken { #[cfg(feature = "runtime")] pub(crate) socket_config: Option, pub(crate) ssl_mode: SslMode, + pub(crate) ssl_negotiation: SslNegotiation, pub(crate) process_id: i32, pub(crate) secret_key: i32, } @@ -37,6 +38,7 @@ impl CancelToken { cancel_query::cancel_query( self.socket_config.clone(), self.ssl_mode, + self.ssl_negotiation, tls, self.process_id, self.secret_key, @@ -54,6 +56,7 @@ impl CancelToken { cancel_query_raw::cancel_query_raw( stream, self.ssl_mode, + self.ssl_negotiation, tls, true, self.process_id, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 92eabde36..b38bbba37 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,5 +1,5 @@ use crate::codec::BackendMessages; -use crate::config::SslMode; +use crate::config::{SslMode, SslNegotiation}; use crate::connection::{Request, RequestMessages}; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] @@ -180,6 +180,7 @@ pub struct Client { #[cfg(feature = "runtime")] socket_config: Option, ssl_mode: SslMode, + ssl_negotiation: SslNegotiation, process_id: i32, secret_key: i32, } @@ -188,6 +189,7 @@ impl Client { pub(crate) fn new( sender: mpsc::UnboundedSender, ssl_mode: SslMode, + ssl_negotiation: SslNegotiation, process_id: i32, secret_key: i32, ) -> Client { @@ -200,6 +202,7 @@ impl Client { #[cfg(feature = "runtime")] socket_config: None, ssl_mode, + ssl_negotiation, process_id, secret_key, } @@ -550,6 +553,7 @@ impl Client { #[cfg(feature = "runtime")] socket_config: self.socket_config.clone(), ssl_mode: self.ssl_mode, + ssl_negotiation: self.ssl_negotiation, process_id: self.process_id, secret_key: self.secret_key, } diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 62b45f793..fb673b9b9 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -50,6 +50,16 @@ pub enum SslMode { Require, } +/// TLS negotiation configuration +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum SslNegotiation { + /// Use PostgreSQL SslRequest for Ssl negotiation + Postgres, + /// Start Ssl handshake without negotiation, only works for PostgreSQL 17+ + Direct, +} + /// Channel binding configuration. #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[non_exhaustive] @@ -106,6 +116,9 @@ pub enum Host { /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `sslnegotiation` - TLS negotiation method. If set to `direct`, the client will perform direct TLS handshake, this only works for PostgreSQL 17 and newer. +/// Note that you will need to setup ALPN of TLS client configuration to `postgresql` when using direct TLS. +/// If set to `postgres`, the default value, it follows original postgres wire protocol to perform the negotiation. /// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, /// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. /// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, @@ -198,6 +211,7 @@ pub struct Config { pub(crate) options: Option, pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, + pub(crate) ssl_negotiation: SslNegotiation, pub(crate) host: Vec, pub(crate) hostaddr: Vec, pub(crate) port: Vec, @@ -227,6 +241,7 @@ impl Config { options: None, application_name: None, ssl_mode: SslMode::Prefer, + ssl_negotiation: SslNegotiation::Postgres, host: vec![], hostaddr: vec![], port: vec![], @@ -325,6 +340,19 @@ impl Config { self.ssl_mode } + /// Sets the SSL negotiation method. + /// + /// Defaults to `postgres`. + pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config { + self.ssl_negotiation = ssl_negotiation; + self + } + + /// Gets the SSL negotiation method. + pub fn get_ssl_negotiation(&self) -> SslNegotiation { + self.ssl_negotiation + } + /// Adds a host to the configuration. /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix @@ -550,6 +578,18 @@ impl Config { }; self.ssl_mode(mode); } + "sslnegotiation" => { + let mode = match value { + "postgres" => SslNegotiation::Postgres, + "direct" => SslNegotiation::Direct, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "sslnegotiation", + )))) + } + }; + self.ssl_negotiation(mode); + } "host" => { for host in value.split(',') { self.host(host); diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 19be9eb01..cf7476cab 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -89,7 +89,14 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?; + let stream = connect_tls( + stream, + config.ssl_mode, + config.ssl_negotiation, + tls, + has_hostname, + ) + .await?; let mut stream = StartupStream { inner: Framed::new(stream, PostgresCodec), @@ -107,7 +114,13 @@ where let (process_id, secret_key, parameters) = read_info(&mut stream).await?; let (sender, receiver) = mpsc::unbounded(); - let client = Client::new(sender, config.ssl_mode, process_id, secret_key); + let client = Client::new( + sender, + config.ssl_mode, + config.ssl_negotiation, + process_id, + secret_key, + ); let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver); Ok((client, connection)) diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs index 2b1229125..c7a093064 100644 --- a/tokio-postgres/src/connect_tls.rs +++ b/tokio-postgres/src/connect_tls.rs @@ -1,4 +1,4 @@ -use crate::config::SslMode; +use crate::config::{SslMode, SslNegotiation}; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::private::ForcePrivateApi; use crate::tls::TlsConnect; @@ -10,6 +10,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub async fn connect_tls( mut stream: S, mode: SslMode, + negotiation: SslNegotiation, tls: T, has_hostname: bool, ) -> Result, Error> @@ -25,18 +26,20 @@ where SslMode::Prefer | SslMode::Require => {} } - let mut buf = BytesMut::new(); - frontend::ssl_request(&mut buf); - stream.write_all(&buf).await.map_err(Error::io)?; + if negotiation == SslNegotiation::Postgres { + let mut buf = BytesMut::new(); + frontend::ssl_request(&mut buf); + stream.write_all(&buf).await.map_err(Error::io)?; - let mut buf = [0]; - stream.read_exact(&mut buf).await.map_err(Error::io)?; + let mut buf = [0]; + stream.read_exact(&mut buf).await.map_err(Error::io)?; - if buf[0] != b'S' { - if SslMode::Require == mode { - return Err(Error::tls("server does not support TLS".into())); - } else { - return Ok(MaybeTlsStream::Raw(stream)); + if buf[0] != b'S' { + if SslMode::Require == mode { + return Err(Error::tls("server does not support TLS".into())); + } else { + return Ok(MaybeTlsStream::Raw(stream)); + } } } diff --git a/tokio-postgres/tests/test/parse.rs b/tokio-postgres/tests/test/parse.rs index 04d422e27..35eeca72b 100644 --- a/tokio-postgres/tests/test/parse.rs +++ b/tokio-postgres/tests/test/parse.rs @@ -1,5 +1,5 @@ use std::time::Duration; -use tokio_postgres::config::{Config, TargetSessionAttrs}; +use tokio_postgres::config::{Config, SslNegotiation, TargetSessionAttrs}; fn check(s: &str, config: &Config) { assert_eq!(s.parse::().expect(s), *config, "`{}`", s); @@ -42,6 +42,10 @@ fn settings() { .keepalives_idle(Duration::from_secs(30)) .target_session_attrs(TargetSessionAttrs::ReadOnly), ); + check( + "sslnegotiation=direct", + Config::new().ssl_negotiation(SslNegotiation::Direct), + ); } #[test]