diff --git a/Cargo.lock b/Cargo.lock index 67e9d12531..1812920de3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1193,6 +1193,19 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +[[package]] +name = "ease-off" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20e90ae5e739d99dc0406f9a4e2307a999625e2414d2ecc4fbb4ded8a3945f77" +dependencies = [ + "async-io 2.3.2", + "pin-project", + "rand", + "thiserror", + "tokio", +] + [[package]] name = "either" version = "1.10.0" @@ -2416,18 +2429,18 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pin-project" -version = "1.1.5" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.5" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", @@ -2436,9 +2449,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" [[package]] name = "pin-utils" @@ -3375,10 +3388,10 @@ dependencies = [ "chrono", "crc", "crossbeam-queue", + "ease-off", "either", "event-listener 5.2.0", "futures-core", - "futures-intrusive", "futures-io", "futures-util", "hashbrown 0.14.5", @@ -3391,6 +3404,7 @@ dependencies = [ "native-tls", "once_cell", "percent-encoding", + "pin-project-lite", "regex", "rust_decimal", "rustls", @@ -3847,18 +3861,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "thiserror" -version = "1.0.58" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.58" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index bf0a867e1e..2d55562199 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,8 +35,7 @@ authors = [ "Chloe Ross ", "Daniel Akhterov ", ] -# TODO: enable this for 0.9.0 -# rust-version = "1.80.0" +rust-version = "1.80.0" [package] name = "sqlx" @@ -48,6 +47,7 @@ license.workspace = true edition.workspace = true authors.workspace = true repository.workspace = true +rust-version.workspace = true [package.metadata.docs.rs] features = ["all-databases", "_unstable-all-types"] @@ -147,6 +147,7 @@ uuid = "1.1.2" # Common utility crates dotenvy = { version = "0.15.0", default-features = false } +ease-off = "0.1.6" # Runtimes [workspace.dependencies.async-std] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 29f0b09695..16beae9c11 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,4 +1,4 @@ # Note: should NOT increase during a minor/patch release cycle [toolchain] -channel = "1.78" +channel = "1.80" profile = "minimal" diff --git a/sqlx-cli/Cargo.toml b/sqlx-cli/Cargo.toml index 0b047ab136..7d3bf89485 100644 --- a/sqlx-cli/Cargo.toml +++ b/sqlx-cli/Cargo.toml @@ -56,14 +56,17 @@ native-tls = ["sqlx/runtime-tokio-native-tls"] # databases mysql = ["sqlx/mysql"] postgres = ["sqlx/postgres"] -sqlite = ["sqlx/sqlite"] -sqlite-unbundled = ["sqlx/sqlite-unbundled"] +sqlite = ["sqlx/sqlite", "_sqlite"] +sqlite-unbundled = ["sqlx/sqlite-unbundled", "_sqlite"] # workaround for musl + openssl issues openssl-vendored = ["openssl/vendored"] completions = ["dep:clap_complete"] +# Conditional compilation +_sqlite = [] + [dev-dependencies] assert_cmd = "2.0.11" tempfile = "3.10.1" diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 789d30fb1c..71b97332bb 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -19,8 +19,8 @@ any = [] json = ["serde", "serde_json"] # for conditional compilation -_rt-async-std = ["async-std", "async-io"] -_rt-tokio = ["tokio", "tokio-stream"] +_rt-async-std = ["async-std", "async-io", "ease-off/async-io-2"] +_rt-tokio = ["tokio", "tokio-stream", "ease-off/tokio"] _tls-native-tls = ["native-tls"] _tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs"] _tls-rustls-ring = ["_tls-rustls", "rustls/ring"] @@ -59,7 +59,6 @@ crossbeam-queue = "0.3.2" either = "1.6.1" futures-core = { version = "0.3.19", default-features = false } futures-io = "0.3.24" -futures-intrusive = "0.5.0" futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] } log = { version = "0.4.18", default-features = false } memchr = { version = "2.4.1", default-features = false } @@ -81,9 +80,12 @@ indexmap = "2.0" event-listener = "5.2.0" hashbrown = "0.14.5" +ease-off = { workspace = true, features = ["futures"] } +pin-project-lite = "0.2.14" + [dev-dependencies] sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] } -tokio = { version = "1", features = ["rt"] } +tokio = { version = "1", features = ["rt", "sync"] } [lints] workspace = true diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 17774addd2..625de3958c 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -11,6 +11,9 @@ use crate::database::Database; use crate::type_info::TypeInfo; use crate::types::Type; +#[cfg(doc)] +use crate::pool::{PoolConnector, PoolOptions}; + /// A specialized `Result` type for SQLx. pub type Result = ::std::result::Result; @@ -104,6 +107,19 @@ pub enum Error { #[error("attempted to acquire a connection on a closed pool")] PoolClosed, + /// A custom error that may be returned from a [`PoolConnector`] implementation. + #[error("error returned from pool connector")] + PoolConnector { + #[source] + source: BoxDynError, + + /// If `true`, `PoolConnector::connect()` is called again in an exponential backoff loop + /// up to [`PoolOptions::connect_timeout`]. + /// + /// See [`PoolConnector::connect()`] for details. + retryable: bool, + }, + /// A background worker has crashed. #[error("attempted to communicate with a crashed background worker")] WorkerCrashed, @@ -202,11 +218,6 @@ pub trait DatabaseError: 'static + Send + Sync + StdError { #[doc(hidden)] fn into_error(self: Box) -> Box; - #[doc(hidden)] - fn is_transient_in_connect_phase(&self) -> bool { - false - } - /// Returns the name of the constraint that triggered the error, if applicable. /// If the error was caused by a conflict of a unique index, this will be the index name. /// @@ -244,6 +255,24 @@ pub trait DatabaseError: 'static + Send + Sync + StdError { fn is_check_violation(&self) -> bool { matches!(self.kind(), ErrorKind::CheckViolation) } + + /// Returns `true` if this error can be retried when connecting to the database. + /// + /// Defaults to `false`. + /// + /// For example, the Postgres driver overrides this to return `true` for the following error codes: + /// + /// * `53300 too_many_connections`: returned when the maximum connections are exceeded + /// on the server. Assumed to be the result of a temporary overcommit + /// (e.g. an extra application replica being spun up to replace one that is going down). + /// * This error being consistently logged or returned is a likely indicator of a misconfiguration; + /// the sum of [`PoolOptions::max_connections`] for all replicas should not exceed + /// the maximum connections allowed by the server. + /// * `57P03 cannot_connect_now`: returned when the database server is still starting up + /// and the tcop component is not ready to accept connections yet. + fn is_retryable_connect_error(&self) -> bool { + false + } } impl dyn DatabaseError { diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs new file mode 100644 index 0000000000..ee80591428 --- /dev/null +++ b/sqlx-core/src/pool/connect.rs @@ -0,0 +1,491 @@ +use crate::connection::{ConnectOptions, Connection}; +use crate::database::Database; +use crate::pool::connection::Floating; +use crate::pool::inner::PoolInner; +use crate::pool::PoolConnection; +use crate::rt::JoinHandle; +use crate::Error; +use ease_off::EaseOff; +use event_listener::{listener, Event}; +use std::fmt::{Display, Formatter}; +use std::future::Future; +use std::ptr; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Instant; + +use std::io; + +/// Custom connect callback for [`Pool`][crate::pool::Pool]. +/// +/// Implemented for closures with the signature +/// `Fn(PoolConnectMetadata) -> impl Future>`. +/// +/// See [`Self::connect()`] for details and implementation advice. +/// +/// # Example: `after_connect` Replacement +/// The `after_connect` callback was removed in 0.9.0 as it was redundant to this API. +/// +/// This example uses Postgres but may be adapted to any driver. +/// +/// ```rust,no_run +/// use std::sync::Arc; +/// use sqlx::PgConnection; +/// use sqlx::postgres::PgPoolOptions; +/// use sqlx::Connection; +/// use sqlx::pool::PoolConnectMetadata; +/// +/// async fn _example() -> sqlx::Result<()> { +/// // `PoolConnector` is implemented for closures but this has restrictions on returning borrows +/// // due to current language limitations. Custom implementations are not subject to this. +/// // +/// // This example shows how to get around this using `Arc`. +/// let database_url: Arc = "postgres://...".into(); +/// +/// let pool = PgPoolOptions::new() +/// .min_connections(5) +/// .max_connections(30) +/// // Type annotation on the argument is required for the trait impl to reseolve. +/// .connect_with_connector(move |meta: PoolConnectMetadata| { +/// let database_url = database_url.clone(); +/// async move { +/// println!( +/// "opening connection {}, attempt {}; elapsed time: {:?}", +/// meta.pool_size, +/// meta.num_attempts + 1, +/// meta.start.elapsed() +/// ); +/// +/// let mut conn = PgConnection::connect(&database_url).await?; +/// +/// // Override the time zone of the connection. +/// sqlx::raw_sql("SET TIME ZONE 'Europe/Berlin'") +/// .execute(&mut conn) +/// .await?; +/// +/// Ok(conn) +/// } +/// }) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// # Example: `set_connect_options` Replacement +/// `set_connect_options` and `get_connect_options` were removed in 0.9.0 because they complicated +/// the pool internals. They can be reimplemented by capturing a mutex, or similar, in the callback. +/// +/// This example uses Postgres and [`tokio::sync::RwLock`] but may be adapted to any driver +/// or `async-std`, respectively. +/// +/// ```rust,no_run +/// use std::sync::Arc; +/// use tokio::sync::RwLock; +/// use sqlx::PgConnection; +/// use sqlx::postgres::PgConnectOptions; +/// use sqlx::postgres::PgPoolOptions; +/// use sqlx::ConnectOptions; +/// use sqlx::pool::PoolConnectMetadata; +/// +/// async fn _example() -> sqlx::Result<()> { +/// // If you do not wish to hold the lock during the connection attempt, +/// // you could use `Arc` instead. +/// let connect_opts: Arc> = Arc::new(RwLock::new("postgres://...".parse()?)); +/// // We need a copy that will be captured by the closure. +/// let connect_opts_ = connect_opts.clone(); +/// +/// let pool = PgPoolOptions::new() +/// .connect_with_connector(move |meta: PoolConnectMetadata| { +/// let connect_opts = connect_opts_.clone(); +/// async move { +/// println!( +/// "opening connection {}, attempt {}; elapsed time: {:?}", +/// meta.pool_size, +/// meta.num_attempts + 1, +/// meta.start.elapsed() +/// ); +/// +/// connect_opts.read().await.connect().await +/// } +/// }) +/// .await?; +/// +/// // Close the connection that was previously opened by `connect_with_connector()`. +/// pool.acquire().await?.close().await?; +/// +/// // Simulating a credential rotation +/// let mut write_connect_opts = connect_opts.write().await; +/// write_connect_opts +/// .set_username("new_username") +/// .set_password("new password"); +/// +/// // Should use the new credentials. +/// let mut conn = pool.acquire().await?; +/// +/// # Ok(()) +/// # } +/// ``` +/// +/// # Example: Custom Implementation +/// +/// Custom implementations of `PoolConnector` trade a little bit of boilerplate for much +/// more flexibility. Thanks to the signature of `connect()`, they can return a `Future` +/// type that borrows from `self`. +/// +/// This example uses Postgres but may be adapted to any driver. +/// +/// ```rust,no_run +/// use sqlx::{PgConnection, Postgres}; +/// use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +/// use sqlx_core::connection::ConnectOptions; +/// use sqlx_core::pool::{PoolConnectMetadata, PoolConnector}; +/// +/// struct MyConnector { +/// // A list of servers to connect to in a high-availability configuration. +/// host_ports: Vec<(String, u16)>, +/// username: String, +/// password: String, +/// } +/// +/// impl PoolConnector for MyConnector { +/// // The desugaring of `async fn` is compatible with the signature of `connect()`. +/// async fn connect(&self, meta: PoolConnectMetadata) -> sqlx::Result { +/// self.get_connect_options(meta.num_attempts) +/// .connect() +/// .await +/// } +/// } +/// +/// impl MyConnector { +/// fn get_connect_options(&self, attempt: usize) -> PgConnectOptions { +/// // Select servers in a round-robin. +/// let (ref host, port) = self.host_ports[attempt % self.host_ports.len()]; +/// +/// PgConnectOptions::new() +/// .host(host) +/// .port(port) +/// .username(&self.username) +/// .password(&self.password) +/// } +/// } +/// +/// # async fn _example() -> sqlx::Result<()> { +/// let pool = PgPoolOptions::new() +/// .max_connections(25) +/// .connect_with_connector(MyConnector { +/// host_ports: vec![ +/// ("db1.postgres.cluster.local".into(), 5432), +/// ("db2.postgres.cluster.local".into(), 5432), +/// ("db3.postgres.cluster.local".into(), 5432), +/// ("db4.postgres.cluster.local".into(), 5432), +/// ], +/// username: "my_username".into(), +/// password: "my password".into(), +/// }) +/// .await?; +/// +/// let conn = pool.acquire().await?; +/// +/// # Ok(()) +/// # } +/// ``` +pub trait PoolConnector: Send + Sync + 'static { + /// Open a connection for the pool. + /// + /// Any setup that must be done on the connection should be performed before it is returned. + /// + /// If this method returns an error that is known to be retryable, it is called again + /// in an exponential backoff loop. Retryable errors include, but are not limited to: + /// + /// * [`io::ErrorKind::ConnectionRefused`] + /// * Database errors for which + /// [`is_retryable_connect_error`][crate::error::DatabaseError::is_retryable_connect_error] + /// returns `true`. + /// * [`Error::PoolConnector`] with `retryable: true`. + /// This error kind is not returned internally and is designed to allow this method to return + /// arbitrary error types not otherwise supported. + /// + /// Manual implementations of this method may also use the signature: + /// ```rust,ignore + /// async fn connect( + /// &self, + /// meta: PoolConnectMetadata + /// ) -> sqlx::Result<{PgConnection, MySqlConnection, SqliteConnection, etc.}> + /// ``` + /// + /// Note: the returned future must be `Send`. + fn connect( + &self, + meta: PoolConnectMetadata, + ) -> impl Future> + Send + '_; +} + +impl PoolConnector for F +where + DB: Database, + F: Fn(PoolConnectMetadata) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, +{ + fn connect( + &self, + meta: PoolConnectMetadata, + ) -> impl Future> + Send + '_ { + self(meta) + } +} + +pub(crate) struct DefaultConnector( + pub <::Connection as Connection>::Options, +); + +impl PoolConnector for DefaultConnector { + fn connect( + &self, + _meta: PoolConnectMetadata, + ) -> impl Future> + Send + '_ { + self.0.connect() + } +} + +/// Metadata passed to [`PoolConnector::connect()`] for every connection attempt. +#[derive(Debug)] +#[non_exhaustive] +pub struct PoolConnectMetadata { + /// The instant at which the current connection task was started, including all attempts. + /// + /// May be used for reporting purposes, or to implement a custom backoff. + pub start: Instant, + /// The number of attempts that have occurred so far. + pub num_attempts: usize, + /// The current size of the pool. + pub pool_size: usize, + /// The ID of the connection, unique for the pool. + pub connection_id: ConnectionId, +} + +pub struct DynConnector { + // We want to spawn the connection attempt as a task anyway + connect: Box< + dyn Fn(ConnectionId, ConnectPermit) -> JoinHandle>> + + Send + + Sync + + 'static, + >, +} + +impl DynConnector { + pub fn new(connector: impl PoolConnector) -> Self { + let connector = Arc::new(connector); + + Self { + connect: Box::new(move |id, permit| { + crate::rt::spawn(connect_with_backoff(id, permit, connector.clone())) + }), + } + } + + pub fn connect( + &self, + id: ConnectionId, + permit: ConnectPermit, + ) -> JoinHandle>> { + (self.connect)(id, permit) + } +} + +pub struct ConnectionCounter { + count: AtomicUsize, + next_id: AtomicUsize, + connect_available: Event, +} + +/// An opaque connection ID, unique for every connection attempt with the same pool. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct ConnectionId(usize); + +impl ConnectionCounter { + pub fn new() -> Self { + Self { + count: AtomicUsize::new(0), + next_id: AtomicUsize::new(1), + connect_available: Event::new(), + } + } + + pub fn connections(&self) -> usize { + self.count.load(Ordering::Acquire) + } + + pub async fn drain(&self) { + while self.count.load(Ordering::Acquire) > 0 { + listener!(self.connect_available => permit_released); + permit_released.await; + } + } + + /// Attempt to acquire a permit from both this instance, and the parent pool, if applicable. + /// + /// Returns the permit, and the ID of the new connection. + pub fn try_acquire_permit( + &self, + pool: &Arc>, + ) -> Option<(ConnectionId, ConnectPermit)> { + debug_assert!(ptr::addr_eq(self, &pool.counter)); + + // Don't skip the queue. + if pool.options.fair && self.connect_available.total_listeners() > 0 { + return None; + } + + let prev_size = self + .count + .fetch_update(Ordering::Release, Ordering::Acquire, |connections| { + (connections < pool.options.max_connections).then_some(connections + 1) + }) + .ok()?; + + let size = prev_size + 1; + + tracing::trace!(target: "sqlx::pool::connect", size, "increased size"); + + Some(( + ConnectionId(self.next_id.fetch_add(1, Ordering::SeqCst)), + ConnectPermit { + pool: Some(Arc::clone(pool)), + }, + )) + } + + /// Attempt to acquire a permit from both this instance, and the parent pool, if applicable. + /// + /// Returns the permit, and the current size of the pool. + pub async fn acquire_permit( + &self, + pool: &Arc>, + ) -> (ConnectionId, ConnectPermit) { + // Check that `self` can increase size first before we check the parent. + let acquired = self.acquire_permit_self(pool).await; + + if let Some(parent) = pool.parent() { + let (_, permit) = parent.0.counter.acquire_permit_self(&parent.0).await; + + // consume the parent permit + permit.consume(); + } + + acquired + } + + // Separate method because `async fn`s cannot be recursive. + /// Attempt to acquire a [`ConnectPermit`] from this instance and this instance only. + async fn acquire_permit_self( + &self, + pool: &Arc>, + ) -> (ConnectionId, ConnectPermit) { + for attempt in 1usize.. { + if let Some(acquired) = self.try_acquire_permit(pool) { + return acquired; + } + + if attempt == 2 { + tracing::warn!( + "unable to acquire a connect permit after sleeping; this may indicate a bug" + ); + } + + listener!(self.connect_available => connect_available); + connect_available.await; + } + + panic!("BUG: was never able to acquire a connection despite waking many times") + } + + pub fn release_permit(&self, pool: &PoolInner) { + debug_assert!(ptr::addr_eq(self, &pool.counter)); + + self.count.fetch_sub(1, Ordering::Release); + self.connect_available.notify(1usize); + + if let Some(parent) = &pool.options.parent_pool { + parent.0.counter.release_permit(&parent.0); + } + } +} + +pub struct ConnectPermit { + pool: Option>>, +} + +impl ConnectPermit { + pub fn float_existing(pool: Arc>) -> Self { + Self { pool: Some(pool) } + } + + pub fn pool(&self) -> &Arc> { + self.pool.as_ref().unwrap() + } + + pub fn consume(mut self) { + self.pool = None; + } +} + +impl Drop for ConnectPermit { + fn drop(&mut self) { + if let Some(pool) = self.pool.take() { + pool.counter.release_permit(&pool); + } + } +} + +impl Display for ConnectionId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +#[tracing::instrument( + target = "sqlx::pool::connect", + skip_all, + fields(%connection_id), + err +)] +async fn connect_with_backoff( + connection_id: ConnectionId, + permit: ConnectPermit, + connector: Arc>, +) -> crate::Result> { + if permit.pool().is_closed() { + return Err(Error::PoolClosed); + } + + let mut ease_off = EaseOff::start_timeout(permit.pool().options.connect_timeout); + + for attempt in 1usize.. { + let meta = PoolConnectMetadata { + start: ease_off.started_at(), + num_attempts: attempt, + pool_size: permit.pool().size(), + connection_id, + }; + + let conn = ease_off + .try_async(connector.connect(meta)) + .await + .or_retry_if(|e| can_retry_error(e.inner()))?; + + if let Some(conn) = conn { + return Ok(Floating::new_live(conn, connection_id, permit).reattach()); + } + } + + Err(Error::PoolTimedOut) +} + +fn can_retry_error(e: &Error) -> bool { + match e { + Error::Io(e) if e.kind() == io::ErrorKind::ConnectionRefused => true, + Error::Database(e) => e.is_retryable_connect_error(), + _ => false, + } +} diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index bf3a6d4b1c..b3044de14d 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -3,16 +3,17 @@ use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::{Duration, Instant}; -use crate::sync::AsyncSemaphoreReleaser; - use crate::connection::Connection; use crate::database::Database; use crate::error::Error; -use super::inner::{is_beyond_max_lifetime, DecrementSizeGuard, PoolInner}; +use super::inner::{is_beyond_max_lifetime, PoolInner}; +use crate::pool::connect::{ConnectPermit, ConnectionId}; use crate::pool::options::PoolConnectionMetadata; +use crate::rt; use std::future::Future; +const RETURN_TO_POOL_TIMEOUT: Duration = Duration::from_secs(5); const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); /// A connection managed by a [`Pool`][crate::pool::Pool]. @@ -26,6 +27,7 @@ pub struct PoolConnection { pub(super) struct Live { pub(super) raw: DB::Connection, + pub(super) id: ConnectionId, pub(super) created_at: Instant, } @@ -37,15 +39,17 @@ pub(super) struct Idle { /// RAII wrapper for connections being handled by functions that may drop them pub(super) struct Floating { pub(super) inner: C, - pub(super) guard: DecrementSizeGuard, + pub(super) permit: ConnectPermit, } const EXPECT_MSG: &str = "BUG: inner connection already taken!"; impl Debug for PoolConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - // TODO: Show the type name of the connection ? - f.debug_struct("PoolConnection").finish() + f.debug_struct("PoolConnection") + .field("database", &DB::NAME) + .field("id", &self.live.as_ref().map(|live| live.id)) + .finish() } } @@ -127,6 +131,10 @@ impl PoolConnection { self.live.take().expect(EXPECT_MSG) } + pub(super) fn into_floating(mut self) -> Floating> { + self.take_live().float(self.pool.clone()) + } + /// Test the connection to make sure it is still live before returning it to the pool. /// /// This effectively runs the drop handler eagerly instead of spawning a task to do it. @@ -143,7 +151,9 @@ impl PoolConnection { async move { let returned_to_pool = if let Some(floating) = floating { - floating.return_to_pool().await + rt::timeout(RETURN_TO_POOL_TIMEOUT, floating.return_to_pool()) + .await + .unwrap_or(false) } else { false }; @@ -215,7 +225,7 @@ impl Live { Floating { inner: self, // create a new guard from a previously leaked permit - guard: DecrementSizeGuard::new_permit(pool), + permit: ConnectPermit::float_existing(pool), } } @@ -242,22 +252,23 @@ impl DerefMut for Idle { } impl Floating> { - pub fn new_live(conn: DB::Connection, guard: DecrementSizeGuard) -> Self { + pub fn new_live(conn: DB::Connection, id: ConnectionId, permit: ConnectPermit) -> Self { Self { inner: Live { raw: conn, + id, created_at: Instant::now(), }, - guard, + permit, } } pub fn reattach(self) -> PoolConnection { - let Floating { inner, guard } = self; + let Floating { inner, permit } = self; - let pool = Arc::clone(&guard.pool); + let pool = Arc::clone(permit.pool()); - guard.cancel(); + permit.consume(); PoolConnection { live: Some(inner), close_on_drop: false, @@ -266,7 +277,7 @@ impl Floating> { } pub fn release(self) { - self.guard.pool.clone().release(self); + self.permit.pool().clone().release(self); } /// Return the connection to the pool. @@ -274,19 +285,19 @@ impl Floating> { /// Returns `true` if the connection was successfully returned, `false` if it was closed. async fn return_to_pool(mut self) -> bool { // Immediately close the connection. - if self.guard.pool.is_closed() { + if self.permit.pool().is_closed() { self.close().await; return false; } // If the connection is beyond max lifetime, close the connection and // immediately create a new connection - if is_beyond_max_lifetime(&self.inner, &self.guard.pool.options) { + if is_beyond_max_lifetime(&self.inner, &self.permit.pool().options) { self.close().await; return false; } - if let Some(test) = &self.guard.pool.options.after_release { + if let Some(test) = &self.permit.pool().options.after_release { let meta = self.metadata(); match (test)(&mut self.inner.raw, meta).await { Ok(true) => (), @@ -345,7 +356,7 @@ impl Floating> { pub fn into_idle(self) -> Floating> { Floating { inner: self.inner.into_idle(), - guard: self.guard, + permit: self.permit, } } @@ -358,14 +369,10 @@ impl Floating> { } impl Floating> { - pub fn from_idle( - idle: Idle, - pool: Arc>, - permit: AsyncSemaphoreReleaser<'_>, - ) -> Self { + pub fn from_idle(idle: Idle, pool: Arc>) -> Self { Self { inner: idle, - guard: DecrementSizeGuard::from_permit(pool, permit), + permit: ConnectPermit::float_existing(pool), } } @@ -376,21 +383,33 @@ impl Floating> { pub fn into_live(self) -> Floating> { Floating { inner: self.inner.live, - guard: self.guard, + permit: self.permit, } } - pub async fn close(self) -> DecrementSizeGuard { + pub async fn close(self) -> (ConnectionId, ConnectPermit) { + let connection_id = self.inner.live.id; + + tracing::debug!(%connection_id, "closing connection (gracefully)"); + if let Err(error) = self.inner.live.raw.close().await { - tracing::debug!(%error, "error occurred while closing the pool connection"); + tracing::debug!( + %connection_id, + %error, + "error occurred while closing the pool connection" + ); } - self.guard + (connection_id, self.permit) } - pub async fn close_hard(self) -> DecrementSizeGuard { + pub async fn close_hard(self) -> (ConnectionId, ConnectPermit) { + let connection_id = self.inner.live.id; + + tracing::debug!(%connection_id, "closing connection (hard)"); + let _ = self.inner.live.raw.close_hard().await; - self.guard + (connection_id, self.permit) } pub fn metadata(&self) -> PoolConnectionMetadata { diff --git a/sqlx-core/src/pool/idle.rs b/sqlx-core/src/pool/idle.rs new file mode 100644 index 0000000000..8b07b8e7c4 --- /dev/null +++ b/sqlx-core/src/pool/idle.rs @@ -0,0 +1,100 @@ +use crate::connection::Connection; +use crate::database::Database; +use crate::pool::connection::{Floating, Idle, Live}; +use crate::pool::inner::PoolInner; +use crossbeam_queue::ArrayQueue; +use event_listener::Event; +use futures_util::FutureExt; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use event_listener::listener; + +pub struct IdleQueue { + queue: ArrayQueue>, + // Keep a separate count because `ArrayQueue::len()` loops until the head and tail pointers + // stop changing, which may never happen at high contention. + len: AtomicUsize, + release_event: Event, + fair: bool, +} + +impl IdleQueue { + pub fn new(fair: bool, cap: usize) -> Self { + Self { + queue: ArrayQueue::new(cap), + len: AtomicUsize::new(0), + release_event: Event::new(), + fair, + } + } + + pub fn len(&self) -> usize { + self.len.load(Ordering::Acquire) + } + + pub async fn acquire(&self, pool: &Arc>) -> Floating> { + let mut should_wait = self.fair && self.release_event.total_listeners() > 0; + + for attempt in 1usize.. { + if should_wait { + listener!(self.release_event => release_event); + release_event.await; + } + + if let Some(conn) = self.try_acquire(pool) { + return conn; + } + + should_wait = true; + + if attempt == 2 { + tracing::warn!( + "unable to acquire a connection after sleeping; this may indicate a bug" + ); + } + } + + panic!("BUG: was never able to acquire a connection despite waking many times") + } + + pub fn try_acquire(&self, pool: &Arc>) -> Option>> { + self.len + .fetch_update(Ordering::Release, Ordering::Acquire, |len| { + len.checked_sub(1) + }) + .ok() + .and_then(|_| { + let conn = self.queue.pop()?; + + Some(Floating::from_idle(conn, Arc::clone(pool))) + }) + } + + pub fn release(&self, conn: Floating>) { + let Floating { + inner: conn, + permit, + } = conn.into_idle(); + + self.queue + .push(conn) + .unwrap_or_else(|_| panic!("BUG: idle queue capacity exceeded")); + + self.len.fetch_add(1, Ordering::Release); + + self.release_event.notify(1usize); + + // Don't decrease the size. + permit.consume(); + } + + pub fn drain(&self, pool: &PoolInner) { + while let Some(conn) = self.queue.pop() { + // Hopefully will send at least a TCP FIN packet. + conn.live.raw.close_hard().now_or_never(); + + pool.counter.release_permit(pool); + } + } +} diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index bbcc43134e..51b0cd47e7 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -1,33 +1,30 @@ use super::connection::{Floating, Idle, Live}; -use crate::connection::ConnectOptions; -use crate::connection::Connection; use crate::database::Database; use crate::error::Error; -use crate::pool::{deadline_as_timeout, CloseEvent, Pool, PoolOptions}; -use crossbeam_queue::ArrayQueue; - -use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser}; +use crate::pool::{CloseEvent, Pool, PoolConnection, PoolConnector, PoolOptions}; use std::cmp; use std::future::Future; -use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}; -use std::sync::{Arc, RwLock}; -use std::task::Poll; +use std::pin::pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::task::ready; use crate::logger::private_level_filter_to_trace_level; -use crate::pool::options::PoolConnectionMetadata; -use crate::private_tracing_dynamic_event; -use futures_util::future::{self}; -use futures_util::FutureExt; +use crate::pool::connect::{ConnectPermit, ConnectionCounter, ConnectionId, DynConnector}; +use crate::pool::idle::IdleQueue; +use crate::rt::JoinHandle; +use crate::{private_tracing_dynamic_event, rt}; +use either::Either; +use futures_util::future::{self, OptionFuture}; +use futures_util::{FutureExt}; use std::time::{Duration, Instant}; use tracing::Level; pub(crate) struct PoolInner { - pub(super) connect_options: RwLock::Options>>, - pub(super) idle_conns: ArrayQueue>, - pub(super) semaphore: AsyncSemaphore, - pub(super) size: AtomicU32, - pub(super) num_idle: AtomicUsize, + pub(super) connector: DynConnector, + pub(super) counter: ConnectionCounter, + pub(super) idle: IdleQueue, is_closed: AtomicBool, pub(super) on_closed: event_listener::Event, pub(super) options: PoolOptions, @@ -38,25 +35,12 @@ pub(crate) struct PoolInner { impl PoolInner { pub(super) fn new_arc( options: PoolOptions, - connect_options: ::Options, + connector: impl PoolConnector, ) -> Arc { - let capacity = options.max_connections as usize; - - let semaphore_capacity = if let Some(parent) = &options.parent_pool { - assert!(options.max_connections <= parent.options().max_connections); - assert_eq!(options.fair, parent.options().fair); - // The child pool must steal permits from the parent - 0 - } else { - capacity - }; - let pool = Self { - connect_options: RwLock::new(Arc::new(connect_options)), - idle_conns: ArrayQueue::new(capacity), - semaphore: AsyncSemaphore::new(options.fair, semaphore_capacity), - size: AtomicU32::new(0), - num_idle: AtomicUsize::new(0), + connector: DynConnector::new(connector), + counter: ConnectionCounter::new(), + idle: IdleQueue::new(options.fair, options.max_connections), is_closed: AtomicBool::new(false), on_closed: event_listener::Event::new(), acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), @@ -71,17 +55,12 @@ impl PoolInner { pool } - pub(super) fn size(&self) -> u32 { - self.size.load(Ordering::Acquire) + pub(super) fn size(&self) -> usize { + self.counter.connections() } pub(super) fn num_idle(&self) -> usize { - // We don't use `self.idle_conns.len()` as it waits for the internal - // head and tail pointers to stop changing for a moment before calculating the length, - // which may take a long time at high levels of churn. - // - // By maintaining our own atomic count, we avoid that issue entirely. - self.num_idle.load(Ordering::Acquire) + self.idle.len() } pub(super) fn is_closed(&self) -> bool { @@ -96,19 +75,22 @@ impl PoolInner { pub(super) fn close<'a>(self: &'a Arc) -> impl Future + 'a { self.mark_closed(); + // Keep clearing the idle queue as connections are released until the count reaches zero. async move { - for permits in 1..=self.options.max_connections { - // Close any currently idle connections in the pool. - while let Some(idle) = self.idle_conns.pop() { - let _ = idle.live.float((*self).clone()).close().await; - } - - if self.size() == 0 { - break; + let mut drained = pin!(self.counter.drain()); + + loop { + let mut acquire_idle = pin!(self.idle.acquire(self)); + + // Not using `futures::select!{}` here because it requires a proc-macro dep, + // and frankly it's a little broken. + match future::select(drained.as_mut(), acquire_idle.as_mut()).await { + // *not* `either::Either`; they rolled their own + future::Either::Left(_) => break, + future::Either::Right((idle, _)) => { + idle.close().await; + } } - - // Wait for all permits to be released. - let _permits = self.semaphore.acquire(permits).await; } } } @@ -119,64 +101,7 @@ impl PoolInner { } } - /// Attempt to pull a permit from `self.semaphore` or steal one from the parent. - /// - /// If we steal a permit from the parent but *don't* open a connection, - /// it should be returned to the parent. - async fn acquire_permit<'a>(self: &'a Arc) -> Result, Error> { - let parent = self - .parent() - // If we're already at the max size, we shouldn't try to steal from the parent. - // This is just going to cause unnecessary churn in `acquire()`. - .filter(|_| self.size() < self.options.max_connections); - - let acquire_self = self.semaphore.acquire(1).fuse(); - let mut close_event = self.close_event(); - - if let Some(parent) = parent { - let acquire_parent = parent.0.semaphore.acquire(1); - let parent_close_event = parent.0.close_event(); - - futures_util::pin_mut!( - acquire_parent, - acquire_self, - close_event, - parent_close_event - ); - - let mut poll_parent = false; - - future::poll_fn(|cx| { - if close_event.as_mut().poll(cx).is_ready() { - return Poll::Ready(Err(Error::PoolClosed)); - } - - if parent_close_event.as_mut().poll(cx).is_ready() { - // Propagate the parent's close event to the child. - self.mark_closed(); - return Poll::Ready(Err(Error::PoolClosed)); - } - - if let Poll::Ready(permit) = acquire_self.as_mut().poll(cx) { - return Poll::Ready(Ok(permit)); - } - - // Don't try the parent right away. - if poll_parent { - acquire_parent.as_mut().poll(cx).map(Ok) - } else { - poll_parent = true; - cx.waker().wake_by_ref(); - Poll::Pending - } - }) - .await - } else { - close_event.do_until(acquire_self).await - } - } - - fn parent(&self) -> Option<&Pool> { + pub(super) fn parent(&self) -> Option<&Pool> { self.options.parent_pool.as_ref() } @@ -186,117 +111,110 @@ impl PoolInner { return None; } - let permit = self.semaphore.try_acquire(1)?; - - self.pop_idle(permit).ok() - } - - fn pop_idle<'a>( - self: &'a Arc, - permit: AsyncSemaphoreReleaser<'a>, - ) -> Result>, AsyncSemaphoreReleaser<'a>> { - if let Some(idle) = self.idle_conns.pop() { - self.num_idle.fetch_sub(1, Ordering::AcqRel); - Ok(Floating::from_idle(idle, (*self).clone(), permit)) - } else { - Err(permit) - } + self.idle.try_acquire(self) } pub(super) fn release(&self, floating: Floating>) { // `options.after_release` and other checks are in `PoolConnection::return_to_pool()`. + self.idle.release(floating); + } - let Floating { inner: idle, guard } = floating.into_idle(); - - if self.idle_conns.push(idle).is_err() { - panic!("BUG: connection queue overflow in release()"); + pub(super) async fn acquire(self: &Arc) -> Result, Error> { + if self.is_closed() { + return Err(Error::PoolClosed); } - // NOTE: we need to make sure we drop the permit *after* we push to the idle queue - // don't decrease the size - guard.release_permit(); - - self.num_idle.fetch_add(1, Ordering::AcqRel); - } + let acquire_started_at = Instant::now(); - /// Try to atomically increment the pool size for a new connection. - /// - /// Returns `Err` if the pool is at max capacity already or is closed. - pub(super) fn try_increment_size<'a>( - self: &'a Arc, - permit: AsyncSemaphoreReleaser<'a>, - ) -> Result, AsyncSemaphoreReleaser<'a>> { - let result = self - .size - .fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| { - if self.is_closed() { - return None; - } + let mut close_event = pin!(self.close_event()); + let mut deadline = pin!(rt::sleep(self.options.acquire_timeout)); + let mut acquire_idle = pin!(self.idle.acquire(self).fuse()); + let mut before_acquire = OptionFuture::from(None); + let mut acquire_connect_permit = pin!(OptionFuture::from(Some( + self.counter.acquire_permit(self).fuse() + ))); + let mut connect = OptionFuture::from(None); - size.checked_add(1) - .filter(|size| size <= &self.options.max_connections) - }); + // The internal state machine of `acquire()`. + // + // * The initial state is racing to acquire either an idle connection or a new `ConnectPermit`. + // * If we acquire a `ConnectPermit`, we begin the connection loop (with backoff) + // as implemented by `DynConnector`. + // * If we acquire an idle connection, we then start polling `check_idle_conn()`. + // + // This doesn't quite fit into `select!{}` because the set of futures that may be polled + // at a given time is dynamic, so it's actually simpler to hand-roll it. + let acquired = future::poll_fn(|cx| { + use std::task::Poll::*; + + // First check if the pool is already closed, + // or register for a wakeup if it gets closed. + if let Ready(()) = close_event.poll_unpin(cx) { + return Ready(Err(Error::PoolClosed)); + } - match result { - // we successfully incremented the size - Ok(_) => Ok(DecrementSizeGuard::from_permit((*self).clone(), permit)), - // the pool is at max capacity or is closed - Err(_) => Err(permit), - } - } + // Then check if our deadline has elapsed, or schedule a wakeup for when that happens. + if let Ready(()) = deadline.poll_unpin(cx) { + return Ready(Err(Error::PoolTimedOut)); + } - pub(super) async fn acquire(self: &Arc) -> Result>, Error> { - if self.is_closed() { - return Err(Error::PoolClosed); - } + // Attempt to acquire a connection from the idle queue. + if let Ready(idle) = acquire_idle.poll_unpin(cx) { + // If we acquired an idle connection, run any checks that need to be done. + // + // Includes `test_on_acquire` and the `before_acquire` callback, if set. + match finish_acquire(idle) { + // There are checks needed to be done, so they're spawned as a task + // to be cancellation-safe. + Either::Left(check_task) => { + before_acquire = Some(check_task).into(); + } + // The connection is ready to go. + Either::Right(conn) => { + return Ready(Ok(conn)); + } + } + } - let acquire_started_at = Instant::now(); - let deadline = acquire_started_at + self.options.acquire_timeout; - - let acquired = crate::rt::timeout( - self.options.acquire_timeout, - async { - loop { - // Handles the close-event internally - let permit = self.acquire_permit().await?; - - - // First attempt to pop a connection from the idle queue. - let guard = match self.pop_idle(permit) { - - // Then, check that we can use it... - Ok(conn) => match check_idle_conn(conn, &self.options).await { - - // All good! - Ok(live) => return Ok(live), - - // if the connection isn't usable for one reason or another, - // we get the `DecrementSizeGuard` back to open a new one - Err(guard) => guard, - }, - Err(permit) => if let Ok(guard) = self.try_increment_size(permit) { - // we can open a new connection - guard - } else { - // This can happen for a child pool that's at its connection limit, - // or if the pool was closed between `acquire_permit()` and - // `try_increment_size()`. - tracing::debug!("woke but was unable to acquire idle connection or open new one; retrying"); - // If so, we're likely in the current-thread runtime if it's Tokio, - // and so we should yield to let any spawned return_to_pool() tasks - // execute. - crate::rt::yield_now().await; - continue; + // Poll the task returned by `finish_acquire` + match ready!(before_acquire.poll_unpin(cx)) { + Some(Ok(conn)) => return Ready(Ok(conn)), + Some(Err((id, permit))) => { + // We don't strictly need to poll `connect` here; all we really want to do + // is to check if it is `None`. But since currently there's no getter for that, + // it doesn't really hurt to just poll it here. + match connect.poll_unpin(cx) { + Ready(None) => { + // If we're not already attempting to connect, + // take the permit returned from closing the connection and + // attempt to open a new one. + connect = Some(self.connector.connect(id, permit)).into(); } - }; + // `permit` is dropped in these branches, allowing another task to use it + Ready(Some(res)) => return Ready(res), + Pending => (), + } - // Attempt to connect... - return self.connect(deadline, guard).await; + // Attempt to acquire another idle connection concurrently to opening a new one. + acquire_idle.set(self.idle.acquire(self).fuse()); + // Annoyingly, `OptionFuture` doesn't fuse to `None` on its own + before_acquire = None.into(); } + None => (), } - ) - .await - .map_err(|_| Error::PoolTimedOut)??; + + if let Ready(Some((id, permit))) = acquire_connect_permit.poll_unpin(cx) { + connect = Some(self.connector.connect(id, permit)).into(); + } + + if let Ready(Some(res)) = connect.poll_unpin(cx) { + // RFC: suppress errors here? + return Ready(res); + } + + Pending + }) + .await?; let acquired_after = acquire_started_at.elapsed(); @@ -324,102 +242,29 @@ impl PoolInner { Ok(acquired) } - pub(super) async fn connect( - self: &Arc, - deadline: Instant, - guard: DecrementSizeGuard, - ) -> Result>, Error> { - if self.is_closed() { - return Err(Error::PoolClosed); - } - - let mut backoff = Duration::from_millis(10); - let max_backoff = deadline_as_timeout(deadline)? / 5; - - loop { - let timeout = deadline_as_timeout(deadline)?; - - // clone the connect options arc so it can be used without holding the RwLockReadGuard - // across an async await point - let connect_options = self - .connect_options - .read() - .expect("write-lock holder panicked") - .clone(); - - // result here is `Result, TimeoutError>` - // if this block does not return, sleep for the backoff timeout and try again - match crate::rt::timeout(timeout, connect_options.connect()).await { - // successfully established connection - Ok(Ok(mut raw)) => { - // See comment on `PoolOptions::after_connect` - let meta = PoolConnectionMetadata { - age: Duration::ZERO, - idle_for: Duration::ZERO, - }; - - let res = if let Some(callback) = &self.options.after_connect { - callback(&mut raw, meta).await - } else { - Ok(()) - }; - - match res { - Ok(()) => return Ok(Floating::new_live(raw, guard)), - Err(error) => { - tracing::error!(%error, "error returned from after_connect"); - // The connection is broken, don't try to close nicely. - let _ = raw.close_hard().await; - - // Fall through to the backoff. - } - } - } - - // an IO error while connecting is assumed to be the system starting up - Ok(Err(Error::Io(e))) if e.kind() == std::io::ErrorKind::ConnectionRefused => (), - - // We got a transient database error, retry. - Ok(Err(Error::Database(error))) if error.is_transient_in_connect_phase() => (), - - // Any other error while connection should immediately - // terminate and bubble the error up - Ok(Err(e)) => return Err(e), - - // timed out - Err(_) => return Err(Error::PoolTimedOut), - } - - // If the connection is refused, wait in exponentially - // increasing steps for the server to come up, - // capped by a factor of the remaining time until the deadline - crate::rt::sleep(backoff).await; - backoff = cmp::min(backoff * 2, max_backoff); - } - } - /// Try to maintain `min_connections`, returning any errors (including `PoolTimedOut`). pub async fn try_min_connections(self: &Arc, deadline: Instant) -> Result<(), Error> { - while self.size() < self.options.min_connections { - // Don't wait for a semaphore permit. - // - // If no extra permits are available then we shouldn't be trying to spin up - // connections anyway. - let Some(permit) = self.semaphore.try_acquire(1) else { - return Ok(()); - }; - - // We must always obey `max_connections`. - let Some(guard) = self.try_increment_size(permit).ok() else { - return Ok(()); - }; - - // We skip `after_release` since the connection was never provided to user code - // besides `after_connect`, if they set it. - self.release(self.connect(deadline, guard).await?); - } + rt::timeout_at(deadline, async { + while self.size() < self.options.min_connections { + // Don't wait for a connect permit. + // + // If no extra permits are available then we shouldn't be trying to spin up + // connections anyway. + let Some((id, permit)) = self.counter.try_acquire_permit(self) else { + return Ok(()); + }; + + let conn = self.connector.connect(id, permit).await?; + + // We skip `after_release` since the connection was never provided to user code + // besides inside `PollConnector::connect()`, if they override it. + self.release(conn.into_floating()); + } - Ok(()) + Ok(()) + }) + .await + .unwrap_or_else(|_| Err(Error::PoolTimedOut)) } /// Attempt to maintain `min_connections`, logging if unable. @@ -443,11 +288,7 @@ impl PoolInner { impl Drop for PoolInner { fn drop(&mut self) { self.mark_closed(); - - if let Some(parent) = &self.options.parent_pool { - // Release the stolen permits. - parent.0.semaphore.release(self.semaphore.permits()); - } + self.idle.drain(self); } } @@ -468,42 +309,54 @@ fn is_beyond_idle_timeout(idle: &Idle, options: &PoolOptions timeout) } -async fn check_idle_conn( +/// Execute `test_before_acquire` and/or `before_acquire` in a background task, if applicable. +/// +/// Otherwise, immediately returns the connection. +fn finish_acquire( mut conn: Floating>, - options: &PoolOptions, -) -> Result>, DecrementSizeGuard> { - if options.test_before_acquire { - // Check that the connection is still live - if let Err(error) = conn.ping().await { - // an error here means the other end has hung up or we lost connectivity - // either way we're fine to just discard the connection - // the error itself here isn't necessarily unexpected so WARN is too strong - tracing::info!(%error, "ping on idle connection returned error"); - // connection is broken so don't try to close nicely - return Err(conn.close_hard().await); - } - } - - if let Some(test) = &options.before_acquire { - let meta = conn.metadata(); - match test(&mut conn.live.raw, meta).await { - Ok(false) => { - // connection was rejected by user-defined hook, close nicely - return Err(conn.close().await); - } - - Err(error) => { - tracing::warn!(%error, "error from `before_acquire`"); +) -> Either< + JoinHandle, (ConnectionId, ConnectPermit)>>, + PoolConnection, +> { + let pool = conn.permit.pool(); + + if pool.options.test_before_acquire || pool.options.before_acquire.is_some() { + // Spawn a task so the call may complete even if `acquire()` is cancelled. + return Either::Left(rt::spawn(async move { + // Check that the connection is still live + if let Err(error) = conn.ping().await { + // an error here means the other end has hung up or we lost connectivity + // either way we're fine to just discard the connection + // the error itself here isn't necessarily unexpected so WARN is too strong + tracing::info!(%error, "ping on idle connection returned error"); // connection is broken so don't try to close nicely return Err(conn.close_hard().await); } - Ok(true) => {} - } + if let Some(test) = &conn.permit.pool().options.before_acquire { + let meta = conn.metadata(); + match test(&mut conn.inner.live.raw, meta).await { + Ok(false) => { + // connection was rejected by user-defined hook, close nicely + return Err(conn.close().await); + } + + Err(error) => { + tracing::warn!(%error, "error from `before_acquire`"); + // connection is broken so don't try to close nicely + return Err(conn.close_hard().await); + } + + Ok(true) => {} + } + } + + Ok(conn.into_live().reattach()) + })); } - // No need to re-connect; connection is alive or we don't care - Ok(conn.into_live()) + // No checks are configured, return immediately. + Either::Right(conn.into_live().reattach()) } fn spawn_maintenance_tasks(pool: &Arc>) { @@ -518,7 +371,7 @@ fn spawn_maintenance_tasks(pool: &Arc>) { (None, None) => { if pool.options.min_connections > 0 { - crate::rt::spawn(async move { + rt::spawn(async move { if let Some(pool) = pool_weak.upgrade() { pool.min_connections_maintenance(None).await; } @@ -532,7 +385,7 @@ fn spawn_maintenance_tasks(pool: &Arc>) { // Immediately cancel this task if the pool is closed. let mut close_event = pool.close_event(); - crate::rt::spawn(async move { + rt::spawn(async move { let _ = close_event .do_until(async { // If the last handle to the pool was dropped while we were sleeping @@ -565,61 +418,13 @@ fn spawn_maintenance_tasks(pool: &Arc>) { if let Some(duration) = next_run.checked_duration_since(Instant::now()) { // `async-std` doesn't have a `sleep_until()` - crate::rt::sleep(duration).await; + rt::sleep(duration).await; } else { // `next_run` is in the past, just yield. - crate::rt::yield_now().await; + rt::yield_now().await; } } }) .await; }); } - -/// RAII guard returned by `Pool::try_increment_size()` and others. -/// -/// Will decrement the pool size if dropped, to avoid semantically "leaking" connections -/// (where the pool thinks it has more connections than it does). -pub(in crate::pool) struct DecrementSizeGuard { - pub(crate) pool: Arc>, - cancelled: bool, -} - -impl DecrementSizeGuard { - /// Create a new guard that will release a semaphore permit on-drop. - pub fn new_permit(pool: Arc>) -> Self { - Self { - pool, - cancelled: false, - } - } - - pub fn from_permit(pool: Arc>, permit: AsyncSemaphoreReleaser<'_>) -> Self { - // here we effectively take ownership of the permit - permit.disarm(); - Self::new_permit(pool) - } - - /// Release the semaphore permit without decreasing the pool size. - /// - /// If the permit was stolen from the pool's parent, it will be returned to the child's semaphore. - fn release_permit(self) { - self.pool.semaphore.release(1); - self.cancel(); - } - - pub fn cancel(mut self) { - self.cancelled = true; - } -} - -impl Drop for DecrementSizeGuard { - fn drop(&mut self) { - if !self.cancelled { - self.pool.size.fetch_sub(1, Ordering::AcqRel); - - // and here we release the permit we got on construction - self.pool.semaphore.release(1); - } - } -} diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index e998618413..1119e1a0d3 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -59,7 +59,6 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use std::time::{Duration, Instant}; use event_listener::EventListener; use futures_core::FusedFuture; @@ -70,6 +69,7 @@ use crate::database::Database; use crate::error::Error; use crate::transaction::Transaction; +pub use self::connect::{PoolConnectMetadata, PoolConnector}; pub use self::connection::PoolConnection; use self::inner::PoolInner; #[doc(hidden)] @@ -82,8 +82,11 @@ mod executor; #[macro_use] pub mod maybe; +mod connect; mod connection; mod inner; + +mod idle; mod options; /// An asynchronous pool of SQLx database connections. @@ -354,7 +357,7 @@ impl Pool { /// returning it. pub fn acquire(&self) -> impl Future, Error>> + 'static { let shared = self.0.clone(); - async move { shared.acquire().await.map(|conn| conn.reattach()) } + async move { shared.acquire().await } } /// Attempts to retrieve a connection from the pool if there is one available. @@ -496,7 +499,7 @@ impl Pool { } /// Returns the number of connections currently active. This includes idle connections. - pub fn size(&self) -> u32 { + pub fn size(&self) -> usize { self.0.size() } @@ -505,28 +508,6 @@ impl Pool { self.0.num_idle() } - /// Gets a clone of the connection options for this pool - pub fn connect_options(&self) -> Arc<::Options> { - self.0 - .connect_options - .read() - .expect("write-lock holder panicked") - .clone() - } - - /// Updates the connection options this pool will use when opening any future connections. Any - /// existing open connection in the pool will be left as-is. - pub fn set_connect_options(&self, connect_options: ::Options) { - // technically write() could also panic if the current thread already holds the lock, - // but because this method can't be re-entered by the same thread that shouldn't be a problem - let mut guard = self - .0 - .connect_options - .write() - .expect("write-lock holder panicked"); - *guard = Arc::new(connect_options); - } - /// Get the options for this pool pub fn options(&self) -> &PoolOptions { &self.0.options @@ -610,15 +591,6 @@ impl FusedFuture for CloseEvent { } } -/// get the time between the deadline and now and use that as our timeout -/// -/// returns `Error::PoolTimedOut` if the deadline is in the past -fn deadline_as_timeout(deadline: Instant) -> Result { - deadline - .checked_duration_since(Instant::now()) - .ok_or(Error::PoolTimedOut) -} - #[test] #[allow(dead_code)] fn assert_pool_traits() { diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 96dbf8ee3d..c59fba008e 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -1,8 +1,9 @@ use crate::connection::Connection; use crate::database::Database; use crate::error::Error; +use crate::pool::connect::DefaultConnector; use crate::pool::inner::PoolInner; -use crate::pool::Pool; +use crate::pool::{Pool, PoolConnector}; use futures_core::future::BoxFuture; use log::LevelFilter; use std::fmt::{self, Debug, Formatter}; @@ -44,14 +45,6 @@ use std::time::{Duration, Instant}; /// the perspectives of both API designer and consumer. pub struct PoolOptions { pub(crate) test_before_acquire: bool, - pub(crate) after_connect: Option< - Arc< - dyn Fn(&mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'_, Result<(), Error>> - + 'static - + Send - + Sync, - >, - >, pub(crate) before_acquire: Option< Arc< dyn Fn( @@ -74,12 +67,13 @@ pub struct PoolOptions { + Sync, >, >, - pub(crate) max_connections: u32, + pub(crate) max_connections: usize, pub(crate) acquire_time_level: LevelFilter, pub(crate) acquire_slow_level: LevelFilter, pub(crate) acquire_slow_threshold: Duration, pub(crate) acquire_timeout: Duration, - pub(crate) min_connections: u32, + pub(crate) connect_timeout: Duration, + pub(crate) min_connections: usize, pub(crate) max_lifetime: Option, pub(crate) idle_timeout: Option, pub(crate) fair: bool, @@ -94,7 +88,6 @@ impl Clone for PoolOptions { fn clone(&self) -> Self { PoolOptions { test_before_acquire: self.test_before_acquire, - after_connect: self.after_connect.clone(), before_acquire: self.before_acquire.clone(), after_release: self.after_release.clone(), max_connections: self.max_connections, @@ -102,6 +95,7 @@ impl Clone for PoolOptions { acquire_slow_threshold: self.acquire_slow_threshold, acquire_slow_level: self.acquire_slow_level, acquire_timeout: self.acquire_timeout, + connect_timeout: self.connect_timeout, min_connections: self.min_connections, max_lifetime: self.max_lifetime, idle_timeout: self.idle_timeout, @@ -143,7 +137,6 @@ impl PoolOptions { pub fn new() -> Self { Self { // User-specifiable routines - after_connect: None, before_acquire: None, after_release: None, test_before_acquire: true, @@ -158,6 +151,7 @@ impl PoolOptions { // to not flag typical time to add a new connection to a pool. acquire_slow_threshold: Duration::from_secs(2), acquire_timeout: Duration::from_secs(30), + connect_timeout: Duration::from_secs(2 * 60), idle_timeout: Some(Duration::from_secs(10 * 60)), max_lifetime: Some(Duration::from_secs(30 * 60)), fair: true, @@ -170,13 +164,13 @@ impl PoolOptions { /// Be mindful of the connection limits for your database as well as other applications /// which may want to connect to the same database (or even multiple instances of the same /// application in high-availability deployments). - pub fn max_connections(mut self, max: u32) -> Self { + pub fn max_connections(mut self, max: usize) -> Self { self.max_connections = max; self } /// Get the maximum number of connections that this pool should maintain - pub fn get_max_connections(&self) -> u32 { + pub fn get_max_connections(&self) -> usize { self.max_connections } @@ -202,13 +196,13 @@ impl PoolOptions { /// [`max_lifetime`]: Self::max_lifetime /// [`idle_timeout`]: Self::idle_timeout /// [`max_connections`]: Self::max_connections - pub fn min_connections(mut self, min: u32) -> Self { + pub fn min_connections(mut self, min: usize) -> Self { self.min_connections = min; self } /// Get the minimum number of connections to maintain at all times. - pub fn get_min_connections(&self) -> u32 { + pub fn get_min_connections(&self) -> usize { self.min_connections } @@ -268,6 +262,23 @@ impl PoolOptions { self.acquire_timeout } + /// Set the maximum amount of time to spend attempting to open a connection. + /// + /// This timeout happens independently of [`acquire_timeout`][Self::acquire_timeout]. + /// + /// If shorter than `acquire_timeout`, this will cause the last connec + pub fn connect_timeout(mut self, timeout: Duration) -> Self { + self.connect_timeout = timeout; + self + } + + /// Get the maximum amount of time to spend attempting to open a connection. + /// + /// This timeout happens independently of [`acquire_timeout`][Self::acquire_timeout]. + pub fn get_connect_timeout(&self) -> Duration { + self.connect_timeout + } + /// Set the maximum lifetime of individual connections. /// /// Any connection with a lifetime greater than this will be closed. @@ -339,57 +350,6 @@ impl PoolOptions { self } - /// Perform an asynchronous action after connecting to the database. - /// - /// If the operation returns with an error then the error is logged, the connection is closed - /// and a new one is opened in its place and the callback is invoked again. - /// - /// This occurs in a backoff loop to avoid high CPU usage and spamming logs during a transient - /// error condition. - /// - /// Note that this may be called for internally opened connections, such as when maintaining - /// [`min_connections`][Self::min_connections], that are then immediately returned to the pool - /// without invoking [`after_release`][Self::after_release]. - /// - /// # Example: Additional Parameters - /// This callback may be used to set additional configuration parameters - /// that are not exposed by the database's `ConnectOptions`. - /// - /// This example is written for PostgreSQL but can likely be adapted to other databases. - /// - /// ```no_run - /// # async fn f() -> Result<(), Box> { - /// use sqlx::Executor; - /// use sqlx::postgres::PgPoolOptions; - /// - /// let pool = PgPoolOptions::new() - /// .after_connect(|conn, _meta| Box::pin(async move { - /// // When directly invoking `Executor` methods, - /// // it is possible to execute multiple statements with one call. - /// conn.execute("SET application_name = 'your_app'; SET search_path = 'my_schema';") - /// .await?; - /// - /// Ok(()) - /// })) - /// .connect("postgres:// …").await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// For a discussion on why `Box::pin()` is required, see [the type-level docs][Self]. - pub fn after_connect(mut self, callback: F) -> Self - where - // We're passing the `PoolConnectionMetadata` here mostly for future-proofing. - // `age` and `idle_for` are obviously not useful for fresh connections. - for<'c> F: Fn(&'c mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'c, Result<(), Error>> - + 'static - + Send - + Sync, - { - self.after_connect = Some(Arc::new(callback)); - self - } - /// Perform an asynchronous action on a previously idle connection before giving it out. /// /// Alongside the connection, the closure gets [`PoolConnectionMetadata`] which contains @@ -537,11 +497,25 @@ impl PoolOptions { pub async fn connect_with( self, options: ::Options, + ) -> Result, Error> { + self.connect_with_connector(DefaultConnector(options)).await + } + + /// Create a new pool from this `PoolOptions` and immediately open at least one connection. + /// + /// This ensures the configuration is correct. + /// + /// The total number of connections opened is max(1, [min_connections][Self::min_connections]). + /// + /// See [PoolConnector] for examples. + pub async fn connect_with_connector( + self, + connector: impl PoolConnector, ) -> Result, Error> { // Don't take longer than `acquire_timeout` starting from when this is called. let deadline = Instant::now() + self.acquire_timeout; - let inner = PoolInner::new_arc(self, options); + let inner = PoolInner::new_arc(self, connector); if inner.options.min_connections > 0 { // If the idle reaper is spawned then this will race with the call from that task @@ -552,7 +526,7 @@ impl PoolOptions { // If `min_connections` is nonzero then we'll likely just pull a connection // from the idle queue here, but it should at least get tested first. let conn = inner.acquire().await?; - inner.release(conn); + inner.release(conn.into_floating()); Ok(Pool(inner)) } @@ -578,7 +552,11 @@ impl PoolOptions { /// optimistically establish that many connections for the pool. pub fn connect_lazy_with(self, options: ::Options) -> Pool { // `min_connections` is guaranteed by the idle reaper now. - Pool(PoolInner::new_arc(self, options)) + self.connect_lazy_with_connector(DefaultConnector(options)) + } + + pub fn connect_lazy_with_connector(self, connector: impl PoolConnector) -> Pool { + Pool(PoolInner::new_arc(self, connector)) } } diff --git a/sqlx-core/src/raw_sql.rs b/sqlx-core/src/raw_sql.rs index 37627d4453..1819a2bbc0 100644 --- a/sqlx-core/src/raw_sql.rs +++ b/sqlx-core/src/raw_sql.rs @@ -114,7 +114,7 @@ pub struct RawSql<'q>(&'q str); /// /// See [MySQL manual, section 13.3.3: Statements That Cause an Implicit Commit](https://dev.mysql.com/doc/refman/8.0/en/implicit-commit.html) for details. /// See also: [MariaDB manual: SQL statements That Cause an Implicit Commit](https://mariadb.com/kb/en/sql-statements-that-cause-an-implicit-commit/). -pub fn raw_sql(sql: &str) -> RawSql<'_> { +pub fn raw_sql<'q>(sql: &'q str) -> RawSql<'q> { RawSql(sql) } @@ -138,27 +138,26 @@ impl<'q, DB: Database> Execute<'q, DB> for RawSql<'q> { impl<'q> RawSql<'q> { /// Execute the SQL string and return the total number of rows affected. - #[inline] - pub async fn execute<'e, E>( + pub async fn execute<'e, 'c: 'e, E>( self, executor: E, ) -> crate::Result<::QueryResult> where 'q: 'e, - E: Executor<'e>, + E: Executor<'c>, { executor.execute(self).await } /// Execute the SQL string. Returns a stream which gives the number of rows affected for each statement in the string. #[inline] - pub fn execute_many<'e, E>( + pub fn execute_many<'e, 'c: 'e, E>( self, executor: E, ) -> BoxStream<'e, crate::Result<::QueryResult>> where 'q: 'e, - E: Executor<'e>, + E: Executor<'c>, { executor.execute_many(self) } @@ -167,13 +166,13 @@ impl<'q> RawSql<'q> { /// /// If the string contains multiple statements, their results will be concatenated together. #[inline] - pub fn fetch<'e, E>( + pub fn fetch<'e, 'c: 'e, E>( self, executor: E, ) -> BoxStream<'e, Result<::Row, Error>> where 'q: 'e, - E: Executor<'e>, + E: Executor<'c>, { executor.fetch(self) } @@ -183,7 +182,7 @@ impl<'q> RawSql<'q> { /// For each query in the stream, any generated rows are returned first, /// then the `QueryResult` with the number of rows affected. #[inline] - pub fn fetch_many<'e, E>( + pub fn fetch_many<'e, 'c: 'e, E>( self, executor: E, ) -> BoxStream< @@ -195,7 +194,7 @@ impl<'q> RawSql<'q> { > where 'q: 'e, - E: Executor<'e>, + E: Executor<'c>, { executor.fetch_many(self) } @@ -208,13 +207,13 @@ impl<'q> RawSql<'q> { /// To avoid exhausting available memory, ensure the result set has a known upper bound, /// e.g. using `LIMIT`. #[inline] - pub async fn fetch_all<'e, E>( + pub async fn fetch_all<'e, 'c: 'e, E>( self, executor: E, ) -> crate::Result::Row>> where 'q: 'e, - E: Executor<'e>, + E: Executor<'c>, { executor.fetch_all(self).await } @@ -232,13 +231,13 @@ impl<'q> RawSql<'q> { /// /// Otherwise, you might want to add `LIMIT 1` to your query. #[inline] - pub async fn fetch_one<'e, E>( + pub async fn fetch_one<'e, 'c: 'e, E>( self, executor: E, ) -> crate::Result<::Row> where 'q: 'e, - E: Executor<'e>, + E: Executor<'c>, { executor.fetch_one(self).await } @@ -256,13 +255,13 @@ impl<'q> RawSql<'q> { /// /// Otherwise, you might want to add `LIMIT 1` to your query. #[inline] - pub async fn fetch_optional<'e, E>( + pub async fn fetch_optional<'e, 'c: 'e, E>( self, executor: E, ) -> crate::Result<::Row> where 'q: 'e, - E: Executor<'e>, + E: Executor<'c>, { executor.fetch_one(self).await } diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 43409073ab..bef5e97158 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -2,7 +2,7 @@ use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; -use std::time::Duration; +use std::time::{Duration, Instant}; #[cfg(feature = "_rt-async-std")] pub mod rt_async_std; @@ -42,6 +42,29 @@ pub async fn timeout(duration: Duration, f: F) -> Result(deadline: Instant, f: F) -> Result { + #[cfg(feature = "_rt-tokio")] + if rt_tokio::available() { + return tokio::time::timeout_at(deadline.into(), f) + .await + .map_err(|_| TimeoutError(())); + } + + #[cfg(feature = "_rt-async-std")] + { + let Some(duration) = deadline.checked_duration_since(Instant::now()) else { + return Err(TimeoutError(())); + }; + + async_std::future::timeout(duration, f) + .await + .map_err(|_| TimeoutError(())) + } + + #[cfg(not(feature = "_rt-async-std"))] + missing_rt((deadline, f)) +} + pub async fn sleep(duration: Duration) { #[cfg(feature = "_rt-tokio")] if rt_tokio::available() { diff --git a/sqlx-core/src/sync.rs b/sqlx-core/src/sync.rs index 27ad29c33e..971752f88f 100644 --- a/sqlx-core/src/sync.rs +++ b/sqlx-core/src/sync.rs @@ -9,135 +9,3 @@ pub use async_std::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; #[cfg(feature = "_rt-tokio")] pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; - -pub struct AsyncSemaphore { - // We use the semaphore from futures-intrusive as the one from async-std - // is missing the ability to add arbitrary permits, and is not guaranteed to be fair: - // * https://github.com/smol-rs/async-lock/issues/22 - // * https://github.com/smol-rs/async-lock/issues/23 - // - // We're on the look-out for a replacement, however, as futures-intrusive is not maintained - // and there are some soundness concerns (although it turns out any intrusive future is unsound - // in MIRI due to the necessitated mutable aliasing): - // https://github.com/launchbadge/sqlx/issues/1668 - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - inner: futures_intrusive::sync::Semaphore, - - #[cfg(feature = "_rt-tokio")] - inner: tokio::sync::Semaphore, -} - -impl AsyncSemaphore { - #[track_caller] - pub fn new(fair: bool, permits: usize) -> Self { - if cfg!(not(any(feature = "_rt-async-std", feature = "_rt-tokio"))) { - crate::rt::missing_rt((fair, permits)); - } - - AsyncSemaphore { - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - inner: futures_intrusive::sync::Semaphore::new(fair, permits), - #[cfg(feature = "_rt-tokio")] - inner: { - debug_assert!(fair, "Tokio only has fair permits"); - tokio::sync::Semaphore::new(permits) - }, - } - } - - pub fn permits(&self) -> usize { - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - return self.inner.permits(); - - #[cfg(feature = "_rt-tokio")] - return self.inner.available_permits(); - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - crate::rt::missing_rt(()) - } - - pub async fn acquire(&self, permits: u32) -> AsyncSemaphoreReleaser<'_> { - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - return AsyncSemaphoreReleaser { - inner: self.inner.acquire(permits as usize).await, - }; - - #[cfg(feature = "_rt-tokio")] - return AsyncSemaphoreReleaser { - inner: self - .inner - // Weird quirk: `tokio::sync::Semaphore` mostly uses `usize` for permit counts, - // but `u32` for this and `try_acquire_many()`. - .acquire_many(permits) - .await - .expect("BUG: we do not expose the `.close()` method"), - }; - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - crate::rt::missing_rt(permits) - } - - pub fn try_acquire(&self, permits: u32) -> Option> { - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - return Some(AsyncSemaphoreReleaser { - inner: self.inner.try_acquire(permits as usize)?, - }); - - #[cfg(feature = "_rt-tokio")] - return Some(AsyncSemaphoreReleaser { - inner: self.inner.try_acquire_many(permits).ok()?, - }); - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - crate::rt::missing_rt(permits) - } - - pub fn release(&self, permits: usize) { - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - return self.inner.release(permits); - - #[cfg(feature = "_rt-tokio")] - return self.inner.add_permits(permits); - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - crate::rt::missing_rt(permits) - } -} - -pub struct AsyncSemaphoreReleaser<'a> { - // We use the semaphore from futures-intrusive as the one from async-std - // is missing the ability to add arbitrary permits, and is not guaranteed to be fair: - // * https://github.com/smol-rs/async-lock/issues/22 - // * https://github.com/smol-rs/async-lock/issues/23 - // - // We're on the look-out for a replacement, however, as futures-intrusive is not maintained - // and there are some soundness concerns (although it turns out any intrusive future is unsound - // in MIRI due to the necessitated mutable aliasing): - // https://github.com/launchbadge/sqlx/issues/1668 - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - inner: futures_intrusive::sync::SemaphoreReleaser<'a>, - - #[cfg(feature = "_rt-tokio")] - inner: tokio::sync::SemaphorePermit<'a>, - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - _phantom: std::marker::PhantomData<&'a ()>, -} - -impl AsyncSemaphoreReleaser<'_> { - pub fn disarm(self) { - #[cfg(feature = "_rt-tokio")] - { - self.inner.forget(); - } - - #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] - { - let mut this = self; - this.inner.disarm(); - } - - #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] - crate::rt::missing_rt(()) - } -} diff --git a/sqlx-mysql/src/testing/mod.rs b/sqlx-mysql/src/testing/mod.rs index 2a9216d1b8..4c41d29bed 100644 --- a/sqlx-mysql/src/testing/mod.rs +++ b/sqlx-mysql/src/testing/mod.rs @@ -1,5 +1,4 @@ use std::fmt::Write; -use std::ops::Deref; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{Duration, SystemTime}; @@ -89,27 +88,11 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .max_connections(20) // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. .after_release(|_conn, _| Box::pin(async move { Ok(false) })) - .connect_lazy_with(master_opts); - - let master_pool = match MASTER_POOL.try_insert(pool) { - Ok(inserted) => inserted, - Err((existing, pool)) => { - // Sanity checks. - assert_eq!( - existing.connect_options().host, - pool.connect_options().host, - "DATABASE_URL changed at runtime, host differs" - ); - - assert_eq!( - existing.connect_options().database, - pool.connect_options().database, - "DATABASE_URL changed at runtime, database differs" - ); - - existing - } - }; + .connect_lazy_with(master_opts.clone()); + + let master_pool = MASTER_POOL + .try_insert(pool) + .unwrap_or_else(|(existing, _pool)| existing); let mut conn = master_pool.acquire().await?; @@ -163,11 +146,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { // Close connections ASAP if left in the idle queue. .idle_timeout(Some(Duration::from_secs(1))) .parent(master_pool.clone()), - connect_opts: master_pool - .connect_options() - .deref() - .clone() - .database(&new_db_name), + connect_opts: master_opts.database(&new_db_name), db_name: new_db_name, }) } diff --git a/sqlx-postgres/src/error.rs b/sqlx-postgres/src/error.rs index db8bcc8a10..193579b76b 100644 --- a/sqlx-postgres/src/error.rs +++ b/sqlx-postgres/src/error.rs @@ -186,7 +186,7 @@ impl DatabaseError for PgDatabaseError { self } - fn is_transient_in_connect_phase(&self) -> bool { + fn is_retryable_connect_error(&self) -> bool { // https://www.postgresql.org/docs/current/errcodes-appendix.html [ // too_many_connections diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index a0b222606a..99327bd6b9 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -206,6 +206,11 @@ impl PgConnectOptions { self } + /// Identical to [Self::host()], but through a mutable reference. + pub fn set_host(&mut self, host: &str) { + host.clone_into(&mut self.host); + } + /// Sets the port to connect to at the server host. /// /// The default port for PostgreSQL is `5432`. @@ -222,6 +227,12 @@ impl PgConnectOptions { self } + /// Identical to [`Self::port()`], but through a mutable reference. + pub fn set_port(&mut self, port: u16) -> &mut Self { + self.port = port; + self + } + /// Sets a custom path to a directory containing a unix domain socket, /// switching the connection method from TCP to the corresponding socket. /// @@ -248,6 +259,12 @@ impl PgConnectOptions { self } + /// Identical to [`Self::username()`], but through a mutable reference. + pub fn set_username(&mut self, username: &str) -> &mut Self { + username.clone_into(&mut self.username); + self + } + /// Sets the password to use if the server demands password authentication. /// /// # Example @@ -263,6 +280,12 @@ impl PgConnectOptions { self } + /// Identical to [`Self::password()`]. but through a mutable reference. + pub fn set_password(&mut self, password: &str) -> &mut Self { + self.password = Some(password.to_owned()); + self + } + /// Sets the database name. Defaults to be the same as the user name. /// /// # Example diff --git a/sqlx-postgres/src/testing/mod.rs b/sqlx-postgres/src/testing/mod.rs index fb36ab4136..f86ddefef6 100644 --- a/sqlx-postgres/src/testing/mod.rs +++ b/sqlx-postgres/src/testing/mod.rs @@ -1,5 +1,4 @@ use std::fmt::Write; -use std::ops::Deref; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{Duration, SystemTime}; @@ -86,27 +85,11 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .max_connections(20) // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. .after_release(|_conn, _| Box::pin(async move { Ok(false) })) - .connect_lazy_with(master_opts); - - let master_pool = match MASTER_POOL.try_insert(pool) { - Ok(inserted) => inserted, - Err((existing, pool)) => { - // Sanity checks. - assert_eq!( - existing.connect_options().host, - pool.connect_options().host, - "DATABASE_URL changed at runtime, host differs" - ); - - assert_eq!( - existing.connect_options().database, - pool.connect_options().database, - "DATABASE_URL changed at runtime, database differs" - ); - - existing - } - }; + .connect_lazy_with(master_opts.clone()); + + let master_pool = MASTER_POOL + .try_insert(pool) + .unwrap_or_else(|(existing, _pool)| existing); let mut conn = master_pool.acquire().await?; @@ -170,11 +153,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { // Close connections ASAP if left in the idle queue. .idle_timeout(Some(Duration::from_secs(1))) .parent(master_pool.clone()), - connect_opts: master_pool - .connect_options() - .deref() - .clone() - .database(&new_db_name), + connect_opts: master_opts.database(&new_db_name), db_name: new_db_name, }) } diff --git a/tests/any/pool.rs b/tests/any/pool.rs index 3130b4f1c6..2502bac8ab 100644 --- a/tests/any/pool.rs +++ b/tests/any/pool.rs @@ -1,44 +1,13 @@ use sqlx::any::{AnyConnectOptions, AnyPoolOptions}; use sqlx::Executor; +use sqlx_core::connection::ConnectOptions; +use sqlx_core::pool::PoolConnectMetadata; use std::sync::{ - atomic::{AtomicI32, AtomicUsize, Ordering}, + atomic::{AtomicI32, Ordering}, Arc, Mutex, }; use std::time::Duration; -#[sqlx_macros::test] -async fn pool_should_invoke_after_connect() -> anyhow::Result<()> { - sqlx::any::install_default_drivers(); - - let counter = Arc::new(AtomicUsize::new(0)); - - let pool = AnyPoolOptions::new() - .after_connect({ - let counter = counter.clone(); - move |_conn, _meta| { - let counter = counter.clone(); - Box::pin(async move { - counter.fetch_add(1, Ordering::SeqCst); - - Ok(()) - }) - } - }) - .connect(&dotenvy::var("DATABASE_URL")?) - .await?; - - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - - // since connections are released asynchronously, - // `.after_connect()` may be called more than once - assert!(counter.load(Ordering::SeqCst) >= 1); - - Ok(()) -} - // https://github.com/launchbadge/sqlx/issues/527 #[sqlx_macros::test] async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> { @@ -83,38 +52,13 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { sqlx_test::setup_if_needed(); - let conn_options: AnyConnectOptions = std::env::var("DATABASE_URL")?.parse()?; + let conn_options: Arc = Arc::new(std::env::var("DATABASE_URL")?.parse()?); let current_id = AtomicI32::new(0); let pool = AnyPoolOptions::new() .max_connections(1) .acquire_timeout(Duration::from_secs(5)) - .after_connect(move |conn, meta| { - assert_eq!(meta.age, Duration::ZERO); - assert_eq!(meta.idle_for, Duration::ZERO); - - let id = current_id.fetch_add(1, Ordering::AcqRel); - - Box::pin(async move { - let statement = format!( - // language=SQL - r#" - CREATE TEMPORARY TABLE conn_stats( - id int primary key, - before_acquire_calls int default 0, - after_release_calls int default 0 - ); - INSERT INTO conn_stats(id) VALUES ({}); - "#, - // Until we have generalized bind parameters - id - ); - - conn.execute(&statement[..]).await?; - Ok(()) - }) - }) .before_acquire(|conn, meta| { // `age` and `idle_for` should both be nonzero assert_ne!(meta.age, Duration::ZERO); @@ -165,7 +109,31 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { }) }) // Don't establish a connection yet. - .connect_lazy_with(conn_options); + .connect_lazy_with_connector(move |_meta: PoolConnectMetadata| { + let connect_opts = Arc::clone(&conn_options); + let id = current_id.fetch_add(1, Ordering::AcqRel); + + async move { + let mut conn = connect_opts.connect().await?; + + let statement = format!( + // language=SQL + r#" + CREATE TEMPORARY TABLE conn_stats( + id int primary key, + before_acquire_calls int default 0, + after_release_calls int default 0 + ); + INSERT INTO conn_stats(id) VALUES ({}); + "#, + // Until we have generalized bind parameters + id + ); + + conn.execute(&statement[..]).await?; + Ok(conn) + } + }); // Expected pattern of (id, before_acquire_calls, after_release_calls) let pattern = [ diff --git a/tests/sqlite/any.rs b/tests/sqlite/any.rs index 856db70c05..b71c3ba43d 100644 --- a/tests/sqlite/any.rs +++ b/tests/sqlite/any.rs @@ -1,4 +1,4 @@ -use sqlx::{Any, Sqlite}; +use sqlx::Any; use sqlx_test::new; #[sqlx_macros::test]