Skip to content

Commit

Permalink
Fix several issues with #[derive(MultiConnection)]
Browse files Browse the repository at this point in the history
This commit fixes several issues that occur with multiconnection
implementations:

* We now generate an implementation of
`Connection::begin_test_transaction` that just calls the inner
implementations as the inner connection implementations might
overwrite the default implementation

* We now mark queries as unsafe to cache if they are in marked as unsafe
to be cached for the specific backend.

* Another fix for binding null values as the assumption that the actual
types don't matter seems to be not correct. I've replaced the hack with
an distinct function on the inner bind collectors itself instead.

* Another fix for the the type metadata returned by `HasSqlType` as that
could result in calling the impl for the wrong backend if backends share
the same `MetadataLookup` type. The previous implementation did just use
the first backend that returned a concrete type metadata lookup value,
this patch just calls all possible backends and return all possible
values in a struct instead of an enum.

In addition this PR adds some doc comments to some locations that are
warned of while building with the
`i-implement-a-third-party-backend-and-opt-into-breaking-changes` flag
  • Loading branch information
weiznich committed Feb 19, 2024
1 parent 7c4ba73 commit 49ac723
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 18 deletions.
1 change: 1 addition & 0 deletions diesel/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ where
&mut self,
) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData;

/// Get the instrumentation instance stored in this connection
#[diesel_derives::__diesel_public_if(
feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"
)]
Expand Down
3 changes: 3 additions & 0 deletions diesel/src/connection/statement_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ where
doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))
)]
pub trait QueryFragmentForCachedStatement<DB> {
/// Convert the query fragment into a SQL string for the given backend
fn construct_sql(&self, backend: &DB) -> QueryResult<String>;
/// Check whether it's safe to cache the query
fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool>;
}
impl<T, DB> QueryFragmentForCachedStatement<DB> for T
Expand All @@ -269,6 +271,7 @@ where
self.to_sql(&mut query_builder, backend)?;
Ok(query_builder.finish())
}

fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool> {
<T as QueryFragment<DB>>::is_safe_to_cache_prepared(self, backend)
}
Expand Down
18 changes: 18 additions & 0 deletions diesel/src/query_builder/bind_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ pub trait BindCollector<'a, DB: TypeMetadata>: Sized {
where
DB: Backend + HasSqlType<T>,
U: ToSql<T, DB> + ?Sized + 'a;

/// Push a null value with the given type information onto the bind collector
///
// For backward compatibility reasons we provide a default implementation
// but custom backends that want to support `#[derive(MultiConnection)]`
// need to provide a customized implementation of this function
#[diesel_derives::__diesel_public_if(
feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"
)]
fn push_null_value(&mut self, _metadata: DB::TypeMetadata) -> QueryResult<()> {
Ok(())
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -105,6 +117,12 @@ where
self.metadata.push(metadata);
Ok(())
}

fn push_null_value(&mut self, metadata: DB::TypeMetadata) -> QueryResult<()> {
self.metadata.push(metadata);
self.binds.push(None);
Ok(())
}
}

// This is private for now as we may want to add `Into` impls for the wrapper type
Expand Down
5 changes: 5 additions & 0 deletions diesel/src/sqlite/connection/bind_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,9 @@ impl<'a> BindCollector<'a, Sqlite> for SqliteBindCollector<'a> {
));
Ok(())
}

fn push_null_value(&mut self, metadata: SqliteType) -> QueryResult<()> {
self.binds.push((InternalSqliteBindValue::Null, metadata));
Ok(())
}
}
69 changes: 51 additions & 18 deletions diesel_derives/src/multiconnection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ fn generate_connection_impl(
}
});

let impl_begin_test_transaction = connection_types.iter().map(|c| {
let ident = c.name;
quote::quote! {
Self::#ident(conn) => conn.begin_test_transaction()
}
});

let r2d2_impl = if cfg!(feature = "r2d2") {
let impl_ping_r2d2 = connection_types.iter().map(|c| {
let ident = c.name;
Expand Down Expand Up @@ -295,6 +302,9 @@ fn generate_connection_impl(
let mut query_builder = self.query_builder.duplicate();
self.inner.to_sql(&mut query_builder, &self.backend)?;
pass.push_sql(&query_builder.finish());
if !self.inner.is_safe_to_cache_prepared(&self.backend)? {
pass.unsafe_to_cache_prepared();
}
if let Some((outer_collector, lookup)) = pass.bind_collector() {
C::handle_inner_pass(outer_collector, lookup, &self.backend, &self.inner)?;
}
Expand Down Expand Up @@ -356,6 +366,12 @@ fn generate_connection_impl(
#(#instrumentation_impl,)*
}
}

fn begin_test_transaction(&mut self) -> diesel::QueryResult<()> {
match self {
#(#impl_begin_test_transaction,)*
}
}
}

impl LoadConnection for MultiConnection
Expand Down Expand Up @@ -757,6 +773,18 @@ fn generate_bind_collector(connection_types: &[ConnectionVariant]) -> TokenStrea
}
});

let push_null_to_inner_collector = connection_types
.iter()
.map(|c| {
let ident = c.name;
quote::quote! {
(Self::#ident(ref mut bc), super::backend::MultiTypeMetadata{ #ident: Some(metadata), .. }) => {
bc.push_null_value(metadata)?;
}
}
})
.collect::<Vec<_>>();

let push_bound_value_super_traits = connection_types
.iter()
.map(|c| {
Expand Down Expand Up @@ -948,20 +976,14 @@ fn generate_bind_collector(connection_types: &[ConnectionVariant]) -> TokenStrea
// set the `inner` field of `BindValue` to something for the `None`
// case. Therefore we need to handle that explicitly here.
//
// We just use a specific sql + rust type here to workaround
// the fact that rustc is not able to see that the underlying DBMS
// must support that sql + rust type combination. All tested DBMS
// (postgres, sqlite, mysql, oracle) seems to not care about the
// actual type here and coerce null values to the "right" type
// anyway
BindValue {
inner: Some(InnerBindValue {
value: InnerBindValueKind::Null,
push_bound_value_to_collector: &PushBoundValueToCollectorImpl {
p: std::marker::PhantomData::<(diesel::sql_types::Integer, i32)>
}
})
let metadata = <MultiBackend as diesel::sql_types::HasSqlType<T>>::metadata(metadata_lookup);
match (self, metadata) {
#(#push_null_to_inner_collector)*
_ => {
unreachable!("We have matching metadata")
},
}
return Ok(());
} else {
out.into_inner()
}
Expand All @@ -972,6 +994,14 @@ fn generate_bind_collector(connection_types: &[ConnectionVariant]) -> TokenStrea

Ok(())
}

fn push_null_value(&mut self, metadata: super::backend::MultiTypeMetadata) -> diesel::QueryResult<()> {
match (self, metadata) {
#(#push_null_to_inner_collector)*
_ => unreachable!("We have matching metadata"),
}
Ok(())
}
}

#(#to_sql_impls)*
Expand Down Expand Up @@ -1368,8 +1398,8 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream {
let type_metadata_variants = connection_types.iter().map(|c| {
let ident = c.name;
let ty = c.ty;
quote::quote!{
#ident(<<#ty as diesel::Connection>::Backend as diesel::sql_types::TypeMetadata>::TypeMetadata)
quote::quote! {
pub(super) #ident: Option<<<#ty as diesel::Connection>::Backend as diesel::sql_types::TypeMetadata>::TypeMetadata>
}
});

Expand Down Expand Up @@ -1456,7 +1486,7 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream {

quote::quote!{
if let Some(lookup) = <#ty as diesel::internal::derives::multiconnection::MultiConnectionHelper>::from_any(lookup) {
return MultiTypeMetadata::#name(<<#ty as diesel::Connection>::Backend as diesel::sql_types::HasSqlType<ST>>::metadata(lookup));
ret.#name = Some(<<#ty as diesel::Connection>::Backend as diesel::sql_types::HasSqlType<ST>>::metadata(lookup));
}
}

Expand All @@ -1480,8 +1510,9 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream {
pub fn lookup_sql_type<ST>(lookup: &mut dyn std::any::Any) -> MultiTypeMetadata
where #(#lookup_sql_type_bounds,)*
{
let mut ret = MultiTypeMetadata::default();
#(#lookup_impl)*
unreachable!()
ret
}
}

Expand Down Expand Up @@ -1519,7 +1550,9 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream {
type BindCollector<'a> = super::bind_collector::MultiBindCollector<'a>;
}

pub enum MultiTypeMetadata {
#[derive(Default)]
#[allow(non_snake_case)]
pub struct MultiTypeMetadata {
#(#type_metadata_variants,)*
}

Expand Down

0 comments on commit 49ac723

Please sign in to comment.