Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Share the statement cache with diesel #206

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ description = "An async extension for Diesel the safe, extensible ORM and Query
rust-version = "1.78.0"

[dependencies]
diesel = { version = "~2.2.0", default-features = false, features = [
"i-implement-a-third-party-backend-and-opt-into-breaking-changes",
] }
async-trait = "0.1.66"
futures-channel = { version = "0.3.17", default-features = false, features = [
"std",
Expand All @@ -39,14 +36,35 @@ deadpool = { version = "0.12", optional = true, default-features = false, featur
mobc = { version = ">=0.7,<0.10", optional = true }
scoped-futures = { version = "0.1", features = ["std"] }

[dependencies.diesel]
version = "~2.2.0"
default-features = false
features = [
"i-implement-a-third-party-backend-and-opt-into-breaking-changes",
]
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[dev-dependencies]
tokio = { version = "1.12.0", features = ["rt", "macros", "rt-multi-thread"] }
cfg-if = "1"
chrono = "0.4"
diesel = { version = "2.2.0", default-features = false, features = ["chrono"] }
diesel_migrations = "2.2.0"
assert_matches = "1.0.1"

[dev-dependencies.diesel]
version = "~2.2.0"
default-features = false
features = [
"chrono"
]
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[dev-dependencies.diesel_migrations]
version = "2.2.0"
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[features]
default = []
mysql = [
Expand Down
8 changes: 7 additions & 1 deletion examples/postgres/pooled-with-rustls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
diesel = { version = "2.2.0", default-features = false, features = ["postgres"] }
diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres"] }
futures-util = "0.3.21"
rustls = "0.23.8"
rustls-platform-verifier = "0.5.0"
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }
tokio-postgres = "0.7.7"
tokio-postgres-rustls = "0.13.0"


[dependencies.diesel]
version = "2.2.0"
default-features = false
git = "https://github.com/diesel-rs/diesel"
branch = "master"
13 changes: 11 additions & 2 deletions examples/postgres/run-pending-migrations-with-rustls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
diesel = { version = "2.2.0", default-features = false, features = ["postgres"] }
diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres", "async-connection-wrapper"] }
diesel_migrations = "2.2.0"
futures-util = "0.3.21"
rustls = "0.23.8"
rustls-platform-verifier = "0.5.0"
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }
tokio-postgres = "0.7.7"
tokio-postgres-rustls = "0.13.0"

[dependencies.diesel]
version = "2.2.0"
default-features = false
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[dependencies.diesel_migrations]
version = "2.2.0"
git = "https://github.com/diesel-rs/diesel"
branch = "master"
14 changes: 12 additions & 2 deletions examples/sync-wrapper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,22 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
diesel = { version = "2.2.0", default-features = false, features = ["returning_clauses_for_sqlite_3_35"] }
diesel-async = { version = "0.5.0", path = "../../", features = ["sync-connection-wrapper", "async-connection-wrapper"] }
diesel_migrations = "2.2.0"
futures-util = "0.3.21"
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }

[dependencies.diesel]
version = "2.2.0"
default-features = false
features = ["returning_clauses_for_sqlite_3_35"]
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[dependencies.diesel_migrations]
version = "2.2.0"
git = "https://github.com/diesel-rs/diesel"
branch = "master"

[features]
default = ["sqlite"]
sqlite = ["diesel-async/sqlite"]
6 changes: 5 additions & 1 deletion src/async_connection_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
pub use self::implementation::AsyncConnectionWrapper;

mod implementation {
use diesel::connection::{Instrumentation, SimpleConnection};
use diesel::connection::{CacheSize, Instrumentation, SimpleConnection};
use std::ops::{Deref, DerefMut};

use super::*;
Expand Down Expand Up @@ -187,6 +187,10 @@ mod implementation {
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
self.inner.set_instrumentation(instrumentation);
}

fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
self.inner.set_prepared_statement_cache_size(size)
}
}

impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
Expand Down
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
)]

use diesel::backend::Backend;
use diesel::connection::Instrumentation;
use diesel::connection::{CacheSize, Instrumentation};
use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
use diesel::result::Error;
use diesel::row::Row;
Expand Down Expand Up @@ -354,4 +354,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {

/// Set a specific [`Instrumentation`] implementation for this connection
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation);

/// Set the prepared statement cache size to [`CacheSize`] for this connection
fn set_prepared_statement_cache_size(&mut self, size: CacheSize);
}
115 changes: 62 additions & 53 deletions src/mysql/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::stmt_cache::{PrepareCallback, StmtCache};
use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey};
use diesel::connection::Instrumentation;
use diesel::connection::InstrumentationEvent;
use diesel::connection::statement_cache::{
MaybeCached, QueryFragmentForCachedStatement, StatementCache,
};
use diesel::connection::StrQueryHelper;
use diesel::connection::{CacheSize, Instrumentation};
use diesel::connection::{DynInstrumentation, InstrumentationEvent};
use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
use diesel::query_builder::QueryBuilder;
use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
Expand All @@ -27,9 +29,9 @@ use self::serialize::ToSqlHelper;
/// `mysql://[user[:password]@]host/database_name`
pub struct AsyncMysqlConnection {
conn: mysql_async::Conn,
stmt_cache: StmtCache<Mysql, Statement>,
stmt_cache: StatementCache<Mysql, Statement>,
transaction_manager: AnsiTransactionManager,
instrumentation: std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
instrumentation: DynInstrumentation,
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -72,7 +74,7 @@ impl AsyncConnection for AsyncMysqlConnection {
type TransactionManager = AnsiTransactionManager;

async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
let mut instrumentation = diesel::connection::get_default_instrumentation();
let mut instrumentation = DynInstrumentation::default_instrumentation();
instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
database_url,
));
Expand All @@ -82,7 +84,7 @@ impl AsyncConnection for AsyncMysqlConnection {
r.as_ref().err(),
));
let mut conn = r?;
conn.instrumentation = std::sync::Mutex::new(instrumentation);
conn.instrumentation = instrumentation;
Ok(conn)
}

Expand Down Expand Up @@ -177,16 +179,15 @@ impl AsyncConnection for AsyncMysqlConnection {
}

fn instrumentation(&mut self) -> &mut dyn Instrumentation {
self.instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner())
&mut *self.instrumentation
}

fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
*self
.instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner()) = Some(Box::new(instrumentation));
self.instrumentation = instrumentation.into();
}

fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
self.stmt_cache.set_cache_size(size);
}
}

Expand All @@ -207,17 +208,24 @@ fn update_transaction_manager_status<T>(
query_result
}

#[async_trait::async_trait]
impl PrepareCallback<Statement, MysqlType> for &'_ mut mysql_async::Conn {
async fn prepare(
self,
sql: &str,
_metadata: &[MysqlType],
_is_for_cache: diesel::connection::statement_cache::PrepareForCache,
) -> QueryResult<(Statement, Self)> {
let s = self.prep(sql).await.map_err(ErrorHelper)?;
Ok((s, self))
}
fn prepare_statement_helper<'a, 'b>(
conn: &'a mut mysql_async::Conn,
sql: &'b str,
_is_for_cache: diesel::connection::statement_cache::PrepareForCache,
_metadata: &[MysqlType],
) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
{
// ideally we wouldn't clone the SQL string here
// but as we usually cache statements anyway
// this is a fixed one time const
//
// The probleme with not cloning it is that we then cannot express
// the right result lifetime anymore (at least not easily)
let sql = sql.to_owned();
CallbackHelper(async move {
let s = conn.prep(sql).await.map_err(ErrorHelper)?;
Ok((s, conn))
})
}

impl AsyncMysqlConnection {
Expand All @@ -229,11 +237,9 @@ impl AsyncMysqlConnection {
use crate::run_query_dsl::RunQueryDsl;
let mut conn = AsyncMysqlConnection {
conn,
stmt_cache: StmtCache::new(),
stmt_cache: StatementCache::new(),
transaction_manager: AnsiTransactionManager::default(),
instrumentation: std::sync::Mutex::new(
diesel::connection::get_default_instrumentation(),
),
instrumentation: DynInstrumentation::default_instrumentation(),
};

for stmt in CONNECTION_SETUP_QUERIES {
Expand Down Expand Up @@ -286,36 +292,29 @@ impl AsyncMysqlConnection {
} = bind_collector?;
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
let sql = sql?;
let helper = QueryFragmentHelper {
sql,
safe_to_cache: is_safe_to_cache_prepared,
};
let inner = async {
let cache_key = if let Some(query_id) = query_id {
StatementCacheKey::Type(query_id)
} else {
StatementCacheKey::Sql {
sql: sql.clone(),
bind_types: metadata.clone(),
}
};

let (stmt, conn) = stmt_cache
.cached_prepared_statement(
cache_key,
sql.clone(),
is_safe_to_cache_prepared,
.cached_statement_non_generic(
query_id,
&helper,
&Mysql,
&metadata,
conn,
instrumentation,
prepare_statement_helper,
&mut **instrumentation,
)
.await?;
callback(conn, stmt, ToSqlHelper { metadata, binds }).await
};
let r = update_transaction_manager_status(inner.await, transaction_manager);
instrumentation
.get_mut()
.unwrap_or_else(|p| p.into_inner())
.on_connection_event(InstrumentationEvent::finish_query(
&StrQueryHelper::new(&sql),
r.as_ref().err(),
));
instrumentation.on_connection_event(InstrumentationEvent::finish_query(
&StrQueryHelper::new(&helper.sql),
r.as_ref().err(),
));
r
}
.boxed()
Expand Down Expand Up @@ -370,9 +369,9 @@ impl AsyncMysqlConnection {

Ok(AsyncMysqlConnection {
conn,
stmt_cache: StmtCache::new(),
stmt_cache: StatementCache::new(),
transaction_manager: AnsiTransactionManager::default(),
instrumentation: std::sync::Mutex::new(None),
instrumentation: DynInstrumentation::none(),
})
}
}
Expand Down Expand Up @@ -427,3 +426,13 @@ mod tests {
}
}
}

impl QueryFragmentForCachedStatement<Mysql> for QueryFragmentHelper {
fn construct_sql(&self, _backend: &Mysql) -> QueryResult<String> {
Ok(self.sql.clone())
}

fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult<bool> {
Ok(self.safe_to_cache)
}
}
Loading
Loading