Skip to content

Commit

Permalink
Merge pull request #206 from weiznich/feature/share_statement_cache_w…
Browse files Browse the repository at this point in the history
…ith_diesel

Share the statement cache with diesel
  • Loading branch information
weiznich authored Jan 16, 2025
2 parents 7e0267b + 45c8559 commit 780b348
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 188 deletions.
28 changes: 23 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ description = "An async extension for Diesel the safe, extensible ORM and Query
rust-version = "1.78.0"

[dependencies]
diesel = { version = "~2.2.0", default-features = false, features = [
"i-implement-a-third-party-backend-and-opt-into-breaking-changes",
] }
async-trait = "0.1.66"
futures-channel = { version = "0.3.17", default-features = false, features = [
"std",
Expand All @@ -39,14 +36,35 @@ deadpool = { version = "0.12", optional = true, default-features = false, featur
mobc = { version = ">=0.7,<0.10", optional = true }
scoped-futures = { version = "0.1", features = ["std"] }

[dependencies.diesel]
version = "~2.2.0"
default-features = false
features = [
"i-implement-a-third-party-backend-and-opt-into-breaking-changes",
]
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[dev-dependencies]
tokio = { version = "1.12.0", features = ["rt", "macros", "rt-multi-thread"] }
cfg-if = "1"
chrono = "0.4"
diesel = { version = "2.2.0", default-features = false, features = ["chrono"] }
diesel_migrations = "2.2.0"
assert_matches = "1.0.1"

[dev-dependencies.diesel]
version = "~2.2.0"
default-features = false
features = [
"chrono"
]
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[dev-dependencies.diesel_migrations]
version = "2.2.0"
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[features]
default = []
mysql = [
Expand Down
8 changes: 7 additions & 1 deletion examples/postgres/pooled-with-rustls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
diesel = { version = "2.2.0", default-features = false, features = ["postgres"] }
diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres"] }
futures-util = "0.3.21"
rustls = "0.23.8"
rustls-platform-verifier = "0.5.0"
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }
tokio-postgres = "0.7.7"
tokio-postgres-rustls = "0.13.0"


[dependencies.diesel]
version = "2.2.0"
default-features = false
git = "https://github.com/diesel-rs/diesel"
branch = "master"
13 changes: 11 additions & 2 deletions examples/postgres/run-pending-migrations-with-rustls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
diesel = { version = "2.2.0", default-features = false, features = ["postgres"] }
diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres", "async-connection-wrapper"] }
diesel_migrations = "2.2.0"
futures-util = "0.3.21"
rustls = "0.23.8"
rustls-platform-verifier = "0.5.0"
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }
tokio-postgres = "0.7.7"
tokio-postgres-rustls = "0.13.0"

[dependencies.diesel]
version = "2.2.0"
default-features = false
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[dependencies.diesel_migrations]
version = "2.2.0"
git = "https://github.com/diesel-rs/diesel"
branch = "master"
14 changes: 12 additions & 2 deletions examples/sync-wrapper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,22 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
diesel = { version = "2.2.0", default-features = false, features = ["returning_clauses_for_sqlite_3_35"] }
diesel-async = { version = "0.5.0", path = "../../", features = ["sync-connection-wrapper", "async-connection-wrapper"] }
diesel_migrations = "2.2.0"
futures-util = "0.3.21"
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }

[dependencies.diesel]
version = "2.2.0"
default-features = false
features = ["returning_clauses_for_sqlite_3_35"]
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[dependencies.diesel_migrations]
version = "2.2.0"
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[features]
default = ["sqlite"]
sqlite = ["diesel-async/sqlite"]
6 changes: 5 additions & 1 deletion src/async_connection_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
pub use self::implementation::AsyncConnectionWrapper;

mod implementation {
use diesel::connection::{Instrumentation, SimpleConnection};
use diesel::connection::{CacheSize, Instrumentation, SimpleConnection};
use std::ops::{Deref, DerefMut};

use super::*;
Expand Down Expand Up @@ -187,6 +187,10 @@ mod implementation {
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
self.inner.set_instrumentation(instrumentation);
}

fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
self.inner.set_prepared_statement_cache_size(size)
}
}

impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
Expand Down
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
)]

use diesel::backend::Backend;
use diesel::connection::Instrumentation;
use diesel::connection::{CacheSize, Instrumentation};
use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
use diesel::result::Error;
use diesel::row::Row;
Expand Down Expand Up @@ -354,4 +354,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {

/// Set a specific [`Instrumentation`] implementation for this connection
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation);

/// Set the prepared statement cache size to [`CacheSize`] for this connection
fn set_prepared_statement_cache_size(&mut self, size: CacheSize);
}
115 changes: 62 additions & 53 deletions src/mysql/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::stmt_cache::{PrepareCallback, StmtCache};
use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey};
use diesel::connection::Instrumentation;
use diesel::connection::InstrumentationEvent;
use diesel::connection::statement_cache::{
MaybeCached, QueryFragmentForCachedStatement, StatementCache,
};
use diesel::connection::StrQueryHelper;
use diesel::connection::{CacheSize, Instrumentation};
use diesel::connection::{DynInstrumentation, InstrumentationEvent};
use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
use diesel::query_builder::QueryBuilder;
use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
Expand All @@ -27,9 +29,9 @@ use self::serialize::ToSqlHelper;
/// `mysql://[user[:password]@]host/database_name`
pub struct AsyncMysqlConnection {
conn: mysql_async::Conn,
stmt_cache: StmtCache<Mysql, Statement>,
stmt_cache: StatementCache<Mysql, Statement>,
transaction_manager: AnsiTransactionManager,
instrumentation: std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
instrumentation: DynInstrumentation,
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -72,7 +74,7 @@ impl AsyncConnection for AsyncMysqlConnection {
type TransactionManager = AnsiTransactionManager;

async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
let mut instrumentation = diesel::connection::get_default_instrumentation();
let mut instrumentation = DynInstrumentation::default_instrumentation();
instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
database_url,
));
Expand All @@ -82,7 +84,7 @@ impl AsyncConnection for AsyncMysqlConnection {
r.as_ref().err(),
));
let mut conn = r?;
conn.instrumentation = std::sync::Mutex::new(instrumentation);
conn.instrumentation = instrumentation;
Ok(conn)
}

Expand Down Expand Up @@ -177,16 +179,15 @@ impl AsyncConnection for AsyncMysqlConnection {
}

fn instrumentation(&mut self) -> &mut dyn Instrumentation {
self.instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner())
&mut *self.instrumentation
}

fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
*self
.instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner()) = Some(Box::new(instrumentation));
self.instrumentation = instrumentation.into();
}

fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
self.stmt_cache.set_cache_size(size);
}
}

Expand All @@ -207,17 +208,24 @@ fn update_transaction_manager_status<T>(
query_result
}

#[async_trait::async_trait]
impl PrepareCallback<Statement, MysqlType> for &'_ mut mysql_async::Conn {
async fn prepare(
self,
sql: &str,
_metadata: &[MysqlType],
_is_for_cache: diesel::connection::statement_cache::PrepareForCache,
) -> QueryResult<(Statement, Self)> {
let s = self.prep(sql).await.map_err(ErrorHelper)?;
Ok((s, self))
}
fn prepare_statement_helper<'a, 'b>(
conn: &'a mut mysql_async::Conn,
sql: &'b str,
_is_for_cache: diesel::connection::statement_cache::PrepareForCache,
_metadata: &[MysqlType],
) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
{
// ideally we wouldn't clone the SQL string here
// but as we usually cache statements anyway
// this is a fixed one time const
//
// The probleme with not cloning it is that we then cannot express
// the right result lifetime anymore (at least not easily)
let sql = sql.to_owned();
CallbackHelper(async move {
let s = conn.prep(sql).await.map_err(ErrorHelper)?;
Ok((s, conn))
})
}

impl AsyncMysqlConnection {
Expand All @@ -229,11 +237,9 @@ impl AsyncMysqlConnection {
use crate::run_query_dsl::RunQueryDsl;
let mut conn = AsyncMysqlConnection {
conn,
stmt_cache: StmtCache::new(),
stmt_cache: StatementCache::new(),
transaction_manager: AnsiTransactionManager::default(),
instrumentation: std::sync::Mutex::new(
diesel::connection::get_default_instrumentation(),
),
instrumentation: DynInstrumentation::default_instrumentation(),
};

for stmt in CONNECTION_SETUP_QUERIES {
Expand Down Expand Up @@ -286,36 +292,29 @@ impl AsyncMysqlConnection {
} = bind_collector?;
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
let sql = sql?;
let helper = QueryFragmentHelper {
sql,
safe_to_cache: is_safe_to_cache_prepared,
};
let inner = async {
let cache_key = if let Some(query_id) = query_id {
StatementCacheKey::Type(query_id)
} else {
StatementCacheKey::Sql {
sql: sql.clone(),
bind_types: metadata.clone(),
}
};

let (stmt, conn) = stmt_cache
.cached_prepared_statement(
cache_key,
sql.clone(),
is_safe_to_cache_prepared,
.cached_statement_non_generic(
query_id,
&helper,
&Mysql,
&metadata,
conn,
instrumentation,
prepare_statement_helper,
&mut **instrumentation,
)
.await?;
callback(conn, stmt, ToSqlHelper { metadata, binds }).await
};
let r = update_transaction_manager_status(inner.await, transaction_manager);
instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner())
.on_connection_event(InstrumentationEvent::finish_query(
&StrQueryHelper::new(&sql),
r.as_ref().err(),
));
instrumentation.on_connection_event(InstrumentationEvent::finish_query(
&StrQueryHelper::new(&helper.sql),
r.as_ref().err(),
));
r
}
.boxed()
Expand Down Expand Up @@ -370,9 +369,9 @@ impl AsyncMysqlConnection {

Ok(AsyncMysqlConnection {
conn,
stmt_cache: StmtCache::new(),
stmt_cache: StatementCache::new(),
transaction_manager: AnsiTransactionManager::default(),
instrumentation: std::sync::Mutex::new(None),
instrumentation: DynInstrumentation::none(),
})
}
}
Expand Down Expand Up @@ -427,3 +426,13 @@ mod tests {
}
}
}

impl QueryFragmentForCachedStatement<Mysql> for QueryFragmentHelper {
fn construct_sql(&self, _backend: &Mysql) -> QueryResult<String> {
Ok(self.sql.clone())
}

fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult<bool> {
Ok(self.safe_to_cache)
}
}
Loading

0 comments on commit 780b348

Please sign in to comment.