diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index ad744bae901c..bbd40199b867 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -45,6 +45,15 @@ pub use self::instrumentation::StrQueryHelper; ))] pub(crate) use self::private::MultiConnectionHelper; +/// Set cache size for a connection +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CacheSize { + /// Caches all queries if possible + Unbounded, + /// Disable statement cache + Disabled, +} + /// Perform simple operations on a backend. /// /// You should likely use [`Connection`] instead. @@ -401,6 +410,9 @@ where /// Set a specific [`Instrumentation`] implementation for this connection fn set_instrumentation(&mut self, instrumentation: impl Instrumentation); + + /// Set a cache size [`CacheSize`] for this connection + fn set_cache_size(&mut self, size: CacheSize); } /// The specific part of a [`Connection`] which actually loads data from the database diff --git a/diesel/src/connection/statement_cache.rs b/diesel/src/connection/statement_cache/mod.rs similarity index 84% rename from diesel/src/connection/statement_cache.rs rename to diesel/src/connection/statement_cache/mod.rs index eb44461a4d15..4c61b615bd55 100644 --- a/diesel/src/connection/statement_cache.rs +++ b/diesel/src/connection/statement_cache/mod.rs @@ -10,8 +10,9 @@ //! statements is [`SimpleConnection::batch_execute`](super::SimpleConnection::batch_execute). //! //! In order to avoid the cost of re-parsing and planning subsequent queries, -//! Diesel caches the prepared statement whenever possible. Queries will fall -//! into one of three buckets: +//! by default Diesel caches the prepared statement whenever possible, but +//! this an be customized by calling [`Connection::set_cache_size`](super::Connection::set_cache_size). +//! Queries will fall into one of three buckets: //! //! - Unsafe to cache //! - Cached by SQL @@ -94,16 +95,21 @@ use std::any::TypeId; use std::borrow::Cow; -use std::collections::HashMap; use std::hash::Hash; use std::ops::{Deref, DerefMut}; +use strategy::{StatementCacheStrategy, WithCacheStrategy, WithoutCacheStrategy}; + use crate::backend::Backend; use crate::connection::InstrumentationEvent; use crate::query_builder::*; use crate::result::QueryResult; -use super::Instrumentation; +use super::{CacheSize, Instrumentation}; + +/// Various interfaces and implementations to control connection statement caching. +#[allow(unreachable_pub)] +pub mod strategy; /// A prepared statement cache #[allow(missing_debug_implementations, unreachable_pub)] @@ -112,7 +118,10 @@ use super::Instrumentation; doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")) )] pub struct StatementCache { - pub(crate) cache: HashMap, Statement>, + cache: Box>, + // increment every time a query is cached + // some backends might use it to create unique prepared statement names + cache_counter: u64, } /// A helper type that indicates if a certain query @@ -134,39 +143,41 @@ pub enum PrepareForCache { No, } -#[allow( - clippy::len_without_is_empty, - clippy::new_without_default, - unreachable_pub -)] +#[allow(clippy::new_without_default, unreachable_pub)] impl StatementCache where - DB: Backend, + DB: Backend + 'static, + Statement: 'static, DB::TypeMetadata: Clone, DB::QueryBuilder: Default, StatementCacheKey: Hash + Eq, { - /// Create a new prepared statement cache + /// Create a new prepared statement cache using [`CacheSize::Unbounded`] as caching strategy. #[allow(unreachable_pub)] pub fn new() -> Self { StatementCache { - cache: HashMap::new(), + cache: Box::new(WithCacheStrategy::default()), + cache_counter: 0, } } - /// Get the current length of the statement cache - #[allow(unreachable_pub)] - #[cfg(any( - feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes", - feature = "postgres", - all(feature = "sqlite", test) - ))] - #[cfg_attr( - docsrs, - doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")) - )] - pub fn len(&self) -> usize { - self.cache.len() + /// Set caching strategy from predefined implementations + pub fn set_cache_size(&mut self, size: CacheSize) { + if self.cache.strategy() != size { + self.cache = match size { + CacheSize::Unbounded => Box::new(WithCacheStrategy::default()), + CacheSize::Disabled => Box::new(WithoutCacheStrategy::default()), + } + } + } + + /// Setting custom caching strategy. It is used in tests, to verify caching logic + #[allow(dead_code)] + pub(crate) fn set_strategy(&mut self, s: Strategy) + where + Strategy: StatementCacheStrategy + 'static, + { + self.cache = Box::new(s); } /// Prepare a query as prepared statement @@ -191,52 +202,41 @@ where ) -> QueryResult> where T: QueryFragment + QueryId, - F: FnMut(&str, PrepareForCache) -> QueryResult, + F: FnMut(&str, Option) -> QueryResult, { - self.cached_statement_non_generic( + Self::cached_statement_non_generic( + self.cache.as_mut(), T::query_id(), source, backend, bind_types, - &mut prepare_fn, - instrumentation, + &mut |sql, is_cached| { + if let PrepareForCache::Yes = is_cached { + instrumentation.on_connection_event(InstrumentationEvent::CacheQuery { sql }); + self.cache_counter += 1; + prepare_fn(sql, Some(self.cache_counter)) + } else { + prepare_fn(sql, None) + } + }, ) } /// Reduce the amount of monomorphized code by factoring this via dynamic dispatch - fn cached_statement_non_generic( - &mut self, + fn cached_statement_non_generic<'a>( + cache: &'a mut dyn StatementCacheStrategy, maybe_type_id: Option, source: &dyn QueryFragmentForCachedStatement, backend: &DB, bind_types: &[DB::TypeMetadata], prepare_fn: &mut dyn FnMut(&str, PrepareForCache) -> QueryResult, - instrumentation: &mut dyn Instrumentation, - ) -> QueryResult> { - use std::collections::hash_map::Entry::{Occupied, Vacant}; - + ) -> QueryResult> { let cache_key = StatementCacheKey::for_source(maybe_type_id, source, bind_types, backend)?; - if !source.is_safe_to_cache_prepared(backend)? { let sql = cache_key.sql(source, backend)?; return prepare_fn(&sql, PrepareForCache::No).map(MaybeCached::CannotCache); } - - let cached_result = match self.cache.entry(cache_key) { - Occupied(entry) => entry.into_mut(), - Vacant(entry) => { - let statement = { - let sql = entry.key().sql(source, backend)?; - instrumentation - .on_connection_event(InstrumentationEvent::CacheQuery { sql: &sql }); - prepare_fn(&sql, PrepareForCache::Yes) - }; - - entry.insert(statement?) - } - }; - - Ok(MaybeCached::Cached(cached_result)) + cache.get(cache_key, backend, source, prepare_fn) } } diff --git a/diesel/src/connection/statement_cache/strategy.rs b/diesel/src/connection/statement_cache/strategy.rs new file mode 100644 index 000000000000..3bd5a3986eed --- /dev/null +++ b/diesel/src/connection/statement_cache/strategy.rs @@ -0,0 +1,241 @@ +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::hash::Hash; + +use crate::{backend::Backend, result::Error}; + +use super::{ + CacheSize, MaybeCached, PrepareForCache, QueryFragmentForCachedStatement, StatementCacheKey, +}; + +/// Implement this trait, in order to control statement caching. +#[allow(unreachable_pub)] +pub trait StatementCacheStrategy +where + DB: Backend, + StatementCacheKey: Hash + Eq, +{ + /// Returns which strategy is implemented by this trait + fn strategy(&self) -> CacheSize; + + /// Every query (which is safe to cache) will go through this function + /// Implementation will decide whether to cache statement or not + fn get( + &mut self, + key: StatementCacheKey, + backend: &DB, + source: &dyn QueryFragmentForCachedStatement, + prepare_fn: &mut dyn FnMut(&str, PrepareForCache) -> Result, + ) -> Result, Error>; +} + +/// Cache all (safe) statements for as long as connection is alive. +#[allow(missing_debug_implementations, unreachable_pub)] +pub struct WithCacheStrategy +where + DB: Backend, +{ + cache: HashMap, Statement>, +} + +impl Default for WithCacheStrategy +where + DB: Backend, +{ + fn default() -> Self { + Self { + cache: Default::default(), + } + } +} + +impl StatementCacheStrategy for WithCacheStrategy +where + DB: Backend, + StatementCacheKey: Hash + Eq, + DB::TypeMetadata: Clone, + DB::QueryBuilder: Default, +{ + fn get( + &mut self, + key: StatementCacheKey, + backend: &DB, + source: &dyn QueryFragmentForCachedStatement, + prepare_fn: &mut dyn FnMut(&str, PrepareForCache) -> Result, + ) -> Result, Error> { + let entry = self.cache.entry(key); + match entry { + Entry::Occupied(e) => Ok(MaybeCached::Cached(e.into_mut())), + Entry::Vacant(e) => { + let sql = e.key().sql(source, backend)?; + let st = prepare_fn(&sql, PrepareForCache::Yes)?; + Ok(MaybeCached::Cached(e.insert(st))) + } + } + } + + fn strategy(&self) -> CacheSize { + CacheSize::Unbounded + } +} + +/// No statements will be cached, +#[allow(missing_debug_implementations, unreachable_pub)] +#[derive(Clone, Copy, Default)] +pub struct WithoutCacheStrategy {} + +impl StatementCacheStrategy for WithoutCacheStrategy +where + DB: Backend, + StatementCacheKey: Hash + Eq, + DB::TypeMetadata: Clone, + DB::QueryBuilder: Default, +{ + fn get( + &mut self, + key: StatementCacheKey, + backend: &DB, + source: &dyn QueryFragmentForCachedStatement, + prepare_fn: &mut dyn FnMut(&str, PrepareForCache) -> Result, + ) -> Result, Error> { + let sql = key.sql(source, backend)?; + Ok(MaybeCached::CannotCache(prepare_fn( + &sql, + PrepareForCache::No, + )?)) + } + + fn strategy(&self) -> CacheSize { + CacheSize::Disabled + } +} + +/// Utilities that help to introspect statement caching behaviour in tests. +#[cfg(test)] +pub mod testing_utils { + use std::cell::RefCell; + + use super::*; + + thread_local! { + static INTROSPECT_CACHING_STRATEGY: RefCell> = const { RefCell::new(Vec::new()) }; + } + + /// Wraps caching strategy and records all outcome of all calls to `get`. + /// Later all recorded calls can be observed by calling free function + /// [`consume_statement_caching_calls`]. + #[allow(missing_debug_implementations)] + pub struct IntrospectCachingStrategy { + inner: Box>, + } + + impl IntrospectCachingStrategy + where + DB: Backend + 'static, + Statement: 'static, + StatementCacheKey: Hash + Eq, + DB::TypeMetadata: Clone, + DB::QueryBuilder: Default, + { + /// Wrap internal cache strategy and record all calls to it that happen on a thread. + /// Later call `consume_statement_caching_calls` to get results. + pub fn new(strategy: Strategy) -> Self + where + Strategy: StatementCacheStrategy + 'static, + { + consume_statement_caching_calls(); // clear everything once new connection is created + IntrospectCachingStrategy { + inner: Box::new(strategy), + } + } + } + + /// Outcome of call to [`StatementCacheStrategy::get`] implementation. + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum CachingOutcome { + /// Statement was taken from cache + UseCached, + /// Statement was put to cache + Cache, + /// Statement wasn't cached + DontCache, + } + + /// Result summary of call to [`StatementCacheStrategy::get`] + #[derive(Debug, PartialEq, Eq)] + pub struct CallInfo { + /// Sql query + pub sql: String, + /// Caching outcome + pub outcome: CachingOutcome, + } + + /// Helper type that makes it simpler to verify [`CachingOutcome`]. + #[derive(Debug)] + pub struct IntrospectedCalls { + /// All introspected calls + pub calls: Vec, + } + + impl IntrospectedCalls { + /// Count how many calls matches required outcome. + pub fn count(&self, outcome: CachingOutcome) -> usize { + self.calls + .iter() + .filter(|info| info.outcome == outcome) + .count() + } + /// Returns true if there was not calls introspected. + pub fn is_empty(&self) -> bool { + self.calls.is_empty() + } + } + + /// Return all calls that was recorded for current thread using [`IntrospectCachingStrategy`] + pub fn consume_statement_caching_calls() -> IntrospectedCalls { + IntrospectedCalls { + calls: INTROSPECT_CACHING_STRATEGY.with_borrow_mut(std::mem::take), + } + } + + impl StatementCacheStrategy + for IntrospectCachingStrategy + where + DB: Backend, + StatementCacheKey: Hash + Eq, + DB::TypeMetadata: Clone, + DB::QueryBuilder: Default, + { + fn get( + &mut self, + key: StatementCacheKey, + backend: &DB, + source: &dyn QueryFragmentForCachedStatement, + prepare_fn: &mut dyn FnMut(&str, PrepareForCache) -> Result, + ) -> Result, Error> { + let mut outcome = None; + + let sql = key.sql(source, backend)?.to_string(); + let res = self + .inner + .get(key, backend, source, &mut |sql, is_cached| { + outcome = Some(match is_cached { + PrepareForCache::Yes => CachingOutcome::Cache, + PrepareForCache::No => CachingOutcome::DontCache, + }); + prepare_fn(sql, is_cached) + })?; + INTROSPECT_CACHING_STRATEGY.with_borrow_mut(|calls| { + calls.push(CallInfo { + sql, + outcome: outcome.unwrap_or(CachingOutcome::UseCached), + }) + }); + Ok(res) + } + + fn strategy(&self) -> CacheSize { + self.inner.strategy() + } + } +} diff --git a/diesel/src/connection/transaction_manager.rs b/diesel/src/connection/transaction_manager.rs index bad7846e83a8..c812b8c9527c 100644 --- a/diesel/src/connection/transaction_manager.rs +++ b/diesel/src/connection/transaction_manager.rs @@ -594,6 +594,10 @@ mod test { ) { self.instrumentation = Some(Box::new(instrumentation)); } + + fn set_cache_size(&mut self, _size: crate::connection::CacheSize) { + panic!("implement, if you want to use it") + } } } diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 5b87a363215f..b2bc24218cb5 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -205,6 +205,10 @@ impl Connection for MysqlConnection { fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { self.instrumentation = instrumentation.into(); } + + fn set_cache_size(&mut self, size: CacheSize) { + self.statement_cache.set_cache_size(size); + } } #[inline(always)] diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index 8b0e1925c071..05188b1938d0 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -124,7 +124,8 @@ pub(super) use self::result::PgResult; #[allow(missing_debug_implementations)] #[cfg(feature = "postgres")] pub struct PgConnection { - statement_cache: StatementCache, + /// pub(crate) for tests + pub(crate) statement_cache: StatementCache, metadata_cache: PgMetadataCache, connection_and_transaction_manager: ConnectionAndTransactionManager, } @@ -236,6 +237,10 @@ impl Connection for PgConnection { fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { self.connection_and_transaction_manager.instrumentation = instrumentation.into(); } + + fn set_cache_size(&mut self, size: CacheSize) { + self.statement_cache.set_cache_size(size); + } } impl LoadConnection for PgConnection @@ -499,19 +504,14 @@ impl PgConnection { let binds = bind_collector.binds; let metadata = bind_collector.metadata; - let cache_len = self.statement_cache.len(); let cache = &mut self.statement_cache; let conn = &mut self.connection_and_transaction_manager.raw_connection; let query = cache.cached_statement( &source, &Pg, &metadata, - |sql, _| { - let query_name = if source.is_safe_to_cache_prepared(&Pg)? { - Some(format!("__diesel_stmt_{cache_len}")) - } else { - None - }; + |sql, counter| { + let query_name = counter.map(|counter| format!("__diesel_stmt_{counter}")); Statement::prepare(conn, sql, query_name.as_deref(), &metadata) }, &mut *self.connection_and_transaction_manager.instrumentation, @@ -613,6 +613,10 @@ mod private { mod tests { extern crate dotenvy; + use statement_cache::strategy::testing_utils::{ + consume_statement_caching_calls, CachingOutcome, + }; + use super::*; use crate::dsl::sql; use crate::prelude::*; @@ -620,6 +624,10 @@ mod tests { use crate::sql_types::{Integer, VarChar}; use std::num::NonZeroU32; + fn connection() -> PgConnection { + crate::test_helpers::pg_connection_no_transaction() + } + #[test] fn malformed_sql_query() { let connection = &mut connection(); @@ -641,7 +649,9 @@ mod tests { assert_eq!(Ok(1), query.get_result(connection)); assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(1, connection.statement_cache.len()); + let outcome = consume_statement_caching_calls(); + assert_eq!(1, outcome.count(CachingOutcome::UseCached)); + assert_eq!(1, outcome.count(CachingOutcome::Cache)); } #[test] @@ -653,7 +663,10 @@ mod tests { assert_eq!(Ok(1), query.get_result(connection)); assert_eq!(Ok("hi".to_string()), query2.get_result(connection)); - assert_eq!(2, connection.statement_cache.len()); + assert_eq!( + 2, + consume_statement_caching_calls().count(CachingOutcome::Cache) + ); } #[test] @@ -663,10 +676,12 @@ mod tests { let query = crate::select(1.into_sql::()).into_boxed::(); let query2 = crate::select("hi".into_sql::()).into_boxed::(); - assert_eq!(0, connection.statement_cache.len()); assert_eq!(Ok(1), query.get_result(connection)); assert_eq!(Ok("hi".to_string()), query2.get_result(connection)); - assert_eq!(2, connection.statement_cache.len()); + assert_eq!( + 2, + consume_statement_caching_calls().count(CachingOutcome::Cache) + ); } define_sql_function!(fn lower(x: VarChar) -> VarChar); @@ -679,10 +694,12 @@ mod tests { let query = crate::select(hi).into_boxed::(); let query2 = crate::select(lower(hi)).into_boxed::(); - assert_eq!(0, connection.statement_cache.len()); assert_eq!(Ok("HI".to_string()), query.get_result(connection)); assert_eq!(Ok("hi".to_string()), query2.get_result(connection)); - assert_eq!(2, connection.statement_cache.len()); + assert_eq!( + 2, + consume_statement_caching_calls().count(CachingOutcome::Cache) + ); } #[test] @@ -691,7 +708,7 @@ mod tests { let query = crate::select(sql::("1")); assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(0, connection.statement_cache.len()); + assert!(consume_statement_caching_calls().is_empty()); } table! { @@ -717,14 +734,16 @@ mod tests { .insert_into(users::table) .into_columns((users::id, users::name)); assert!(insert.execute(connection).is_ok()); - assert_eq!(1, connection.statement_cache.len()); let query = users::table.filter(users::id.eq(42)).into_boxed(); let insert = query .insert_into(users::table) .into_columns((users::id, users::name)); assert!(insert.execute(connection).is_ok()); - assert_eq!(2, connection.statement_cache.len()); + assert_eq!( + 2, + consume_statement_caching_calls().count(CachingOutcome::Cache) + ); } #[test] @@ -742,7 +761,10 @@ mod tests { crate::insert_into(users::table).values((users::id.eq(42), users::name.eq("Foo"))); assert!(insert.execute(connection).is_ok()); - assert_eq!(1, connection.statement_cache.len()); + assert_eq!( + 1, + consume_statement_caching_calls().count(CachingOutcome::Cache) + ); } #[test] @@ -760,7 +782,7 @@ mod tests { .values(vec![(users::id.eq(42), users::name.eq("Foo"))]); assert!(insert.execute(connection).is_ok()); - assert_eq!(0, connection.statement_cache.len()); + assert!(consume_statement_caching_calls().is_empty()); } #[test] @@ -778,7 +800,9 @@ mod tests { crate::insert_into(users::table).values([(users::id.eq(42), users::name.eq("Foo"))]); assert!(insert.execute(connection).is_ok()); - assert_eq!(1, connection.statement_cache.len()); + let outcome = consume_statement_caching_calls(); + assert_eq!(1, outcome.count(CachingOutcome::Cache)); + assert_eq!(1, outcome.calls.len()); } #[test] @@ -788,11 +812,10 @@ mod tests { let query = crate::select(one_as_expr.eq_any(vec![1, 2, 3])); assert_eq!(Ok(true), query.get_result(connection)); - assert_eq!(1, connection.statement_cache.len()); - } - - fn connection() -> PgConnection { - crate::test_helpers::pg_connection_no_transaction() + assert_eq!( + 1, + consume_statement_caching_calls().count(CachingOutcome::Cache) + ); } #[test] diff --git a/diesel/src/r2d2.rs b/diesel/src/r2d2.rs index 8abca2f386ba..cdc6d3901e65 100644 --- a/diesel/src/r2d2.rs +++ b/diesel/src/r2d2.rs @@ -263,6 +263,10 @@ where fn set_instrumentation(&mut self, instrumentation: impl crate::connection::Instrumentation) { (**self).set_instrumentation(instrumentation) } + + fn set_cache_size(&mut self, size: crate::connection::CacheSize) { + (**self).set_cache_size(size) + } } impl LoadConnection for PooledConnection diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index deee02775795..f7ba3aa2eb1e 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -31,6 +31,7 @@ use crate::result::*; use crate::serialize::ToSql; use crate::sql_types::{HasSqlType, TypeMetadata}; use crate::sqlite::Sqlite; +use statement_cache::PrepareForCache; /// Connections for the SQLite backend. Unlike other backends, SQLite supported /// connection URLs are: @@ -121,7 +122,8 @@ pub struct SqliteConnection { // statement_cache needs to be before raw_connection // otherwise we will get errors about open statements before closing the // connection itself - statement_cache: StatementCache, + // pub(crate) for tests + pub(crate) statement_cache: StatementCache, raw_connection: RawConnection, transaction_state: AnsiTransactionManager, // this exists for the sole purpose of implementing `WithMetadataLookup` trait @@ -204,6 +206,10 @@ impl Connection for SqliteConnection { fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { self.instrumentation = instrumentation.into(); } + + fn set_cache_size(&mut self, size: CacheSize) { + self.statement_cache.set_cache_size(size); + } } impl LoadConnection for SqliteConnection { @@ -351,7 +357,15 @@ impl SqliteConnection { &source, &Sqlite, &[], - |sql, is_cached| Statement::prepare(raw_connection, sql, is_cached), + |sql, counter| { + Statement::prepare( + raw_connection, + sql, + counter + .map(|_| PrepareForCache::Yes) + .unwrap_or(PrepareForCache::No), + ) + }, &mut *self.instrumentation, ) { Ok(statement) => statement, @@ -558,6 +572,17 @@ mod tests { use crate::dsl::sql; use crate::prelude::*; use crate::sql_types::Integer; + use crate::sqlite::connection::statement_cache::strategy::testing_utils::{ + consume_statement_caching_calls, CachingOutcome, IntrospectCachingStrategy, + }; + use crate::sqlite::connection::statement_cache::strategy::WithCacheStrategy; + + fn connection() -> SqliteConnection { + let mut conn = SqliteConnection::establish(":memory:").unwrap(); + conn.statement_cache + .set_strategy(IntrospectCachingStrategy::new(WithCacheStrategy::default())); + conn + } #[test] fn database_serializes_and_deserializes_successfully() { @@ -574,73 +599,79 @@ mod tests { ), ]; - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let conn1 = &mut connection(); let _ = crate::sql_query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT)") - .execute(connection); + .execute(conn1); let _ = crate::sql_query("INSERT INTO users (name, email) VALUES ('John Doe', 'john.doe@example.com'), ('Jane Doe', 'jane.doe@example.com')") - .execute(connection); + .execute(conn1); - let serialized_database = connection.serialize_database_to_buffer(); + let serialized_database = conn1.serialize_database_to_buffer(); - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); - connection + let conn2 = &mut connection(); + conn2 .deserialize_readonly_database_from_buffer(serialized_database.as_slice()) .unwrap(); let query = sql::<(Integer, Text, Text)>("SELECT id, name, email FROM users ORDER BY id"); - let actual_users = query.load::<(i32, String, String)>(connection).unwrap(); + let actual_users = query.load::<(i32, String, String)>(conn2).unwrap(); assert_eq!(expected_users, actual_users); } #[test] fn prepared_statements_are_cached_when_run() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let query = crate::select(1.into_sql::()); assert_eq!(Ok(1), query.get_result(connection)); assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(1, connection.statement_cache.len()); + + let outcome = consume_statement_caching_calls(); + assert_eq!(1, outcome.count(CachingOutcome::Cache)); + assert_eq!(1, outcome.count(CachingOutcome::UseCached)); } #[test] fn sql_literal_nodes_are_not_cached() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let query = crate::select(sql::("1")); assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(0, connection.statement_cache.len()); + assert!(consume_statement_caching_calls().is_empty()); } #[test] fn queries_containing_sql_literal_nodes_are_not_cached() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let one_as_expr = 1.into_sql::(); let query = crate::select(one_as_expr.eq(sql::("1"))); assert_eq!(Ok(true), query.get_result(connection)); - assert_eq!(0, connection.statement_cache.len()); + assert!(consume_statement_caching_calls().is_empty()); } #[test] fn queries_containing_in_with_vec_are_not_cached() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let one_as_expr = 1.into_sql::(); let query = crate::select(one_as_expr.eq_any(vec![1, 2, 3])); assert_eq!(Ok(true), query.get_result(connection)); - assert_eq!(0, connection.statement_cache.len()); + assert!(consume_statement_caching_calls().is_empty()); } #[test] fn queries_containing_in_with_subselect_are_cached() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let one_as_expr = 1.into_sql::(); let query = crate::select(one_as_expr.eq_any(crate::select(one_as_expr))); assert_eq!(Ok(true), query.get_result(connection)); - assert_eq!(1, connection.statement_cache.len()); + assert_eq!( + 1, + consume_statement_caching_calls().count(CachingOutcome::Cache) + ); } use crate::sql_types::Text; @@ -648,7 +679,7 @@ mod tests { #[test] fn register_custom_function() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); fun_case_utils::register_impl(connection, |x: String| { x.chars() .enumerate() @@ -673,7 +704,7 @@ mod tests { #[test] fn register_multiarg_function() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); my_add_utils::register_impl(connection, |x: i32, y: i32| x + y).unwrap(); let added = crate::select(my_add(1, 2)).get_result::(connection); @@ -684,7 +715,7 @@ mod tests { #[test] fn register_noarg_function() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); answer_utils::register_impl(connection, || 42).unwrap(); let answer = crate::select(answer()).get_result::(connection); @@ -693,7 +724,7 @@ mod tests { #[test] fn register_nondeterministic_noarg_function() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); answer_utils::register_nondeterministic_impl(connection, || 42).unwrap(); let answer = crate::select(answer()).get_result::(connection); @@ -704,7 +735,7 @@ mod tests { #[test] fn register_nondeterministic_function() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let mut y = 0; add_counter_utils::register_nondeterministic_impl(connection, move |x: i32| { y += 1; @@ -750,7 +781,7 @@ mod tests { fn register_aggregate_function() { use self::my_sum_example::dsl::*; - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); crate::sql_query( "CREATE TABLE my_sum_example (id integer primary key autoincrement, value integer)", ) @@ -772,7 +803,7 @@ mod tests { fn register_aggregate_function_returns_finalize_default_on_empty_set() { use self::my_sum_example::dsl::*; - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); crate::sql_query( "CREATE TABLE my_sum_example (id integer primary key autoincrement, value integer)", ) @@ -834,7 +865,7 @@ mod tests { fn register_aggregate_multiarg_function() { use self::range_max_example::dsl::*; - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); crate::sql_query( r#"CREATE TABLE range_max_example ( id integer primary key autoincrement, @@ -870,7 +901,7 @@ mod tests { fn register_collation_function() { use self::my_collation_example::dsl::*; - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); connection .register_collation("RUSTNOCASE", |rhs, lhs| { @@ -950,7 +981,7 @@ mod tests { } } - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let res = crate::select( CustomWrapper("".into()) @@ -980,7 +1011,7 @@ mod tests { } } - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let res = crate::select( CustomWrapper(Vec::new()) diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index 92b12465d772..067134a3bcc9 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -15,7 +15,7 @@ use std::io::{stderr, Write}; use std::os::raw as libc; use std::ptr::{self, NonNull}; -pub(super) struct Statement { +pub(crate) struct Statement { inner_statement: NonNull, } diff --git a/diesel/src/test_helpers.rs b/diesel/src/test_helpers.rs index 2bc0572b31c3..f0df65dc5a17 100644 --- a/diesel/src/test_helpers.rs +++ b/diesel/src/test_helpers.rs @@ -60,7 +60,13 @@ pub fn pg_connection() -> PgConnection { #[cfg(feature = "postgres")] pub fn pg_connection_no_transaction() -> PgConnection { - PgConnection::establish(&pg_database_url()).unwrap() + use crate::connection::statement_cache::strategy::{ + testing_utils::IntrospectCachingStrategy, WithCacheStrategy, + }; + let mut conn = PgConnection::establish(&pg_database_url()).unwrap(); + conn.statement_cache + .set_strategy(IntrospectCachingStrategy::new(WithCacheStrategy::default())); + conn } #[cfg(feature = "postgres")] diff --git a/diesel_derives/src/multiconnection.rs b/diesel_derives/src/multiconnection.rs index fde7838aad8b..bfc7edb18875 100644 --- a/diesel_derives/src/multiconnection.rs +++ b/diesel_derives/src/multiconnection.rs @@ -118,6 +118,15 @@ fn generate_connection_impl( } }); + let set_cache_impl = connection_types.iter().map(|c| { + let variant_ident = c.name; + quote::quote! { + #ident::#variant_ident(conn) => { + diesel::connection::Connection::set_cache_size(conn, size); + } + } + }); + let get_instrumentation_impl = connection_types.iter().map(|c| { let variant_ident = c.name; quote::quote! { @@ -367,6 +376,12 @@ fn generate_connection_impl( } } + fn set_cache_size(&mut self, size: diesel::connection::CacheSize) { + match self { + #(#set_cache_impl,)* + } + } + fn begin_test_transaction(&mut self) -> diesel::QueryResult<()> { match self { #(#impl_begin_test_transaction,)*