From e2bedf6c08b8d423fddf17babf8bd14c83a4054b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Medina?= Date: Mon, 2 Dec 2024 16:07:56 -0800 Subject: [PATCH 1/3] Allow non-rust idents column names in SerializeRow derived struct fixes cases where the "column name" is not a valid rust identifier; such as trying to pass in a dynamic TTL --- scylla-cql/src/types/serialize/row.rs | 17 +++++++++++++++++ scylla-macros/src/serialize/row.rs | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index 1fd3b820a3..665335ebe4 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -1568,6 +1568,23 @@ pub(crate) mod tests { assert_eq!(reference, row); } + #[test] + fn test_row_serialization_with_not_rust_idents() { + #[derive(SerializeRow, Debug)] + #[scylla(crate = crate)] + struct RowWithTTL { + #[scylla(rename = "[ttl]")] + ttl: i32, + } + + let spec = [col("[ttl]", ColumnType::Int)]; + + let reference = do_serialize((42i32,), &spec); + let row = do_serialize(RowWithTTL { ttl: 42 }, &spec); + + assert_eq!(reference, row); + } + #[derive(SerializeRow, Debug)] #[scylla(crate = crate)] struct TestRowWithSkippedFields { diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs index fc0de5234f..ffa2c7a2b4 100644 --- a/scylla-macros/src/serialize/row.rs +++ b/scylla-macros/src/serialize/row.rs @@ -223,7 +223,7 @@ impl Generator for ColumnSortingGenerator<'_> { statements.push(self.ctx.generate_mk_ser_err()); // Generate a "visited" flag for each field - let visited_flag_names = rust_field_names + let visited_flag_names = rust_field_idents .iter() .map(|s| syn::Ident::new(&format!("visited_flag_{}", s), Span::call_site())) .collect::>(); From 1c0d353be602404c85ab82f77bc4c21462e3b1b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Medina?= Date: Mon, 2 Dec 2024 16:08:35 -0800 Subject: [PATCH 2/3] fix attribute documentation missing closing square bracket --- scylla/src/macros.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index 6549507bef..ce64153e80 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -360,7 +360,7 @@ pub use scylla_cql::macros::SerializeRow; /// If the value of the field received from DB is null, the field will be /// initialized with `Default::default()`. /// -/// `#[scylla(rename = "field_name")` +/// `#[scylla(rename = "field_name")]` /// /// By default, the generated implementation will try to match the Rust field /// to a UDT field with the same name. This attribute instead allows to match @@ -475,7 +475,7 @@ pub use scylla_macros::DeserializeValue; /// The field will be completely ignored during deserialization and will /// be initialized with `Default::default()`. /// -/// `#[scylla(rename = "field_name")` +/// `#[scylla(rename = "field_name")]` /// /// By default, the generated implementation will try to match the Rust field /// to a column with the same name. This attribute allows to match to a column From 8b579e638578beec998d97f07bcfee65723deaac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Medina?= Date: Tue, 3 Dec 2024 09:30:00 -0800 Subject: [PATCH 3/3] Allow non-rust idents for renamed field names in SerializeValue this adds support for UDTs that have fields that are not valid rust idents but are valid scylla field names --- scylla-cql/src/types/serialize/value.rs | 28 +++++++++++++++++++++++++ scylla-macros/src/serialize/value.rs | 13 ++++++------ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 0e7fba6691..78e169aa4e 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -2824,4 +2824,32 @@ pub(crate) mod tests { assert_eq!(reference, row); } + + #[test] + fn test_udt_with_non_rust_ident() { + #[derive(SerializeValue, Debug)] + #[scylla(crate = crate)] + struct UdtWithNonRustIdent { + #[scylla(rename = "a$a")] + a: i32, + } + + let typ = ColumnType::UserDefinedType { + type_name: "typ".into(), + keyspace: "ks".into(), + field_types: vec![("a$a".into(), ColumnType::Int)], + }; + let value = UdtWithNonRustIdent { a: 42 }; + + let mut reference = Vec::new(); + // Total length of the struct + reference.extend_from_slice(&8i32.to_be_bytes()); + // Field 'a' + reference.extend_from_slice(&(std::mem::size_of_val(&value.a) as i32).to_be_bytes()); + reference.extend_from_slice(&value.a.to_be_bytes()); + + let udt = do_serialize(value, &typ); + + assert_eq!(reference, udt); + } } diff --git a/scylla-macros/src/serialize/value.rs b/scylla-macros/src/serialize/value.rs index a657d6c45c..d66c2930c2 100644 --- a/scylla-macros/src/serialize/value.rs +++ b/scylla-macros/src/serialize/value.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use darling::FromAttributes; use proc_macro::TokenStream; -use proc_macro2::Span; use syn::parse_quote; use crate::Flavor; @@ -327,14 +326,14 @@ impl Generator for FieldSortingGenerator<'_> { .generate_udt_type_match(parse_quote!(#crate_path::UdtTypeCheckErrorKind::NotUdt)), ); - fn make_visited_flag_ident(field_name: &str) -> syn::Ident { - syn::Ident::new(&format!("visited_flag_{}", field_name), Span::call_site()) + fn make_visited_flag_ident(field_name: &syn::Ident) -> syn::Ident { + syn::Ident::new(&format!("visited_flag_{}", field_name), field_name.span()) } // Generate a "visited" flag for each field - let visited_flag_names = rust_field_names + let visited_flag_names = rust_field_idents .iter() - .map(|s| make_visited_flag_ident(s)) + .map(make_visited_flag_ident) .collect::>(); statements.extend::>(parse_quote! { #(let mut #visited_flag_names = false;)* @@ -347,11 +346,11 @@ impl Generator for FieldSortingGenerator<'_> { .fields .iter() .filter(|f| !f.attrs.ignore_missing) - .map(|f| f.field_name()); + .map(|f| &f.ident); // An iterator over visited flags of Rust fields that can't be ignored // (i.e., if UDT misses a corresponding field, an error should be raised). let nonignorable_visited_flag_names = - nonignorable_rust_field_names.map(|s| make_visited_flag_ident(&s)); + nonignorable_rust_field_names.map(make_visited_flag_ident); // Generate a variable that counts down visited fields. let field_count = self.ctx.fields.len();