From 4d349df3983d0b27e8f16434a089fbd77a2a6d05 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Tue, 17 Dec 2024 20:21:53 +0100 Subject: [PATCH] Share the statement cache with diesel This commit refactors diesel-async to use the same statement cache implementation as diesel. That brings in all the optimisations done to the diesel statement cache. --- Cargo.toml | 28 +++- .../postgres/pooled-with-rustls/Cargo.toml | 8 +- .../Cargo.toml | 13 +- examples/sync-wrapper/Cargo.toml | 14 +- src/async_connection_wrapper.rs | 6 +- src/lib.rs | 5 +- src/mysql/mod.rs | 115 +++++++++-------- src/pg/mod.rs | 121 +++++++++++------- src/pooled_connection/mod.rs | 6 +- src/stmt_cache.rs | 120 +++++++---------- src/sync_connection_wrapper/mod.rs | 16 ++- 11 files changed, 264 insertions(+), 188 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dd22776..656fed9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,9 +13,6 @@ description = "An async extension for Diesel the safe, extensible ORM and Query rust-version = "1.78.0" [dependencies] -diesel = { version = "~2.2.0", default-features = false, features = [ - "i-implement-a-third-party-backend-and-opt-into-breaking-changes", -] } async-trait = "0.1.66" futures-channel = { version = "0.3.17", default-features = false, features = [ "std", @@ -39,14 +36,35 @@ deadpool = { version = "0.12", optional = true, default-features = false, featur mobc = { version = ">=0.7,<0.10", optional = true } scoped-futures = { version = "0.1", features = ["std"] } +[dependencies.diesel] +version = "~2.2.0" +default-features = false +features = [ + "i-implement-a-third-party-backend-and-opt-into-breaking-changes", +] +git = "https://github.com/diesel-rs/diesel" +branch = "master" + [dev-dependencies] tokio = { version = "1.12.0", features = ["rt", "macros", "rt-multi-thread"] } cfg-if = "1" chrono = "0.4" -diesel = { version = "2.2.0", default-features = false, features = ["chrono"] } -diesel_migrations = "2.2.0" assert_matches = "1.0.1" +[dev-dependencies.diesel] +version = "~2.2.0" +default-features = false +features = [ + "chrono" +] +git = "https://github.com/diesel-rs/diesel" +branch = "master" + +[dev-dependencies.diesel_migrations] +version = "2.2.0" +git = "https://github.com/diesel-rs/diesel" +branch = "master" + [features] default = [] mysql = [ diff --git a/examples/postgres/pooled-with-rustls/Cargo.toml b/examples/postgres/pooled-with-rustls/Cargo.toml index 28c6093..452b28c 100644 --- a/examples/postgres/pooled-with-rustls/Cargo.toml +++ b/examples/postgres/pooled-with-rustls/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -diesel = { version = "2.2.0", default-features = false, features = ["postgres"] } diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres"] } futures-util = "0.3.21" rustls = "0.23.8" @@ -14,3 +13,10 @@ rustls-native-certs = "0.7.1" tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] } tokio-postgres = "0.7.7" tokio-postgres-rustls = "0.12.0" + + +[dependencies.diesel] +version = "2.2.0" +default-features = false +git = "https://github.com/diesel-rs/diesel" +branch = "master" diff --git a/examples/postgres/run-pending-migrations-with-rustls/Cargo.toml b/examples/postgres/run-pending-migrations-with-rustls/Cargo.toml index 2f54ab4..0621ce7 100644 --- a/examples/postgres/run-pending-migrations-with-rustls/Cargo.toml +++ b/examples/postgres/run-pending-migrations-with-rustls/Cargo.toml @@ -6,12 +6,21 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -diesel = { version = "2.2.0", default-features = false, features = ["postgres"] } diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres", "async-connection-wrapper"] } -diesel_migrations = "2.2.0" futures-util = "0.3.21" rustls = "0.23.10" rustls-native-certs = "0.7.1" tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] } tokio-postgres = "0.7.7" tokio-postgres-rustls = "0.12.0" + +[dependencies.diesel] +version = "2.2.0" +default-features = false +git = "https://github.com/diesel-rs/diesel" +branch = "master" + +[dependencies.diesel_migrations] +version = "2.2.0" +git = "https://github.com/diesel-rs/diesel" +branch = "master" diff --git a/examples/sync-wrapper/Cargo.toml b/examples/sync-wrapper/Cargo.toml index d578028..c271019 100644 --- a/examples/sync-wrapper/Cargo.toml +++ b/examples/sync-wrapper/Cargo.toml @@ -6,12 +6,22 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -diesel = { version = "2.2.0", default-features = false, features = ["returning_clauses_for_sqlite_3_35"] } diesel-async = { version = "0.5.0", path = "../../", features = ["sync-connection-wrapper", "async-connection-wrapper"] } -diesel_migrations = "2.2.0" futures-util = "0.3.21" tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] } +[dependencies.diesel] +version = "2.2.0" +default-features = false +features = ["returning_clauses_for_sqlite_3_35"] +git = "https://github.com/diesel-rs/diesel" +branch = "master" + +[dependencies.diesel_migrations] +version = "2.2.0" +git = "https://github.com/diesel-rs/diesel" +branch = "master" + [features] default = ["sqlite"] sqlite = ["diesel-async/sqlite"] diff --git a/src/async_connection_wrapper.rs b/src/async_connection_wrapper.rs index 238f279..c817633 100644 --- a/src/async_connection_wrapper.rs +++ b/src/async_connection_wrapper.rs @@ -100,7 +100,7 @@ pub type AsyncConnectionWrapper = pub use self::implementation::AsyncConnectionWrapper; mod implementation { - use diesel::connection::{Instrumentation, SimpleConnection}; + use diesel::connection::{CacheSize, Instrumentation, SimpleConnection}; use std::ops::{Deref, DerefMut}; use super::*; @@ -187,6 +187,10 @@ mod implementation { fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { self.inner.set_instrumentation(instrumentation); } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + self.inner.set_prepared_statement_cache_size(size) + } } impl diesel::connection::LoadConnection for AsyncConnectionWrapper diff --git a/src/lib.rs b/src/lib.rs index 1a4b49c..1b9740c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,7 +74,7 @@ )] use diesel::backend::Backend; -use diesel::connection::Instrumentation; +use diesel::connection::{CacheSize, Instrumentation}; use diesel::query_builder::{AsQuery, QueryFragment, QueryId}; use diesel::result::Error; use diesel::row::Row; @@ -354,4 +354,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { /// Set a specific [`Instrumentation`] implementation for this connection fn set_instrumentation(&mut self, instrumentation: impl Instrumentation); + + /// Set the prepared statement cache size to [`CacheSize`] for this connection + fn set_prepared_statement_cache_size(&mut self, size: CacheSize); } diff --git a/src/mysql/mod.rs b/src/mysql/mod.rs index 9158f62..b357304 100644 --- a/src/mysql/mod.rs +++ b/src/mysql/mod.rs @@ -1,9 +1,11 @@ -use crate::stmt_cache::{PrepareCallback, StmtCache}; +use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper}; use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection}; -use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey}; -use diesel::connection::Instrumentation; -use diesel::connection::InstrumentationEvent; +use diesel::connection::statement_cache::{ + MaybeCached, QueryFragmentForCachedStatement, StatementCache, +}; use diesel::connection::StrQueryHelper; +use diesel::connection::{CacheSize, Instrumentation}; +use diesel::connection::{DynInstrumentation, InstrumentationEvent}; use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType}; use diesel::query_builder::QueryBuilder; use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId}; @@ -27,9 +29,9 @@ use self::serialize::ToSqlHelper; /// `mysql://[user[:password]@]host/database_name` pub struct AsyncMysqlConnection { conn: mysql_async::Conn, - stmt_cache: StmtCache, + stmt_cache: StatementCache, transaction_manager: AnsiTransactionManager, - instrumentation: std::sync::Mutex>>, + instrumentation: DynInstrumentation, } #[async_trait::async_trait] @@ -72,7 +74,7 @@ impl AsyncConnection for AsyncMysqlConnection { type TransactionManager = AnsiTransactionManager; async fn establish(database_url: &str) -> diesel::ConnectionResult { - let mut instrumentation = diesel::connection::get_default_instrumentation(); + let mut instrumentation = DynInstrumentation::default_instrumentation(); instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection( database_url, )); @@ -82,7 +84,7 @@ impl AsyncConnection for AsyncMysqlConnection { r.as_ref().err(), )); let mut conn = r?; - conn.instrumentation = std::sync::Mutex::new(instrumentation); + conn.instrumentation = instrumentation; Ok(conn) } @@ -177,16 +179,15 @@ impl AsyncConnection for AsyncMysqlConnection { } fn instrumentation(&mut self) -> &mut dyn Instrumentation { - self.instrumentation - .get_mut() - .unwrap_or_else(|p| p.into_inner()) + &mut *self.instrumentation } fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { - *self - .instrumentation - .get_mut() - .unwrap_or_else(|p| p.into_inner()) = Some(Box::new(instrumentation)); + self.instrumentation = instrumentation.into(); + } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + self.stmt_cache.set_cache_size(size); } } @@ -207,17 +208,24 @@ fn update_transaction_manager_status( query_result } -#[async_trait::async_trait] -impl PrepareCallback for &'_ mut mysql_async::Conn { - async fn prepare( - self, - sql: &str, - _metadata: &[MysqlType], - _is_for_cache: diesel::connection::statement_cache::PrepareForCache, - ) -> QueryResult<(Statement, Self)> { - let s = self.prep(sql).await.map_err(ErrorHelper)?; - Ok((s, self)) - } +fn prepare_statement_helper<'a, 'b>( + conn: &'a mut mysql_async::Conn, + sql: &'b str, + _is_for_cache: diesel::connection::statement_cache::PrepareForCache, + _metadata: &[MysqlType], +) -> CallbackHelper> + Send> +{ + // ideally we wouldn't clone the SQL string here + // but as we usually cache statements anyway + // this is a fixed one time const + // + // The probleme with not cloning it is that we then cannot express + // the right result lifetime anymore (at least not easily) + let sql = sql.to_owned(); + CallbackHelper(async move { + let s = conn.prep(sql).await.map_err(ErrorHelper)?; + Ok((s, conn)) + }) } impl AsyncMysqlConnection { @@ -229,11 +237,9 @@ impl AsyncMysqlConnection { use crate::run_query_dsl::RunQueryDsl; let mut conn = AsyncMysqlConnection { conn, - stmt_cache: StmtCache::new(), + stmt_cache: StatementCache::new(), transaction_manager: AnsiTransactionManager::default(), - instrumentation: std::sync::Mutex::new( - diesel::connection::get_default_instrumentation(), - ), + instrumentation: DynInstrumentation::default_instrumentation(), }; for stmt in CONNECTION_SETUP_QUERIES { @@ -286,36 +292,29 @@ impl AsyncMysqlConnection { } = bind_collector?; let is_safe_to_cache_prepared = is_safe_to_cache_prepared?; let sql = sql?; + let helper = QueryFragmentHelper { + sql, + safe_to_cache: is_safe_to_cache_prepared, + }; let inner = async { - let cache_key = if let Some(query_id) = query_id { - StatementCacheKey::Type(query_id) - } else { - StatementCacheKey::Sql { - sql: sql.clone(), - bind_types: metadata.clone(), - } - }; - let (stmt, conn) = stmt_cache - .cached_prepared_statement( - cache_key, - sql.clone(), - is_safe_to_cache_prepared, + .cached_statement_non_generic( + query_id, + &helper, + &Mysql, &metadata, conn, - instrumentation, + prepare_statement_helper, + &mut **instrumentation, ) .await?; callback(conn, stmt, ToSqlHelper { metadata, binds }).await }; let r = update_transaction_manager_status(inner.await, transaction_manager); - instrumentation - .get_mut() - .unwrap_or_else(|p| p.into_inner()) - .on_connection_event(InstrumentationEvent::finish_query( - &StrQueryHelper::new(&sql), - r.as_ref().err(), - )); + instrumentation.on_connection_event(InstrumentationEvent::finish_query( + &StrQueryHelper::new(&helper.sql), + r.as_ref().err(), + )); r } .boxed() @@ -370,9 +369,9 @@ impl AsyncMysqlConnection { Ok(AsyncMysqlConnection { conn, - stmt_cache: StmtCache::new(), + stmt_cache: StatementCache::new(), transaction_manager: AnsiTransactionManager::default(), - instrumentation: std::sync::Mutex::new(None), + instrumentation: DynInstrumentation::none(), }) } } @@ -427,3 +426,13 @@ mod tests { } } } + +impl QueryFragmentForCachedStatement for QueryFragmentHelper { + fn construct_sql(&self, _backend: &Mysql) -> QueryResult { + Ok(self.sql.clone()) + } + + fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult { + Ok(self.safe_to_cache) + } +} diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 2ee7145..a888027 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -7,12 +7,14 @@ use self::error_helper::ErrorHelper; use self::row::PgRow; use self::serialize::ToSqlHelper; -use crate::stmt_cache::{PrepareCallback, StmtCache}; +use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper}; use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection}; -use diesel::connection::statement_cache::{PrepareForCache, StatementCacheKey}; -use diesel::connection::Instrumentation; -use diesel::connection::InstrumentationEvent; +use diesel::connection::statement_cache::{ + PrepareForCache, QueryFragmentForCachedStatement, StatementCache, +}; use diesel::connection::StrQueryHelper; +use diesel::connection::{CacheSize, Instrumentation}; +use diesel::connection::{DynInstrumentation, InstrumentationEvent}; use diesel::pg::{ Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgQueryBuilder, PgTypeMetadata, }; @@ -122,13 +124,13 @@ const FAKE_OID: u32 = 0; /// [tokio_postgres_rustls]: https://docs.rs/tokio-postgres-rustls/0.12.0/tokio_postgres_rustls/ pub struct AsyncPgConnection { conn: Arc, - stmt_cache: Arc>>, + stmt_cache: Arc>>, transaction_state: Arc>, metadata_cache: Arc>, connection_future: Option>>, shutdown_channel: Option>, // a sync mutex is fine here as we only hold it for a really short time - instrumentation: Arc>>>, + instrumentation: Arc>, } #[async_trait::async_trait] @@ -162,7 +164,7 @@ impl AsyncConnection for AsyncPgConnection { type TransactionManager = AnsiTransactionManager; async fn establish(database_url: &str) -> ConnectionResult { - let mut instrumentation = diesel::connection::get_default_instrumentation(); + let mut instrumentation = DynInstrumentation::default_instrumentation(); instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection( database_url, )); @@ -229,14 +231,25 @@ impl AsyncConnection for AsyncPgConnection { // that means there is only one instance of this arc and // we can simply access the inner data if let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) { - instrumentation.get_mut().unwrap_or_else(|p| p.into_inner()) + &mut **(instrumentation.get_mut().unwrap_or_else(|p| p.into_inner())) } else { panic!("Cannot access shared instrumentation") } } fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { - self.instrumentation = Arc::new(std::sync::Mutex::new(Some(Box::new(instrumentation)))); + self.instrumentation = Arc::new(std::sync::Mutex::new(instrumentation.into())); + } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + // there should be no other pending future when this is called + // that means there is only one instance of this arc and + // we can simply access the inner data + if let Some(cache) = Arc::get_mut(&mut self.stmt_cache) { + cache.get_mut().set_cache_size(size) + } else { + panic!("Cannot access shared statement cache") + } } } @@ -293,25 +306,33 @@ fn update_transaction_manager_status( query_result } -#[async_trait::async_trait] -impl PrepareCallback for Arc { - async fn prepare( - self, - sql: &str, - metadata: &[PgTypeMetadata], - _is_for_cache: PrepareForCache, - ) -> QueryResult<(Statement, Self)> { - let bind_types = metadata - .iter() - .map(type_from_oid) - .collect::>>()?; - - let stmt = self - .prepare_typed(sql, &bind_types) +fn prepare_statement_helper<'a>( + conn: Arc, + sql: &'a str, + _is_for_cache: PrepareForCache, + metadata: &[PgTypeMetadata], +) -> CallbackHelper< + impl Future)>> + Send, +> { + let bind_types = metadata + .iter() + .map(type_from_oid) + .collect::>>(); + // ideally we wouldn't clone the SQL string here + // but as we usually cache statements anyway + // this is a fixed one time const + // + // The probleme with not cloning it is that we then cannot express + // the right result lifetime anymore (at least not easily) + let sql = sql.to_string(); + CallbackHelper(async move { + let bind_types = bind_types?; + let stmt = conn + .prepare_typed(&sql, &bind_types) .await .map_err(ErrorHelper); - Ok((stmt?, self)) - } + Ok((stmt?, conn)) + }) } fn type_from_oid(t: &PgTypeMetadata) -> QueryResult { @@ -369,7 +390,7 @@ impl AsyncPgConnection { None, None, Arc::new(std::sync::Mutex::new( - diesel::connection::get_default_instrumentation(), + DynInstrumentation::default_instrumentation(), )), ) .await @@ -390,9 +411,7 @@ impl AsyncPgConnection { client, Some(error_rx), Some(shutdown_tx), - Arc::new(std::sync::Mutex::new( - diesel::connection::get_default_instrumentation(), - )), + Arc::new(std::sync::Mutex::new(DynInstrumentation::none())), ) .await } @@ -401,11 +420,11 @@ impl AsyncPgConnection { conn: tokio_postgres::Client, connection_future: Option>>, shutdown_channel: Option>, - instrumentation: Arc>>>, + instrumentation: Arc>, ) -> ConnectionResult { let mut conn = Self { conn: Arc::new(conn), - stmt_cache: Arc::new(Mutex::new(StmtCache::new())), + stmt_cache: Arc::new(Mutex::new(StatementCache::new())), transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())), metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())), connection_future, @@ -559,23 +578,27 @@ impl AsyncPgConnection { })?; } } - let key = match query_id { - Some(id) => StatementCacheKey::Type(id), - None => StatementCacheKey::Sql { - sql: sql.clone(), - bind_types: bind_collector.metadata.clone(), - }, - }; let stmt = { let mut stmt_cache = stmt_cache.lock().await; + let helper = QueryFragmentHelper { + sql: sql.clone(), + safe_to_cache: is_safe_to_cache_prepared, + }; + let instrumentation = Arc::clone(&instrumentation); stmt_cache - .cached_prepared_statement( - key, - sql.clone(), - is_safe_to_cache_prepared, + .cached_statement_non_generic( + query_id, + &helper, + &Pg, &bind_collector.metadata, raw_connection.clone(), - &instrumentation + prepare_statement_helper, + &mut move |event: InstrumentationEvent<'_>| { + // we wrap this lock into another callback to prevent locking + // the instrumentation longer than necessary + instrumentation.lock().unwrap_or_else(|e| e.into_inner()) + .on_connection_event(event); + }, ) .await? .0 @@ -894,6 +917,16 @@ impl crate::pooled_connection::PoolableConnection for AsyncPgConnection { } } +impl QueryFragmentForCachedStatement for QueryFragmentHelper { + fn construct_sql(&self, _backend: &Pg) -> QueryResult { + Ok(self.sql.clone()) + } + + fn is_safe_to_cache_prepared(&self, _backend: &Pg) -> QueryResult { + Ok(self.safe_to_cache) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/pooled_connection/mod.rs b/src/pooled_connection/mod.rs index 21471b1..e701e8d 100644 --- a/src/pooled_connection/mod.rs +++ b/src/pooled_connection/mod.rs @@ -8,7 +8,7 @@ use crate::{AsyncConnection, SimpleAsyncConnection}; use crate::{TransactionManager, UpdateAndFetchResults}; use diesel::associations::HasTable; -use diesel::connection::Instrumentation; +use diesel::connection::{CacheSize, Instrumentation}; use diesel::QueryResult; use futures_util::{future, FutureExt}; use std::borrow::Cow; @@ -241,6 +241,10 @@ where fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { self.deref_mut().set_instrumentation(instrumentation); } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + self.deref_mut().set_prepared_statement_cache_size(size); + } } #[doc(hidden)] diff --git a/src/stmt_cache.rs b/src/stmt_cache.rs index 9d6b9af..a17568a 100644 --- a/src/stmt_cache.rs +++ b/src/stmt_cache.rs @@ -1,91 +1,57 @@ -use std::collections::HashMap; -use std::hash::Hash; - -use diesel::backend::Backend; -use diesel::connection::statement_cache::{MaybeCached, PrepareForCache, StatementCacheKey}; -use diesel::connection::Instrumentation; -use diesel::connection::InstrumentationEvent; +use diesel::connection::statement_cache::{MaybeCached, StatementCallbackReturnType}; use diesel::QueryResult; -use futures_util::{future, FutureExt}; +use futures_util::{future, FutureExt, TryFutureExt}; +use std::future::Future; -#[derive(Default)] -pub struct StmtCache { - cache: HashMap, S>, -} +pub(crate) struct CallbackHelper(pub(crate) F); -type PrepareFuture<'a, F, S> = future::Either< - future::Ready, F)>>, - future::BoxFuture<'a, QueryResult<(MaybeCached<'a, S>, F)>>, +type PrepareFuture<'a, C, S> = future::Either< + future::Ready, C)>>, + future::BoxFuture<'a, QueryResult<(MaybeCached<'a, S>, C)>>, >; -#[async_trait::async_trait] -pub trait PrepareCallback: Sized { - async fn prepare( - self, - sql: &str, - metadata: &[M], - is_for_cache: PrepareForCache, - ) -> QueryResult<(S, Self)>; -} +impl<'b, S, F, C> StatementCallbackReturnType for CallbackHelper +where + F: Future> + Send + 'b, + S: 'static, +{ + type Return<'a> = PrepareFuture<'a, C, S>; -impl StmtCache { - pub fn new() -> Self { - Self { - cache: HashMap::new(), - } + fn from_error<'a>(e: diesel::result::Error) -> Self::Return<'a> { + future::Either::Left(future::ready(Err(e))) } - pub fn cached_prepared_statement<'a, F>( - &'a mut self, - cache_key: StatementCacheKey, - sql: String, - is_query_safe_to_cache: bool, - metadata: &[DB::TypeMetadata], - prepare_fn: F, - instrumentation: &std::sync::Mutex>>, - ) -> PrepareFuture<'a, F, S> + fn map_to_no_cache<'a>(self) -> Self::Return<'a> where - S: Send, - DB::QueryBuilder: Default, - DB::TypeMetadata: Clone + Send + Sync, - F: PrepareCallback + Send + 'a, - StatementCacheKey: Hash + Eq, + Self: 'a, { - use std::collections::hash_map::Entry::{Occupied, Vacant}; - - if !is_query_safe_to_cache { - let metadata = metadata.to_vec(); - let f = async move { - let stmt = prepare_fn - .prepare(&sql, &metadata, PrepareForCache::No) - .await?; - Ok((MaybeCached::CannotCache(stmt.0), stmt.1)) - } - .boxed(); - return future::Either::Right(f); - } + future::Either::Right( + self.0 + .map_ok(|(stmt, conn)| (MaybeCached::CannotCache(stmt), conn)) + .boxed(), + ) + } - match self.cache.entry(cache_key) { - Occupied(entry) => future::Either::Left(future::ready(Ok(( - MaybeCached::Cached(entry.into_mut()), - prepare_fn, - )))), - Vacant(entry) => { - let metadata = metadata.to_vec(); - instrumentation - .lock() - .unwrap_or_else(|p| p.into_inner()) - .on_connection_event(InstrumentationEvent::cache_query(&sql)); - let f = async move { - let statement = prepare_fn - .prepare(&sql, &metadata, PrepareForCache::Yes) - .await?; + fn map_to_cache<'a>(stmt: &'a mut S, conn: C) -> Self::Return<'a> { + future::Either::Left(future::ready(Ok((MaybeCached::Cached(stmt), conn)))) + } - Ok((MaybeCached::Cached(entry.insert(statement.0)), statement.1)) - } - .boxed(); - future::Either::Right(f) - } - } + fn register_cache<'a>( + self, + callback: impl FnOnce(S) -> &'a mut S + Send + 'a, + ) -> Self::Return<'a> + where + Self: 'a, + { + future::Either::Right( + self.0 + .map_ok(|(stmt, conn)| (MaybeCached::Cached(callback(stmt)), conn)) + .boxed(), + ) } } + +pub(crate) struct QueryFragmentHelper { + pub(crate) sql: String, + pub(crate) safe_to_cache: bool, +} diff --git a/src/sync_connection_wrapper/mod.rs b/src/sync_connection_wrapper/mod.rs index 76a06da..a926d76 100644 --- a/src/sync_connection_wrapper/mod.rs +++ b/src/sync_connection_wrapper/mod.rs @@ -9,7 +9,7 @@ use crate::{AsyncConnection, SimpleAsyncConnection, TransactionManager}; use diesel::backend::{Backend, DieselReserveSpecialization}; -use diesel::connection::Instrumentation; +use diesel::connection::{CacheSize, Instrumentation}; use diesel::connection::{ Connection, LoadConnection, TransactionManagerStatus, WithMetadataLookup, }; @@ -188,6 +188,20 @@ where panic!("Cannot access shared instrumentation") } } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + // there should be no other pending future when this is called + // that means there is only one instance of this arc and + // we can simply access the inner data + if let Some(inner) = Arc::get_mut(&mut self.inner) { + inner + .get_mut() + .unwrap_or_else(|p| p.into_inner()) + .set_prepared_statement_cache_size(size) + } else { + panic!("Cannot access shared cache") + } + } } /// A wrapper of a diesel transaction manager usable in async context.