From 201e4b84535572db23aaeea3fe1f785cddbb6b80 Mon Sep 17 00:00:00 2001 From: Ten0 <9094255+Ten0@users.noreply.github.com> Date: Tue, 23 May 2023 15:05:24 +0200 Subject: [PATCH] Implement field length introspection and make available in schema (#3552) Resolves #3551 --- .github/workflows/ci.yml | 4 + diesel/src/lib.rs | 1 + diesel/src/query_source/mod.rs | 9 ++ .../infer_schema_internals/data_structures.rs | 67 ++++---- .../src/infer_schema_internals/mysql.rs | 61 +++++++- diesel_cli/src/infer_schema_internals/pg.rs | 69 ++++++++- .../src/infer_schema_internals/sqlite.rs | 16 ++ diesel_cli/src/migrations/diff_schema.rs | 18 ++- diesel_cli/src/print_schema.rs | 22 ++- .../mysql/schema_out.rs/expected.snap | 3 +- .../postgres/schema_out.rs/expected.snap | 3 +- .../mysql/down.sql/expected.snap | 3 +- .../postgres/down.sql/expected.snap | 9 +- .../mysql/expected.snap | 1 + .../mysql/expected.snap | 3 + .../mysql/expected.snap | 2 + diesel_derives/src/table.rs | 10 ++ diesel_table_macro_syntax/src/lib.rs | 143 +++++++++--------- diesel_table_macro_syntax/tests/basic.rs | 18 +++ diesel_table_macro_syntax/tests/basic.rs.in | 7 + 20 files changed, 327 insertions(+), 142 deletions(-) create mode 100644 diesel_table_macro_syntax/tests/basic.rs create mode 100644 diesel_table_macro_syntax/tests/basic.rs.in diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d5f76f7095dd..322d62d727c7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -246,6 +246,10 @@ jobs: shell: bash run: cargo +${{ matrix.rust }} test --manifest-path diesel_migrations/migrations_macros/Cargo.toml --features "diesel/${{ matrix.backend }} ${{ matrix.backend }}" + - name: Test table-macro-syntax + shell: bash + run: cargo +${{ matrix.rust }} test --manifest-path diesel_table_macro_syntax/Cargo.toml + - name: Test diesel_migrations shell: bash run: cargo +${{ matrix.rust }} test --manifest-path diesel_migrations/Cargo.toml --features "${{ matrix.backend }} diesel/${{ matrix.backend }}" diff --git a/diesel/src/lib.rs b/diesel/src/lib.rs index f4c047eaef64..5b89ff49ad07 100644 --- a/diesel/src/lib.rs +++ b/diesel/src/lib.rs @@ -649,6 +649,7 @@ pub mod prelude { pub use crate::query_dsl::{ BelongingToDsl, CombineDsl, JoinOnDsl, QueryDsl, RunQueryDsl, SaveChangesDsl, }; + pub use crate::query_source::SizeRestrictedColumn as _; #[doc(inline)] pub use crate::query_source::{Column, JoinTo, QuerySource, Table}; #[doc(inline)] diff --git a/diesel/src/query_source/mod.rs b/diesel/src/query_source/mod.rs index aaf1be08060f..a007c7434a84 100644 --- a/diesel/src/query_source/mod.rs +++ b/diesel/src/query_source/mod.rs @@ -189,3 +189,12 @@ mod impls_which_are_only_here_to_improve_error_messages { type Selection = this_table_appears_in_your_query_more_than_once_and_must_be_aliased; } } + +/// Max length for columns of type Char/Varchar... +/// +/// If a given column has a such constraint, this trait will be implemented and specify that +/// length. +pub trait SizeRestrictedColumn: Column { + /// Max length of that column + const MAX_LENGTH: usize; +} diff --git a/diesel_cli/src/infer_schema_internals/data_structures.rs b/diesel_cli/src/infer_schema_internals/data_structures.rs index 2a7964a4a061..4f51b04e73e1 100644 --- a/diesel_cli/src/infer_schema_internals/data_structures.rs +++ b/diesel_cli/src/infer_schema_internals/data_structures.rs @@ -1,11 +1,7 @@ -#[cfg(feature = "uses_information_schema")] -use diesel::backend::Backend; -use diesel::deserialize::{self, FromStaticSqlRow, Queryable}; -#[cfg(feature = "sqlite")] -use diesel::sqlite::Sqlite; +use diesel_table_macro_syntax::ColumnDef; + +use std::error::Error; -#[cfg(feature = "uses_information_schema")] -use super::information_schema::DefaultSchema; use super::table_data::TableName; #[derive(Debug, Clone, PartialEq, Eq)] @@ -14,6 +10,7 @@ pub struct ColumnInformation { pub type_name: String, pub type_schema: Option, pub nullable: bool, + pub max_length: Option, pub comment: Option, } @@ -25,10 +22,26 @@ pub struct ColumnType { pub is_array: bool, pub is_nullable: bool, pub is_unsigned: bool, + pub max_length: Option, } -impl From<&syn::TypePath> for ColumnType { - fn from(t: &syn::TypePath) -> Self { +impl ColumnType { + pub(crate) fn for_column_def(c: &ColumnDef) -> Result> { + Ok(Self::for_type_path( + &c.tpe, + c.max_length + .as_ref() + .map(|l| { + l.base10_parse::() + .map_err(|e| -> Box { + format!("Column length literal can't be parsed as u64: {e}").into() + }) + }) + .transpose()?, + )) + } + + fn for_type_path(t: &syn::TypePath, max_length: Option) -> Self { let last = t .path .segments @@ -42,6 +55,7 @@ impl From<&syn::TypePath> for ColumnType { is_array: last.ident == "Array", is_nullable: last.ident == "Nullable", is_unsigned: last.ident == "Unsigned", + max_length, }; let sql_name = if !ret.is_nullable && !ret.is_array && !ret.is_unsigned { @@ -53,7 +67,7 @@ impl From<&syn::TypePath> for ColumnType { } else if let syn::PathArguments::AngleBracketed(ref args) = last.arguments { let arg = args.args.first().expect("There is at least one argument"); if let syn::GenericArgument::Type(syn::Type::Path(p)) = arg { - let s = Self::from(p); + let s = Self::for_type_path(p, max_length); ret.is_nullable |= s.is_nullable; ret.is_array |= s.is_array; ret.is_unsigned |= s.is_unsigned; @@ -110,6 +124,7 @@ impl ColumnInformation { type_name: U, type_schema: Option, nullable: bool, + max_length: Option, comment: Option, ) -> Self where @@ -121,42 +136,12 @@ impl ColumnInformation { type_name: type_name.into(), type_schema, nullable, + max_length, comment, } } } -#[cfg(feature = "uses_information_schema")] -impl Queryable for ColumnInformation -where - DB: Backend + DefaultSchema, - (String, String, Option, String, Option): FromStaticSqlRow, -{ - type Row = (String, String, Option, String, Option); - - fn build(row: Self::Row) -> deserialize::Result { - Ok(ColumnInformation::new( - row.0, - row.1, - row.2, - row.3 == "YES", - row.4, - )) - } -} - -#[cfg(feature = "sqlite")] -impl Queryable for ColumnInformation -where - (i32, String, String, bool, Option, bool, i32): FromStaticSqlRow, -{ - type Row = (i32, String, String, bool, Option, bool, i32); - - fn build(row: Self::Row) -> deserialize::Result { - Ok(ColumnInformation::new(row.1, row.2, None, !row.3, None)) - } -} - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct ForeignKeyConstraint { pub child_table: TableName, diff --git a/diesel_cli/src/infer_schema_internals/mysql.rs b/diesel_cli/src/infer_schema_internals/mysql.rs index 9a4e14ea7912..6a07a0443c1e 100644 --- a/diesel_cli/src/infer_schema_internals/mysql.rs +++ b/diesel_cli/src/infer_schema_internals/mysql.rs @@ -1,3 +1,4 @@ +use diesel::deserialize::{self, FromStaticSqlRow, Queryable}; use diesel::mysql::{Mysql, MysqlConnection}; use diesel::*; use heck::ToUpperCamelCase; @@ -33,14 +34,60 @@ pub fn get_table_data( column_type, type_schema, __is_nullable, + character_maximum_length, // MySQL comments are not nullable and are empty strings if not set null_if_text(column_comment, ""), )) .filter(table_name.eq(&table.sql_name)) .filter(table_schema.eq(schema_name)); - match column_sorting { - ColumnSorting::OrdinalPosition => query.order(ordinal_position).load(conn), - ColumnSorting::Name => query.order(column_name).load(conn), + let mut table_columns: Vec = match column_sorting { + ColumnSorting::OrdinalPosition => query.order(ordinal_position).load(conn)?, + ColumnSorting::Name => query.order(column_name).load(conn)?, + }; + for c in &mut table_columns { + if c.max_length.is_some() && !c.type_name.contains('(') { + // Mysql returns something in character_maximum_length regardless + // of whether it's specified at field creation time + // In addition there is typically a shared limitation at row level, + // so it's typically not even the real max. + // This basically means no max. + // https://dev.mysql.com/doc/refman/8.0/en/column-count-limit.html + // https://chartio.com/resources/tutorials/understanding-strorage-sizes-for-mysql-text-data-types/ + c.max_length = None; + } + } + Ok(table_columns) +} + +impl Queryable for ColumnInformation +where + ( + String, + String, + Option, + String, + Option, + Option, + ): FromStaticSqlRow, +{ + type Row = ( + String, + String, + Option, + String, + Option, + Option, + ); + + fn build(row: Self::Row) -> deserialize::Result { + Ok(ColumnInformation::new( + row.0, + row.1, + row.2, + row.3 == "YES", + row.4, + row.5, + )) } } @@ -84,7 +131,8 @@ mod information_schema { column_name -> VarChar, #[sql_name = "is_nullable"] __is_nullable -> VarChar, - ordinal_position -> BigInt, + character_maximum_length -> Nullable>, + ordinal_position -> Unsigned, udt_name -> VarChar, udt_schema -> VarChar, column_type -> VarChar, @@ -171,6 +219,7 @@ pub fn determine_column_type( is_array: false, is_nullable: attr.nullable, is_unsigned: unsigned, + max_length: attr.max_length, }) } @@ -324,9 +373,11 @@ mod test { "varchar(255)", None, false, + Some(255), Some("column comment".to_string()), ); - let id_without_comment = ColumnInformation::new("id", "varchar(255)", None, false, None); + let id_without_comment = + ColumnInformation::new("id", "varchar(255)", None, false, Some(255), None); assert_eq!( Ok(vec![id_with_comment]), get_table_data(&mut connection, &table_1, &ColumnSorting::OrdinalPosition) diff --git a/diesel_cli/src/infer_schema_internals/pg.rs b/diesel_cli/src/infer_schema_internals/pg.rs index 1cf7f5968649..bb0635979aa6 100644 --- a/diesel_cli/src/infer_schema_internals/pg.rs +++ b/diesel_cli/src/infer_schema_internals/pg.rs @@ -2,7 +2,14 @@ use super::data_structures::*; use super::information_schema::DefaultSchema; use super::TableName; use crate::print_schema::ColumnSorting; -use diesel::{dsl::AsExprOf, expression::AsExpression, pg::Pg, prelude::*, sql_types}; +use diesel::{ + deserialize::{self, FromStaticSqlRow, Queryable}, + dsl::AsExprOf, + expression::AsExpression, + pg::Pg, + prelude::*, + sql_types, +}; use heck::ToUpperCamelCase; use std::borrow::Cow; use std::error::Error; @@ -48,6 +55,7 @@ pub fn determine_column_type( is_array, is_nullable: attr.nullable, is_unsigned: false, + max_length: attr.max_length, }) } @@ -79,6 +87,7 @@ pub fn get_table_data( udt_name, udt_schema.nullable(), __is_nullable, + character_maximum_length, col_description(regclass(table), ordinal_position), )) .filter(table_name.eq(&table.sql_name)) @@ -89,6 +98,44 @@ pub fn get_table_data( } } +impl Queryable for ColumnInformation +where + ( + String, + String, + Option, + String, + Option, + Option, + ): FromStaticSqlRow, +{ + type Row = ( + String, + String, + Option, + String, + Option, + Option, + ); + + fn build(row: Self::Row) -> deserialize::Result { + Ok(ColumnInformation::new( + row.0, + row.1, + row.2, + row.3 == "YES", + row.4 + .map(|n| { + std::convert::TryInto::try_into(n).map_err(|e| { + format!("Max column length can't be converted to u64: {e} (got: {n})") + }) + }) + .transpose()?, + row.5, + )) + } +} + sql_function!(fn obj_description(oid: sql_types::Oid, catalog: sql_types::Text) -> Nullable); pub fn get_table_comment( @@ -108,6 +155,7 @@ mod information_schema { column_name -> VarChar, #[sql_name = "is_nullable"] __is_nullable -> VarChar, + character_maximum_length -> Nullable, ordinal_position -> BigInt, udt_name -> VarChar, udt_schema -> VarChar, @@ -221,7 +269,7 @@ mod test { .execute(&mut connection) .unwrap(); diesel::sql_query( - "CREATE TABLE test_schema.table_1 (id SERIAL PRIMARY KEY, text_col VARCHAR, not_null TEXT NOT NULL)", + "CREATE TABLE test_schema.table_1 (id SERIAL PRIMARY KEY, text_col VARCHAR(128), not_null TEXT NOT NULL)", ).execute(&mut connection) .unwrap(); diesel::sql_query("COMMENT ON COLUMN test_schema.table_1.id IS 'column comment'") @@ -239,12 +287,21 @@ mod test { "int4", pg_catalog.clone(), false, + None, Some("column comment".to_string()), ); - let text_col = - ColumnInformation::new("text_col", "varchar", pg_catalog.clone(), true, None); - let not_null = ColumnInformation::new("not_null", "text", pg_catalog.clone(), false, None); - let array_col = ColumnInformation::new("array_col", "_varchar", pg_catalog, false, None); + let text_col = ColumnInformation::new( + "text_col", + "varchar", + pg_catalog.clone(), + true, + Some(128), + None, + ); + let not_null = + ColumnInformation::new("not_null", "text", pg_catalog.clone(), false, None, None); + let array_col = + ColumnInformation::new("array_col", "_varchar", pg_catalog, false, None, None); assert_eq!( Ok(vec![id, text_col, not_null]), get_table_data(&mut connection, &table_1, &ColumnSorting::OrdinalPosition) diff --git a/diesel_cli/src/infer_schema_internals/sqlite.rs b/diesel_cli/src/infer_schema_internals/sqlite.rs index f77fe843c591..cbc64e267988 100644 --- a/diesel_cli/src/infer_schema_internals/sqlite.rs +++ b/diesel_cli/src/infer_schema_internals/sqlite.rs @@ -1,6 +1,8 @@ use std::error::Error; +use diesel::deserialize::{self, FromStaticSqlRow, Queryable}; use diesel::dsl::sql; +use diesel::sqlite::Sqlite; use diesel::*; use super::data_structures::*; @@ -151,6 +153,19 @@ pub fn get_table_data( Ok(result) } +impl Queryable for ColumnInformation +where + (i32, String, String, bool, Option, bool, i32): FromStaticSqlRow, +{ + type Row = (i32, String, String, bool, Option, bool, i32); + + fn build(row: Self::Row) -> deserialize::Result { + Ok(ColumnInformation::new( + row.1, row.2, None, !row.3, None, None, + )) + } +} + #[derive(Queryable)] struct FullTableInfo { _cid: i32, @@ -232,6 +247,7 @@ pub fn determine_column_type( is_array: false, is_nullable: attr.nullable, is_unsigned: false, + max_length: attr.max_length, }) } diff --git a/diesel_cli/src/migrations/diff_schema.rs b/diesel_cli/src/migrations/diff_schema.rs index 2e4fb8405fb7..107cca6be50c 100644 --- a/diesel_cli/src/migrations/diff_schema.rs +++ b/diesel_cli/src/migrations/diff_schema.rs @@ -126,7 +126,7 @@ pub fn generate_sql_based_on_diff_schema( for c in columns.column_data { if let Some(def) = expected_column_map.remove(&c.sql_name.to_lowercase()) { - let tpe = ColumnType::from(&def.tpe); + let tpe = ColumnType::for_column_def(&def)?; if !is_same_type(&c.ty, tpe) { changed_columns.push((c, def)); } @@ -242,6 +242,7 @@ fn is_same_type(ty: &ColumnType, tpe: ColumnType) -> bool { if ty.is_array != tpe.is_array || ty.is_nullable != tpe.is_nullable || ty.is_unsigned != tpe.is_unsigned + || ty.max_length != tpe.max_length { return false; } @@ -313,15 +314,16 @@ impl SchemaDiff { .column_defs .iter() .map(|c| { - let ty = ColumnType::from(&c.tpe); - ColumnDefinition { + let ty = ColumnType::for_column_def(c) + .map_err(diesel::result::Error::QueryBuilderError)?; + Ok(ColumnDefinition { sql_name: c.sql_name.to_lowercase(), rust_name: c.sql_name.clone(), ty, comment: None, - } + }) }) - .collect::>(); + .collect::>>()?; let foreign_keys = foreign_keys .iter() .map(|(f, pk)| { @@ -361,7 +363,8 @@ impl SchemaDiff { query_builder, &table.to_lowercase(), &c.column_name.to_string().to_lowercase(), - &ColumnType::from(&c.tpe), + &ColumnType::for_column_def(c) + .map_err(diesel::result::Error::QueryBuilderError)?, )?; query_builder.push_sql("\n"); } @@ -548,6 +551,9 @@ where { // TODO: handle schema query_builder.push_sql(&format!(" {}", ty.sql_name.to_uppercase())); + if let Some(max_length) = ty.max_length { + query_builder.push_sql(&format!("({max_length})")); + } if !ty.is_nullable { query_builder.push_sql(" NOT NULL"); } diff --git a/diesel_cli/src/print_schema.rs b/diesel_cli/src/print_schema.rs index 72f7bc1e7a7d..c1dca8a4f893 100644 --- a/diesel_cli/src/print_schema.rs +++ b/diesel_cli/src/print_schema.rs @@ -696,12 +696,24 @@ impl<'a> Display for ColumnDefinitions<'a> { } } - if column.rust_name == column.sql_name { - writeln!(out, "{} -> {},", column.sql_name, column_type)?; - } else { - writeln!(out, r#"#[sql_name = "{}"]"#, column.sql_name)?; - writeln!(out, "{} -> {},", column.rust_name, column_type)?; + // Write out attributes + if column.rust_name != column.sql_name || column.ty.max_length.is_some() { + let mut is_first = true; + write!(out, r#"#["#)?; + if column.rust_name != column.sql_name { + write!(out, r#"sql_name = {:?}"#, column.sql_name)?; + is_first = false; + } + if let Some(max_length) = column.ty.max_length { + if !is_first { + write!(out, ", ")?; + } + write!(out, "max_length = {}", max_length)?; + } + writeln!(out, r#"]"#)?; } + + writeln!(out, "{} -> {},", column.rust_name, column_type)?; } } writeln!(f, "}}")?; diff --git a/diesel_cli/tests/generate_migrations/diff_add_table_all_the_types/mysql/schema_out.rs/expected.snap b/diesel_cli/tests/generate_migrations/diff_add_table_all_the_types/mysql/schema_out.rs/expected.snap index 2e888b34447b..535aa4253fd6 100644 --- a/diesel_cli/tests/generate_migrations/diff_add_table_all_the_types/mysql/schema_out.rs/expected.snap +++ b/diesel_cli/tests/generate_migrations/diff_add_table_all_the_types/mysql/schema_out.rs/expected.snap @@ -1,6 +1,5 @@ --- source: diesel_cli/tests/migration_generate.rs -assertion_line: 354 description: "Test: diff_add_table_all_the_types" --- // @generated automatically by Diesel CLI. @@ -12,6 +11,7 @@ diesel::table! { integer_column -> Integer, small_int_col -> Smallint, big_int_col -> Bigint, + #[max_length = 1] binary_col -> Binary, text_col -> Text, double_col -> Double, @@ -28,6 +28,7 @@ diesel::table! { big_int2_col -> Bigint, float8_col -> Double, decimal_col -> Decimal, + #[max_length = 1] char_col -> Char, tinytext_col -> Tinytext, mediumtext_col -> Mediumtext, diff --git a/diesel_cli/tests/generate_migrations/diff_add_table_all_the_types/postgres/schema_out.rs/expected.snap b/diesel_cli/tests/generate_migrations/diff_add_table_all_the_types/postgres/schema_out.rs/expected.snap index 549367029c3d..256bc5c28f57 100644 --- a/diesel_cli/tests/generate_migrations/diff_add_table_all_the_types/postgres/schema_out.rs/expected.snap +++ b/diesel_cli/tests/generate_migrations/diff_add_table_all_the_types/postgres/schema_out.rs/expected.snap @@ -1,6 +1,5 @@ --- source: diesel_cli/tests/migration_generate.rs -assertion_line: 354 description: "Test: diff_add_table_all_the_types" --- // @generated automatically by Diesel CLI. @@ -28,7 +27,9 @@ diesel::table! { decimal_col -> Numeric, varchar_col -> Varchar, varchar2_col -> Varchar, + #[max_length = 1] char_col -> Bpchar, + #[max_length = 1] bit_col -> Bit, cidr_col -> Cidr, inet_col -> Inet, diff --git a/diesel_cli/tests/generate_migrations/diff_drop_table_all_the_types/mysql/down.sql/expected.snap b/diesel_cli/tests/generate_migrations/diff_drop_table_all_the_types/mysql/down.sql/expected.snap index 369d2e5c1b48..c2d617bf1cdd 100644 --- a/diesel_cli/tests/generate_migrations/diff_drop_table_all_the_types/mysql/down.sql/expected.snap +++ b/diesel_cli/tests/generate_migrations/diff_drop_table_all_the_types/mysql/down.sql/expected.snap @@ -1,6 +1,5 @@ --- source: diesel_cli/tests/migration_generate.rs -assertion_line: 338 description: "Test: diff_drop_table_all_the_types" --- -- This file should undo anything in `up.sql` @@ -19,7 +18,7 @@ CREATE TABLE `test`( `datetime` DATETIME, `timestamp` TIMESTAMP NOT NULL, `time` TIME, - `char` CHAR, + `char` CHAR(50), `blob` BLOB, `text` TEXT ); diff --git a/diesel_cli/tests/generate_migrations/diff_drop_table_all_the_types/postgres/down.sql/expected.snap b/diesel_cli/tests/generate_migrations/diff_drop_table_all_the_types/postgres/down.sql/expected.snap index d4af36ca206e..24a46f5d935a 100644 --- a/diesel_cli/tests/generate_migrations/diff_drop_table_all_the_types/postgres/down.sql/expected.snap +++ b/diesel_cli/tests/generate_migrations/diff_drop_table_all_the_types/postgres/down.sql/expected.snap @@ -1,6 +1,5 @@ --- source: diesel_cli/tests/migration_generate.rs -assertion_line: 319 description: "Test: diff_drop_table_all_the_types" --- -- This file should undo anything in `up.sql` @@ -12,10 +11,10 @@ CREATE TABLE "test"( "bigserial2" INT8 NOT NULL, "boolean" BOOL, "bytea" BYTEA, - "character" BPCHAR, - "char" BPCHAR, - "varchar" VARCHAR, - "varchar2" VARCHAR, + "character" BPCHAR(5), + "char" BPCHAR(5), + "varchar" VARCHAR(5), + "varchar2" VARCHAR(5), "cidr" CIDR, "date" DATE, "double" FLOAT8, diff --git a/diesel_cli/tests/print_schema/print_schema_column_order/mysql/expected.snap b/diesel_cli/tests/print_schema/print_schema_column_order/mysql/expected.snap index 5d7de645d117..23127a80cb6c 100644 --- a/diesel_cli/tests/print_schema/print_schema_column_order/mysql/expected.snap +++ b/diesel_cli/tests/print_schema/print_schema_column_order/mysql/expected.snap @@ -7,6 +7,7 @@ description: "Test: print_schema_column_order" diesel::table! { abc (a) { a -> Integer, + #[max_length = 16] b -> Varchar, c -> Nullable, } diff --git a/diesel_cli/tests/print_schema/print_schema_several_keys_with_compound_key/mysql/expected.snap b/diesel_cli/tests/print_schema/print_schema_several_keys_with_compound_key/mysql/expected.snap index ed9f20fc399e..5099596c0383 100644 --- a/diesel_cli/tests/print_schema/print_schema_several_keys_with_compound_key/mysql/expected.snap +++ b/diesel_cli/tests/print_schema/print_schema_several_keys_with_compound_key/mysql/expected.snap @@ -7,6 +7,7 @@ description: "Test: print_schema_several_keys_with_compound_key" diesel::table! { payment_card (id) { id -> Integer, + #[max_length = 50] code -> Varchar, } } @@ -14,6 +15,7 @@ diesel::table! { diesel::table! { transaction_one (id) { id -> Integer, + #[max_length = 50] card_code -> Varchar, payment_card_id -> Integer, by_card_id -> Integer, @@ -24,6 +26,7 @@ diesel::table! { transaction_two (id) { id -> Integer, payment_card_id -> Integer, + #[max_length = 50] card_code -> Varchar, } } diff --git a/diesel_cli/tests/print_schema/print_schema_with_enum_set_types/mysql/expected.snap b/diesel_cli/tests/print_schema/print_schema_with_enum_set_types/mysql/expected.snap index 9138e4b0e558..cfa188fcfc1b 100644 --- a/diesel_cli/tests/print_schema/print_schema_with_enum_set_types/mysql/expected.snap +++ b/diesel_cli/tests/print_schema/print_schema_with_enum_set_types/mysql/expected.snap @@ -45,12 +45,14 @@ diesel::table! { /// Its SQL type is `Users1UserStateEnum`. /// /// (Automatically generated by Diesel.) + #[max_length = 8] user_state -> Users1UserStateEnum, /// The `enabled_features` column of the `users1` table. /// /// Its SQL type is `Users1EnabledFeaturesSet`. /// /// (Automatically generated by Diesel.) + #[max_length = 19] enabled_features -> Users1EnabledFeaturesSet, } } diff --git a/diesel_derives/src/table.rs b/diesel_derives/src/table.rs index a32e0d4413df..377004bc99b5 100644 --- a/diesel_derives/src/table.rs +++ b/diesel_derives/src/table.rs @@ -663,6 +663,14 @@ fn expand_column_def(column_def: &ColumnDef) -> TokenStream { None }; + let max_length = column_def.max_length.as_ref().map(|column_max_length| { + quote::quote! { + impl self::diesel::query_source::SizeRestrictedColumn for #column_name { + const MAX_LENGTH: usize = #column_max_length; + } + } + }); + quote::quote_spanned! {span=> #(#meta)* #[allow(non_camel_case_types, dead_code)] @@ -768,6 +776,8 @@ fn expand_column_def(column_def: &ColumnDef) -> TokenStream { } } + #max_length + #ops_impls #backend_specific_column_impl } diff --git a/diesel_table_macro_syntax/src/lib.rs b/diesel_table_macro_syntax/src/lib.rs index f32084a0a933..c1f8cf2b6869 100644 --- a/diesel_table_macro_syntax/src/lib.rs +++ b/diesel_table_macro_syntax/src/lib.rs @@ -2,59 +2,31 @@ use syn::spanned::Spanned; use syn::Ident; use syn::MetaNameValue; -#[allow(dead_code)] // paren_token is currently unused -pub struct PrimaryKey { - paren_token: syn::token::Paren, - pub keys: syn::punctuated::Punctuated, -} - -#[allow(dead_code)] // arrow is currently unused -pub struct ColumnDef { - pub meta: Vec, - pub column_name: Ident, - pub sql_name: String, - arrow: syn::Token![->], - pub tpe: syn::TypePath, -} - -#[allow(dead_code)] // punct and brace_token is currently unused pub struct TableDecl { pub use_statements: Vec, pub meta: Vec, pub schema: Option, - punct: Option, + _punct: Option, pub sql_name: String, pub table_name: Ident, pub primary_keys: Option, - brace_token: syn::token::Brace, + _brace_token: syn::token::Brace, pub column_defs: syn::punctuated::Punctuated, } -#[allow(dead_code)] // eq is currently unused -struct SqlNameAttribute { - eq: syn::Token![=], - lit: syn::LitStr, +#[allow(dead_code)] // paren_token is currently unused +pub struct PrimaryKey { + paren_token: syn::token::Paren, + pub keys: syn::punctuated::Punctuated, } -impl SqlNameAttribute { - fn from_attribute(element: syn::Attribute) -> Result { - if let syn::Meta::NameValue(MetaNameValue { - eq_token, - value: - syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(lit), - .. - }), - .. - }) = element.meta - { - Ok(SqlNameAttribute { eq: eq_token, lit }) - } else { - Err(syn::Error::new( - element.span(), - "Invalid `#[sql_name = \"column_name\"]` attribute", - )) - } - } + +pub struct ColumnDef { + pub meta: Vec, + pub column_name: Ident, + pub sql_name: String, + _arrow: syn::Token![->], + pub tpe: syn::TypePath, + pub max_length: Option, } impl syn::parse::Parse for TableDecl { @@ -68,7 +40,7 @@ impl syn::parse::Parse for TableDecl { break; }; } - let meta = syn::Attribute::parse_outer(buf)?; + let mut meta = syn::Attribute::parse_outer(buf)?; let fork = buf.fork(); let (schema, punct, table_name) = if parse_table_with_schema(&fork).is_ok() { let (schema, punct, table_name) = parse_table_with_schema(buf)?; @@ -86,16 +58,16 @@ impl syn::parse::Parse for TableDecl { let content; let brace_token = syn::braced!(content in buf); let column_defs = syn::punctuated::Punctuated::parse_terminated(&content)?; - let (sql_name, meta) = get_sql_name(meta, &table_name)?; + let sql_name = get_sql_name(&mut meta, &table_name)?; Ok(Self { use_statements, meta, table_name, primary_keys, - brace_token, + _brace_token: brace_token, column_defs, sql_name, - punct, + _punct: punct, schema, }) } @@ -112,29 +84,28 @@ impl syn::parse::Parse for PrimaryKey { impl syn::parse::Parse for ColumnDef { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let meta = syn::Attribute::parse_outer(input)?; - let column_name = input.parse()?; - let arrow = input.parse()?; - let tpe = input.parse()?; - let (sql_name, meta) = get_sql_name(meta, &column_name)?; + let mut meta = syn::Attribute::parse_outer(input)?; + let column_name: syn::Ident = input.parse()?; + let _arrow: syn::Token![->] = input.parse()?; + let tpe: syn::TypePath = input.parse()?; + + let sql_name = get_sql_name(&mut meta, &column_name)?; + let max_length = take_lit(&mut meta, "max_length", |lit| match lit { + syn::Lit::Int(lit_int) => Some(lit_int), + _ => None, + })?; + Ok(Self { meta, column_name, - arrow, + _arrow, tpe, + max_length, sql_name, }) } } -impl syn::parse::Parse for SqlNameAttribute { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let eq = input.parse()?; - let lit = input.parse()?; - Ok(Self { eq, lit }) - } -} - pub fn parse_table_with_schema( input: &syn::parse::ParseBuffer<'_>, ) -> Result<(syn::Ident, syn::Token![.], syn::Ident), syn::Error> { @@ -142,19 +113,51 @@ pub fn parse_table_with_schema( } fn get_sql_name( - mut meta: Vec, - ident: &syn::Ident, -) -> Result<(String, Vec), syn::Error> { - if let Some(pos) = meta.iter().position(|m| { + meta: &mut Vec, + fallback_ident: &syn::Ident, +) -> Result { + Ok( + match take_lit(meta, "sql_name", |lit| match lit { + syn::Lit::Str(lit_str) => Some(lit_str), + _ => None, + })? { + None => fallback_ident.to_string(), + Some(str_lit) => str_lit.value(), + }, + ) +} + +fn take_lit( + meta: &mut Vec, + attribute_name: &'static str, + extraction_fn: F, +) -> Result, syn::Error> +where + F: FnOnce(syn::Lit) -> Option, +{ + if let Some(index) = meta.iter().position(|m| { m.path() .get_ident() - .map(|i| i == "sql_name") + .map(|i| i == attribute_name) .unwrap_or(false) }) { - let element = meta.remove(pos); - let inner = SqlNameAttribute::from_attribute(element)?; - Ok((inner.lit.value(), meta)) - } else { - Ok((ident.to_string(), meta)) + let attribute = meta.remove(index); + let span = attribute.span(); + let extraction_after_finding_attr = if let syn::Meta::NameValue(MetaNameValue { + value: syn::Expr::Lit(syn::ExprLit { lit, .. }), + .. + }) = attribute.meta + { + extraction_fn(lit) + } else { + None + }; + return Ok(Some(extraction_after_finding_attr.ok_or_else(|| { + syn::Error::new( + span, + format_args!("Invalid `#[sql_name = {attribute_name:?}]` attribute"), + ) + })?)); } + Ok(None) } diff --git a/diesel_table_macro_syntax/tests/basic.rs b/diesel_table_macro_syntax/tests/basic.rs new file mode 100644 index 000000000000..b1b4e2929533 --- /dev/null +++ b/diesel_table_macro_syntax/tests/basic.rs @@ -0,0 +1,18 @@ +use diesel_table_macro_syntax::*; + +#[test] +fn basic() { + let input = include_str!("basic.rs.in"); + let t: TableDecl = syn::parse_str(input).unwrap(); + assert_eq!(t.column_defs.len(), 3); + assert_eq!( + t.column_defs + .iter() + .map(|c| c + .max_length + .as_ref() + .map(|n| n.base10_parse::().unwrap())) + .collect::>(), + &[None, Some(120), Some(120)] + ) +} diff --git a/diesel_table_macro_syntax/tests/basic.rs.in b/diesel_table_macro_syntax/tests/basic.rs.in new file mode 100644 index 000000000000..5637c71e021d --- /dev/null +++ b/diesel_table_macro_syntax/tests/basic.rs.in @@ -0,0 +1,7 @@ +t1 (id) { + f0 -> Varchar, + #[max_length = 120] + f1 -> Varchar, + #[max_length = 120] + f2 -> Nullable, +}