From dcb4cf48f2841d7ad37d6849b570788398ddde41 Mon Sep 17 00:00:00 2001 From: Piotr Dulikowski Date: Fri, 27 Oct 2023 07:44:49 +0200 Subject: [PATCH] scylla-macros: implement enforce_order flavor of SerializeCql Some users might not need the additional robustness of `SerializeCql` that comes from sorting the fields before serializing, as they are used to the current behavior of `Value` and properly set the order of the fields in their Rust struct. In order to give them some performance boost, add an additional mode to `SerializeCql` called "enforce_order" which expects that the order of the fields in the struct is kept in sync with the DB definition of the UDT. It's still safe to use because, as the struct fields are serialized, their names are compared with the fields in the UDT definition order and serialization fails if the field name on some position is mismatched. --- scylla-cql/src/macros.rs | 19 ++- scylla-cql/src/types/serialize/value.rs | 170 ++++++++++++++++++++++++ scylla-macros/src/serialize/cql.rs | 121 ++++++++++++++++- scylla-macros/src/serialize/mod.rs | 18 +++ 4 files changed, 323 insertions(+), 5 deletions(-) diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index 8f53e24fa9..2b7b0b4ae7 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -16,9 +16,7 @@ pub use scylla_macros::ValueList; /// Derive macro for the [`SerializeCql`](crate::types::serialize::value::SerializeCql) trait /// which serializes given Rust structure as a User Defined Type (UDT). /// -/// At the moment, only structs with named fields are supported. The generated -/// implementation of the trait will match the struct fields to UDT fields -/// by name automatically. +/// At the moment, only structs with named fields are supported. /// /// Serialization will fail if there are some fields in the UDT that don't match /// to any of the Rust struct fields, _or vice versa_. @@ -50,6 +48,21 @@ pub use scylla_macros::ValueList; /// /// # Attributes /// +/// `#[scylla(flavor = "flavor_name")]` +/// +/// Allows to choose one of the possible "flavors", i.e. the way how the +/// generated code will approach serialization. Possible flavors are: +/// +/// - `"match_by_name"` (default) - the generated implementation _does not +/// require_ the fields in the Rust struct to be in the same order as the +/// fields in the UDT. During serialization, the implementation will take +/// care to serialize the fields in the order which the database expects. +/// - `"enforce_order"` - the generated implementation _requires_ the fields +/// in the Rust struct to be in the same order as the fields in the UDT. +/// If the order is incorrect, type checking/serialization will fail. +/// This is a less robust flavor than `"match_by_name"`, but should be +/// slightly more performant as it doesn't need to perform lookups by name. +/// /// `#[scylla(crate = crate_name)]` /// /// By default, the code generated by the derive macro will refer to the items diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 85033dac25..567b59cfab 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -1314,6 +1314,12 @@ pub enum UdtTypeCheckErrorKind { /// The Rust data contains a field that is not present in the UDT UnexpectedFieldInDestination { field_name: String }, + + /// A different field name was expected at given position. + FieldNameMismatch { + rust_field_name: String, + db_field_name: String, + }, } impl Display for UdtTypeCheckErrorKind { @@ -1337,6 +1343,10 @@ impl Display for UdtTypeCheckErrorKind { f, "the field {field_name} present in the Rust data is not present in the CQL type" ), + UdtTypeCheckErrorKind::FieldNameMismatch { rust_field_name, db_field_name } => write!( + f, + "expected field with name {db_field_name} at given position, but the Rust field name is {rust_field_name}" + ), } } } @@ -1668,4 +1678,164 @@ mod tests { check_with_type(ColumnType::Int, 123_i32, CqlValue::Int(123_i32)); check_with_type(ColumnType::Double, 123_f64, CqlValue::Double(123_f64)); } + + #[derive(SerializeCql, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate, flavor = "enforce_order")] + struct TestUdtWithEnforcedOrder { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_udt_serialization_with_enforced_order_correct_order() { + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let reference = do_serialize( + CqlValue::UserDefinedType { + keyspace: "ks".to_string(), + type_name: "typ".to_string(), + fields: vec![ + ( + "a".to_string(), + Some(CqlValue::Text(String::from("Ala ma kota"))), + ), + ("b".to_string(), Some(CqlValue::Int(42))), + ( + "c".to_string(), + Some(CqlValue::List(vec![ + CqlValue::BigInt(1), + CqlValue::BigInt(2), + CqlValue::BigInt(3), + ])), + ), + ], + }, + &typ, + ); + let udt = do_serialize( + TestUdtWithEnforcedOrder { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &typ, + ); + + assert_eq!(reference, udt); + } + + #[test] + fn test_udt_serialization_with_enforced_order_failing_type_check() { + let typ_not_udt = ColumnType::Ascii; + let udt = TestUdtWithEnforcedOrder::default(); + + let mut data = Vec::new(); + + let err = <_ as SerializeCql>::serialize(&udt, &typ_not_udt, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) + )); + + let typ = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + // Two first columns are swapped + ("b".to_string(), ColumnType::Int), + ("a".to_string(), ColumnType::Text), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ, CellWriter::new(&mut data)).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::FieldNameMismatch { .. }) + )); + + let typ_without_c = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + // Last field is missing + ], + }; + + let err = <_ as SerializeCql>::serialize(&udt, &typ_without_c, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::MissingField { .. }) + )); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ( + "c".to_string(), + ColumnType::List(Box::new(ColumnType::BigInt)), + ), + // Unexpected field + ("d".to_string(), ColumnType::Counter), + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::UnexpectedFieldInDestination { .. } + ) + )); + + let typ_unexpected_field = ColumnType::UserDefinedType { + type_name: "typ".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ("c".to_string(), ColumnType::TinyInt), // Wrong column type + ], + }; + + let err = + <_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data)) + .unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinSerializationErrorKind::UdtError( + UdtSerializationErrorKind::FieldSerializationFailed { .. } + ) + )); + } } diff --git a/scylla-macros/src/serialize/cql.rs b/scylla-macros/src/serialize/cql.rs index f19e47b27c..d3c5788401 100644 --- a/scylla-macros/src/serialize/cql.rs +++ b/scylla-macros/src/serialize/cql.rs @@ -3,11 +3,15 @@ use proc_macro::TokenStream; use proc_macro2::Span; use syn::parse_quote; +use super::Flavor; + #[derive(FromAttributes)] #[darling(attributes(scylla))] struct Attributes { #[darling(rename = "crate")] crate_path: Option, + + flavor: Option, } impl Attributes { @@ -36,7 +40,11 @@ pub fn derive_serialize_cql(tokens_input: TokenStream) -> Result = match ctx.attributes.flavor { + Some(Flavor::MatchByName) | None => Box::new(FieldSortingGenerator { ctx: &ctx }), + Some(Flavor::EnforceOrder) => Box::new(FieldOrderedGenerator { ctx: &ctx }), + }; let serialize_item = gen.generate_serialize(); @@ -93,13 +101,17 @@ impl Context { } } +trait Generator { + fn generate_serialize(&self) -> syn::TraitItemFn; +} + // Generates an implementation of the trait which sorts the fields according // to how it is defined in the database. struct FieldSortingGenerator<'a> { ctx: &'a Context, } -impl<'a> FieldSortingGenerator<'a> { +impl<'a> Generator for FieldSortingGenerator<'a> { fn generate_serialize(&self) -> syn::TraitItemFn { // Need to: // - Check that all required fields are there and no more @@ -222,3 +234,108 @@ impl<'a> FieldSortingGenerator<'a> { } } } + +// Generates an implementation of the trait which requires the fields +// to be placed in the same order as they are defined in the struct. +struct FieldOrderedGenerator<'a> { + ctx: &'a Context, +} + +impl<'a> Generator for FieldOrderedGenerator<'a> { + fn generate_serialize(&self) -> syn::TraitItemFn { + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + // Declare a helper lambda for creating errors + statements.push(self.ctx.generate_mk_typck_err()); + statements.push(self.ctx.generate_mk_ser_err()); + + // Check that the type we want to serialize to is a UDT + statements.push( + self.ctx + .generate_udt_type_match(parse_quote!(#crate_path::UdtTypeCheckErrorKind::NotUdt)), + ); + + // Turn the cell writer into a value builder + statements.push(parse_quote! { + let mut builder = #crate_path::CellWriter::into_value_builder(writer); + }); + + // Create an iterator over fields + statements.push(parse_quote! { + let mut field_iter = field_types.iter(); + }); + + // Serialize each field + for field in self.ctx.fields.iter() { + let rust_field_ident = field.ident.as_ref().unwrap(); + let rust_field_name = rust_field_ident.to_string(); + let typ = &field.ty; + statements.push(parse_quote! { + match field_iter.next() { + Some((field_name, typ)) => { + if field_name == #rust_field_name { + let sub_builder = #crate_path::CellValueBuilder::make_sub_writer(&mut builder); + match <#typ as #crate_path::SerializeCql>::serialize(&self.#rust_field_ident, typ, sub_builder) { + Ok(_proof) => {}, + Err(err) => { + return ::std::result::Result::Err(mk_ser_err( + #crate_path::UdtSerializationErrorKind::FieldSerializationFailed { + field_name: <_ as ::std::clone::Clone>::clone(field_name), + err, + } + )); + } + } + } else { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::FieldNameMismatch { + rust_field_name: <_ as ::std::string::ToString>::to_string(#rust_field_name), + db_field_name: <_ as ::std::clone::Clone>::clone(field_name), + } + )); + } + } + None => { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::MissingField { + field_name: <_ as ::std::string::ToString>::to_string(#rust_field_name), + } + )); + } + } + }); + } + + // Check whether there are some fields remaining + statements.push(parse_quote! { + if let Some((field_name, typ)) = field_iter.next() { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::UdtTypeCheckErrorKind::UnexpectedFieldInDestination { + field_name: <_ as ::std::clone::Clone>::clone(field_name), + } + )); + } + }); + + parse_quote! { + fn serialize<'b>( + &self, + typ: &#crate_path::ColumnType, + writer: #crate_path::CellWriter<'b>, + ) -> ::std::result::Result<#crate_path::WrittenCellProof<'b>, #crate_path::SerializationError> { + #(#statements)* + let proof = #crate_path::CellValueBuilder::finish(builder) + .map_err(|_| #crate_path::SerializationError::new( + #crate_path::BuiltinTypeSerializationError { + rust_name: ::std::any::type_name::(), + got: <_ as ::std::clone::Clone>::clone(typ), + kind: #crate_path::BuiltinTypeSerializationErrorKind::SizeOverflow, + } + ) as #crate_path::SerializationError)?; + ::std::result::Result::Ok(proof) + } + } + } +} diff --git a/scylla-macros/src/serialize/mod.rs b/scylla-macros/src/serialize/mod.rs index 53abe0f296..183183fa91 100644 --- a/scylla-macros/src/serialize/mod.rs +++ b/scylla-macros/src/serialize/mod.rs @@ -1,2 +1,20 @@ +use darling::FromMeta; + pub(crate) mod cql; pub(crate) mod row; + +#[derive(Copy, Clone, PartialEq, Eq)] +enum Flavor { + MatchByName, + EnforceOrder, +} + +impl FromMeta for Flavor { + fn from_string(value: &str) -> darling::Result { + match value { + "match_by_name" => Ok(Self::MatchByName), + "enforce_order" => Ok(Self::EnforceOrder), + _ => Err(darling::Error::unknown_value(value)), + } + } +}