Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nexus] Make 'update_and_check' CTE explicitly request columns #4572

Merged
merged 2 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions nexus/db-queries/src/db/column_walker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

//! CTE utility for iterating over all columns in a table.

use diesel::prelude::*;
use std::marker::PhantomData;

/// Used to iterate over a tuple of columns ("T").
///
/// Diesel exposes "AllColumns" as a tuple, which is difficult to iterate over
/// -- after all, all the types are distinct. However, each of these types
/// implements "Column", so we can use a macro to provide a
/// "convertion-to-iterator" implemenation for our expected tuples.
pub(crate) struct ColumnWalker<T> {
remaining: PhantomData<T>,
}

impl<T> ColumnWalker<T> {
pub fn new() -> Self {
Self { remaining: PhantomData }
}
}

macro_rules! impl_column_walker {
( $len:literal $($column:ident)+ ) => (
impl<$($column: Column),+> IntoIterator for ColumnWalker<($($column,)+)> {
type Item = &'static str;
type IntoIter = std::array::IntoIter<Self::Item, $len>;

fn into_iter(self) -> Self::IntoIter {
[$($column::NAME,)+].into_iter()
}
}
);
}

// implementations for 1 - 32 columns
impl_column_walker! { 1 A }
impl_column_walker! { 2 A B }
impl_column_walker! { 3 A B C }
impl_column_walker! { 4 A B C D }
impl_column_walker! { 5 A B C D E }
impl_column_walker! { 6 A B C D E F }
impl_column_walker! { 7 A B C D E F G }
impl_column_walker! { 8 A B C D E F G H }
impl_column_walker! { 9 A B C D E F G H I }
impl_column_walker! { 10 A B C D E F G H I J }
impl_column_walker! { 11 A B C D E F G H I J K }
impl_column_walker! { 12 A B C D E F G H I J K L }
impl_column_walker! { 13 A B C D E F G H I J K L M }
impl_column_walker! { 14 A B C D E F G H I J K L M N }
impl_column_walker! { 15 A B C D E F G H I J K L M N O }
impl_column_walker! { 16 A B C D E F G H I J K L M N O P }
impl_column_walker! { 17 A B C D E F G H I J K L M N O P Q }
impl_column_walker! { 18 A B C D E F G H I J K L M N O P Q R }
impl_column_walker! { 19 A B C D E F G H I J K L M N O P Q R S }
impl_column_walker! { 20 A B C D E F G H I J K L M N O P Q R S T }
impl_column_walker! { 21 A B C D E F G H I J K L M N O P Q R S T U }
impl_column_walker! { 22 A B C D E F G H I J K L M N O P Q R S T U V }
impl_column_walker! { 23 A B C D E F G H I J K L M N O P Q R S T U V W }
impl_column_walker! { 24 A B C D E F G H I J K L M N O P Q R S T U V W X }
impl_column_walker! { 25 A B C D E F G H I J K L M N O P Q R S T U V W X Y }
impl_column_walker! { 26 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z }
impl_column_walker! { 27 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 }
impl_column_walker! { 28 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 B1 }
impl_column_walker! { 29 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 B1 C1 }
impl_column_walker! { 30 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 B1 C1 D1 }
impl_column_walker! { 31 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 B1 C1 D1 E1 }
impl_column_walker! { 32 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 B1 C1 D1 E1 F1 }

#[cfg(test)]
mod test {
use super::*;

table! {
test_schema.test_table (id) {
id -> Uuid,
value -> Int4,
time_deleted -> Nullable<Timestamptz>,
}
}

// We can convert all a tables columns into an iteratable format.
#[test]
fn test_walk_table() {
let all_columns =
ColumnWalker::<<test_table::table as Table>::AllColumns>::new();

let mut iter = all_columns.into_iter();
assert_eq!(iter.next(), Some("id"));
assert_eq!(iter.next(), Some("value"));
assert_eq!(iter.next(), Some("time_deleted"));
assert_eq!(iter.next(), None);
}

// We can, if we want to, also make a ColumnWalker out of an arbitrary tuple
// of columns.
#[test]
fn test_walk_columns() {
let all_columns = ColumnWalker::<(
test_table::columns::id,
test_table::columns::value,
)>::new();

let mut iter = all_columns.into_iter();
assert_eq!(iter.next(), Some("id"));
assert_eq!(iter.next(), Some("value"));
assert_eq!(iter.next(), None);
}
}
1 change: 1 addition & 0 deletions nexus/db-queries/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod collection_attach;
pub mod collection_detach;
pub mod collection_detach_many;
pub mod collection_insert;
mod column_walker;
mod config;
mod cte_utils;
// This is marked public for use by the integration tests
Expand Down
48 changes: 28 additions & 20 deletions nexus/db-queries/src/db/update_and_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

//! CTE implementation for "UPDATE with extended return status".

use super::column_walker::ColumnWalker;
use super::pool::DbConnection;
use async_bb8_diesel::AsyncRunQueryDsl;
use diesel::associations::HasTable;
Expand All @@ -21,7 +22,7 @@ use std::marker::PhantomData;
/// allows referencing generics with names (and extending usage
/// without re-stating those generic parameters everywhere).
pub trait UpdateStatementExt {
type Table: QuerySource;
type Table: Table + QuerySource;
type WhereClause;
type Changeset;

Expand All @@ -32,7 +33,7 @@ pub trait UpdateStatementExt {

impl<T, U, V> UpdateStatementExt for UpdateStatement<T, U, V>
where
T: QuerySource,
T: Table + QuerySource,
{
type Table = T;
type WhereClause = U;
Expand Down Expand Up @@ -201,11 +202,11 @@ where
///
/// ```text
/// // WITH found AS (SELECT <primary key> FROM T WHERE <primary key = value>)
/// // updated AS (UPDATE T SET <constraints> RETURNING *)
/// // updated AS (UPDATE T SET <constraints> RETURNING <primary key>)
/// // SELECT
/// // found.<primary key>
/// // updated.<primary key>
/// // found.*
/// // found.<all columns>
/// // FROM
/// // found
/// // LEFT JOIN
Expand All @@ -217,41 +218,48 @@ impl<US, K, Q> QueryFragment<Pg> for UpdateAndQueryStatement<US, K, Q>
where
US: UpdateStatementExt,
US::Table: HasTable<Table = US::Table> + Table,
ColumnWalker<<<US as UpdateStatementExt>::Table as Table>::AllColumns>:
IntoIterator<Item = &'static str>,
PrimaryKey<US>: diesel::Column,
UpdateStatement<US::Table, US::WhereClause, US::Changeset>:
QueryFragment<Pg>,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
let primary_key = <PrimaryKey<US> as Column>::NAME;

out.push_sql("WITH found AS (");
self.find_subquery.walk_ast(out.reborrow())?;
out.push_sql("), updated AS (");
self.update_statement.walk_ast(out.reborrow())?;
// TODO: Only need primary? Or would we actually want
// to pass the returned rows back through the result?
out.push_sql(" RETURNING *) ");
out.push_sql(" RETURNING ");
out.push_identifier(primary_key)?;
out.push_sql(") ");

out.push_sql("SELECT");

let name = <PrimaryKey<US> as Column>::NAME;
out.push_sql(" found.");
out.push_identifier(name)?;
out.push_identifier(primary_key)?;
out.push_sql(", updated.");
out.push_identifier(name)?;
// TODO: I'd prefer to list all columns explicitly. But how?
// The types exist within Table::AllColumns, and each one
// has a name as "<C as Column>::Name".
// But Table::AllColumns is a tuple, which makes iteration
// a pain.
//
// TODO: Technically, we're repeating the PK here.
out.push_sql(", found.*");
out.push_identifier(primary_key)?;

// List all the "found" columns explicitly.
// This admittedly repeats the primary key, but that keeps the query
// "simple" since it returns all columns in the same order as
// AllColumns.
let all_columns = ColumnWalker::<
<<US as UpdateStatementExt>::Table as Table>::AllColumns,
>::new();
for column in all_columns.into_iter() {
out.push_sql(", found.");
out.push_identifier(column)?;
}

out.push_sql(" FROM found LEFT JOIN updated ON");
out.push_sql(" found.");
out.push_identifier(name)?;
out.push_identifier(primary_key)?;
out.push_sql(" = ");
out.push_sql("updated.");
out.push_identifier(name)?;
out.push_identifier(primary_key)?;

Ok(())
}
Expand Down
Loading