Skip to content

Commit

Permalink
[nexus] Make 'update_and_check' CTE explicitly request columns (#4572)
Browse files Browse the repository at this point in the history
Related to #4570 , but
not a direct fix for it

This PR removes a usage of ".\*" from a SQL query. Using ".\*" in sql
queries is somewhat risky -- it makes an implicit dependency on order,
and can make backwards compatibility difficult in certain circumstances.

Instead, this PR provides a `ColumnWalker`, for converting a tuple of
columns to an iterator, and requests the expected columns explicitly.
  • Loading branch information
smklein authored Nov 29, 2023
1 parent bb7ee84 commit a4e1216
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 20 deletions.
112 changes: 112 additions & 0 deletions nexus/db-queries/src/db/column_walker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

//! CTE utility for iterating over all columns in a table.
use diesel::prelude::*;
use std::marker::PhantomData;

/// Used to iterate over a tuple of columns ("T").
///
/// Diesel exposes "AllColumns" as a tuple, which is difficult to iterate over
/// -- after all, all the types are distinct. However, each of these types
/// implements "Column", so we can use a macro to provide a
/// "convertion-to-iterator" implemenation for our expected tuples.
pub(crate) struct ColumnWalker<T> {
remaining: PhantomData<T>,
}

impl<T> ColumnWalker<T> {
pub fn new() -> Self {
Self { remaining: PhantomData }
}
}

macro_rules! impl_column_walker {
( $len:literal $($column:ident)+ ) => (
impl<$($column: Column),+> IntoIterator for ColumnWalker<($($column,)+)> {
type Item = &'static str;
type IntoIter = std::array::IntoIter<Self::Item, $len>;

fn into_iter(self) -> Self::IntoIter {
[$($column::NAME,)+].into_iter()
}
}
);
}

// implementations for 1 - 32 columns
impl_column_walker! { 1 A }
impl_column_walker! { 2 A B }
impl_column_walker! { 3 A B C }
impl_column_walker! { 4 A B C D }
impl_column_walker! { 5 A B C D E }
impl_column_walker! { 6 A B C D E F }
impl_column_walker! { 7 A B C D E F G }
impl_column_walker! { 8 A B C D E F G H }
impl_column_walker! { 9 A B C D E F G H I }
impl_column_walker! { 10 A B C D E F G H I J }
impl_column_walker! { 11 A B C D E F G H I J K }
impl_column_walker! { 12 A B C D E F G H I J K L }
impl_column_walker! { 13 A B C D E F G H I J K L M }
impl_column_walker! { 14 A B C D E F G H I J K L M N }
impl_column_walker! { 15 A B C D E F G H I J K L M N O }
impl_column_walker! { 16 A B C D E F G H I J K L M N O P }
impl_column_walker! { 17 A B C D E F G H I J K L M N O P Q }
impl_column_walker! { 18 A B C D E F G H I J K L M N O P Q R }
impl_column_walker! { 19 A B C D E F G H I J K L M N O P Q R S }
impl_column_walker! { 20 A B C D E F G H I J K L M N O P Q R S T }
impl_column_walker! { 21 A B C D E F G H I J K L M N O P Q R S T U }
impl_column_walker! { 22 A B C D E F G H I J K L M N O P Q R S T U V }
impl_column_walker! { 23 A B C D E F G H I J K L M N O P Q R S T U V W }
impl_column_walker! { 24 A B C D E F G H I J K L M N O P Q R S T U V W X }
impl_column_walker! { 25 A B C D E F G H I J K L M N O P Q R S T U V W X Y }
impl_column_walker! { 26 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z }
impl_column_walker! { 27 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 }
impl_column_walker! { 28 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 B1 }
impl_column_walker! { 29 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 B1 C1 }
impl_column_walker! { 30 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 B1 C1 D1 }
impl_column_walker! { 31 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 B1 C1 D1 E1 }
impl_column_walker! { 32 A B C D E F G H I J K L M N O P Q R S T U V W X Y Z A1 B1 C1 D1 E1 F1 }

#[cfg(test)]
mod test {
use super::*;

table! {
test_schema.test_table (id) {
id -> Uuid,
value -> Int4,
time_deleted -> Nullable<Timestamptz>,
}
}

// We can convert all a tables columns into an iteratable format.
#[test]
fn test_walk_table() {
let all_columns =
ColumnWalker::<<test_table::table as Table>::AllColumns>::new();

let mut iter = all_columns.into_iter();
assert_eq!(iter.next(), Some("id"));
assert_eq!(iter.next(), Some("value"));
assert_eq!(iter.next(), Some("time_deleted"));
assert_eq!(iter.next(), None);
}

// We can, if we want to, also make a ColumnWalker out of an arbitrary tuple
// of columns.
#[test]
fn test_walk_columns() {
let all_columns = ColumnWalker::<(
test_table::columns::id,
test_table::columns::value,
)>::new();

let mut iter = all_columns.into_iter();
assert_eq!(iter.next(), Some("id"));
assert_eq!(iter.next(), Some("value"));
assert_eq!(iter.next(), None);
}
}
1 change: 1 addition & 0 deletions nexus/db-queries/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod collection_attach;
pub mod collection_detach;
pub mod collection_detach_many;
pub mod collection_insert;
mod column_walker;
mod config;
mod cte_utils;
// This is marked public for use by the integration tests
Expand Down
48 changes: 28 additions & 20 deletions nexus/db-queries/src/db/update_and_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

//! CTE implementation for "UPDATE with extended return status".
use super::column_walker::ColumnWalker;
use super::pool::DbConnection;
use async_bb8_diesel::AsyncRunQueryDsl;
use diesel::associations::HasTable;
Expand All @@ -21,7 +22,7 @@ use std::marker::PhantomData;
/// allows referencing generics with names (and extending usage
/// without re-stating those generic parameters everywhere).
pub trait UpdateStatementExt {
type Table: QuerySource;
type Table: Table + QuerySource;
type WhereClause;
type Changeset;

Expand All @@ -32,7 +33,7 @@ pub trait UpdateStatementExt {

impl<T, U, V> UpdateStatementExt for UpdateStatement<T, U, V>
where
T: QuerySource,
T: Table + QuerySource,
{
type Table = T;
type WhereClause = U;
Expand Down Expand Up @@ -201,11 +202,11 @@ where
///
/// ```text
/// // WITH found AS (SELECT <primary key> FROM T WHERE <primary key = value>)
/// // updated AS (UPDATE T SET <constraints> RETURNING *)
/// // updated AS (UPDATE T SET <constraints> RETURNING <primary key>)
/// // SELECT
/// // found.<primary key>
/// // updated.<primary key>
/// // found.*
/// // found.<all columns>
/// // FROM
/// // found
/// // LEFT JOIN
Expand All @@ -217,41 +218,48 @@ impl<US, K, Q> QueryFragment<Pg> for UpdateAndQueryStatement<US, K, Q>
where
US: UpdateStatementExt,
US::Table: HasTable<Table = US::Table> + Table,
ColumnWalker<<<US as UpdateStatementExt>::Table as Table>::AllColumns>:
IntoIterator<Item = &'static str>,
PrimaryKey<US>: diesel::Column,
UpdateStatement<US::Table, US::WhereClause, US::Changeset>:
QueryFragment<Pg>,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
let primary_key = <PrimaryKey<US> as Column>::NAME;

out.push_sql("WITH found AS (");
self.find_subquery.walk_ast(out.reborrow())?;
out.push_sql("), updated AS (");
self.update_statement.walk_ast(out.reborrow())?;
// TODO: Only need primary? Or would we actually want
// to pass the returned rows back through the result?
out.push_sql(" RETURNING *) ");
out.push_sql(" RETURNING ");
out.push_identifier(primary_key)?;
out.push_sql(") ");

out.push_sql("SELECT");

let name = <PrimaryKey<US> as Column>::NAME;
out.push_sql(" found.");
out.push_identifier(name)?;
out.push_identifier(primary_key)?;
out.push_sql(", updated.");
out.push_identifier(name)?;
// TODO: I'd prefer to list all columns explicitly. But how?
// The types exist within Table::AllColumns, and each one
// has a name as "<C as Column>::Name".
// But Table::AllColumns is a tuple, which makes iteration
// a pain.
//
// TODO: Technically, we're repeating the PK here.
out.push_sql(", found.*");
out.push_identifier(primary_key)?;

// List all the "found" columns explicitly.
// This admittedly repeats the primary key, but that keeps the query
// "simple" since it returns all columns in the same order as
// AllColumns.
let all_columns = ColumnWalker::<
<<US as UpdateStatementExt>::Table as Table>::AllColumns,
>::new();
for column in all_columns.into_iter() {
out.push_sql(", found.");
out.push_identifier(column)?;
}

out.push_sql(" FROM found LEFT JOIN updated ON");
out.push_sql(" found.");
out.push_identifier(name)?;
out.push_identifier(primary_key)?;
out.push_sql(" = ");
out.push_sql("updated.");
out.push_identifier(name)?;
out.push_identifier(primary_key)?;

Ok(())
}
Expand Down

0 comments on commit a4e1216

Please sign in to comment.