From d60d0066b4a69846bf8869f098255dd77932fc5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Medina?= Date: Sat, 7 Dec 2024 00:34:05 -0800 Subject: [PATCH] Add `flatten` attribute to derive SerializeRow Currently only the `match_by_name` flavor is supported. All the needed structs/traits to make this work are marked as `#[doc(hidden)]` to not increase the public API surface. Effort was done to not change any of the existing API. --- scylla-cql/src/lib.rs | 5 +- scylla-cql/src/types/serialize/row.rs | 142 +++++++++++++++ scylla-macros/src/serialize/row.rs | 238 +++++++++++++++++--------- scylla/src/lib.rs | 7 +- scylla/src/macros.rs | 8 + 5 files changed, 318 insertions(+), 82 deletions(-) diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index 228fc43f8..c9b469204 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -62,8 +62,9 @@ pub mod _macro_internal { BuiltinSerializationError as BuiltinRowSerializationError, BuiltinSerializationErrorKind as BuiltinRowSerializationErrorKind, BuiltinTypeCheckError as BuiltinRowTypeCheckError, - BuiltinTypeCheckErrorKind as BuiltinRowTypeCheckErrorKind, RowSerializationContext, - SerializeRow, + BuiltinTypeCheckErrorKind as BuiltinRowTypeCheckErrorKind, + FieldSerializationStatus as FieldRowSerializationStatus, PartialSerializeRowByName, + RowSerializationContext, SerializeRow, SerializeRowByName, }; pub use crate::types::serialize::value::{ BuiltinSerializationError as BuiltinTypeSerializationError, diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index 665335ebe..d734dfef5 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -85,6 +85,96 @@ pub trait SerializeRow { fn is_empty(&self) -> bool; } +/// How to serialize a row column-by-column +/// +/// For now this trait is an implementation detail of `#[derive(SerializeRow)]` when +/// serializing by name +#[doc(hidden)] +pub trait PartialSerializeRowByName { + /// Serializes a single column in the row according to the information in the + /// given context + /// + /// It returns whether the column finished the serialization of the struct, did + /// it partially, none of at all, or errored + fn serialize_field( + &mut self, + spec: &ColumnSpec, + writer: &mut RowWriter<'_>, + ) -> Result; + + /// Checks if there are any missing columns to finish the serialization + fn check_missing(self) -> Result<(), SerializationError>; +} + +/// Represents a set of values that can be sent along a CQL statement when serializing by name +/// +/// For now this trait is an implementation detail of `#[derive(SerializeRow)]` when +/// serializing by name +#[doc(hidden)] +pub trait SerializeRowByName { + /// A type that can handle serialization of this struct column-by-column + type Partial<'d>: PartialSerializeRowByName + where + Self: 'd; + + /// Returns a type that can serialize this row "column-by-column" + fn partial(&self) -> Self::Partial<'_>; + + /// Serializes all the fields/columns by name + /// + /// Auto-implemented -- do not override + fn serialize_by_name( + &self, + ctx: &RowSerializationContext, + writer: &mut RowWriter<'_>, + ) -> Result<(), SerializationError> { + let mut partial = self.partial(); + + for spec in ctx.columns() { + let serialized = partial.serialize_field(spec, writer).map_err(|err| { + SerializationError::new(BuiltinSerializationError { + rust_name: std::any::type_name::(), + kind: BuiltinSerializationErrorKind::ColumnSerializationFailed { + name: spec.name().to_owned(), + err, + }, + }) + })?; + + if matches!(serialized, FieldSerializationStatus::NotUsed) { + return Err(SerializationError::new(BuiltinTypeCheckError { + rust_name: std::any::type_name::(), + kind: BuiltinTypeCheckErrorKind::NoColumnWithName { + name: spec.name().to_owned(), + }, + })); + } + } + + partial.check_missing()?; + + Ok(()) + } +} + +/// Whether a field used a column to finish its serialization or not +/// +/// Used when serializing by name as a single column may not have finished a rust +/// field in the case of a flattened struct +/// +/// For now this enum is an implementation detail of `#[derive(SerializeRow)]` when +/// serializing by name +#[derive(Debug)] +#[doc(hidden)] +pub enum FieldSerializationStatus { + /// The column finished the serialization for this field + Done, + /// The column was used but there are other fields not yet serialized + NotDone, + /// The column did not belong to this field + NotUsed, +} + macro_rules! fallback_impl_contents { () => { fn serialize( @@ -1634,4 +1724,56 @@ pub(crate) mod tests { assert_eq!(reference, row); } + + #[test] + fn test_row_serialization_nested_structs() { + #[derive(SerializeRow, Debug)] + #[scylla(crate = crate)] + struct InnerColumnsOne { + x: i32, + y: f64, + } + + #[derive(SerializeRow, Debug)] + #[scylla(crate = crate)] + struct InnerColumnsTwo { + z: bool, + } + + #[derive(SerializeRow, Debug)] + #[scylla(crate = crate)] + struct OuterColumns { + #[scylla(flatten)] + inner_one: InnerColumnsOne, + a: String, + #[scylla(flatten)] + inner_two: InnerColumnsTwo, + } + + let spec = [ + col("a", ColumnType::Text), + col("x", ColumnType::Int), + col("z", ColumnType::Boolean), + col("y", ColumnType::Double), + ]; + + let value = OuterColumns { + inner_one: InnerColumnsOne { x: 5, y: 1.0 }, + a: "something".to_owned(), + inner_two: InnerColumnsTwo { z: true }, + }; + + let reference = do_serialize( + ( + &value.a, + &value.inner_one.x, + &value.inner_two.z, + &value.inner_one.y, + ), + &spec, + ); + let row = do_serialize(value, &spec); + + assert_eq!(reference, row); + } } diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs index ffa2c7a2b..1280d73b5 100644 --- a/scylla-macros/src/serialize/row.rs +++ b/scylla-macros/src/serialize/row.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use darling::FromAttributes; use proc_macro::TokenStream; -use proc_macro2::Span; use syn::parse_quote; use crate::Flavor; @@ -55,6 +54,11 @@ struct FieldAttributes { // instead of the Rust field name. rename: Option, + // If set, then this field's columns are serialized using its own implementation + // of `SerializeRow` and flattened as if they were fields in this struct. + #[darling(default)] + flatten: bool, + // If true, then the field is not serialized at all, but simply ignored. // All other attributes are ignored. #[darling(default)] @@ -64,6 +68,8 @@ struct FieldAttributes { struct Context { attributes: Attributes, fields: Vec, + struct_name: syn::Ident, + generics: syn::Generics, } pub(crate) fn derive_serialize_row(tokens_input: TokenStream) -> Result { @@ -90,7 +96,12 @@ pub(crate) fn derive_serialize_row(tokens_input: TokenStream) -> Result>()?; - let ctx = Context { attributes, fields }; + let ctx = Context { + attributes, + fields, + struct_name: struct_name.clone(), + generics: input.generics.clone(), + }; ctx.validate(&input.ident)?; let gen: Box = match ctx.attributes.flavor { @@ -136,6 +147,30 @@ impl Context { } } + // `flatten` annotations is not yet supported outside of `match_by_name` + if !matches!(self.attributes.flavor, Flavor::MatchByName) { + if let Some(field) = self.fields.iter().find(|f| f.attrs.flatten) { + let err = darling::Error::custom( + "the `flatten` annotations is only supported wit the `match_by_name` flavor", + ) + .with_span(&field.ident); + errors.push(err); + } + } + + // Check that no renames are attempted on flattened fields + let rename_flatten_errors = self + .fields + .iter() + .filter(|f| f.attrs.flatten && f.attrs.rename.is_some()) + .map(|f| { + darling::Error::custom( + "`rename` and `flatten` annotations do not make sense together", + ) + .with_span(&f.ident) + }); + errors.extend(rename_flatten_errors); + // Check for name collisions let mut used_names = HashMap::::new(); for field in self.fields.iter() { @@ -199,94 +234,136 @@ impl Generator for ColumnSortingGenerator<'_> { // Need to: // - Check that all required columns are there and no more // - Check that the column types match - let mut statements: Vec = Vec::new(); let crate_path = self.ctx.attributes.crate_path(); + let struct_name = &self.ctx.struct_name; + let (impl_generics, ty_generics, where_clause) = self.ctx.generics.split_for_impl(); + let partial_struct_name = syn::Ident::new( + &format!("_{}ScyllaSerPartial", struct_name), + struct_name.span(), + ); + let mut partial_generics = self.ctx.generics.clone(); + let partial_lt: syn::LifetimeParam = syn::parse_quote!('scylla_ser_partial); + if !self.ctx.fields.is_empty() { + partial_generics + .params + .push(syn::GenericParam::Lifetime(partial_lt.clone())); + } - let rust_field_idents = self - .ctx - .fields - .iter() - .map(|f| f.ident.clone()) - .collect::>(); - let rust_field_names = self + let (partial_impl_generics, partial_ty_generics, partial_where_clause) = + partial_generics.split_for_impl(); + + let flattened: Vec<_> = self.ctx.fields.iter().filter(|f| f.attrs.flatten).collect(); + let flattened_fields: Vec<_> = flattened.iter().map(|f| &f.ident).collect(); + let flattened_tys: Vec<_> = flattened.iter().map(|f| &f.ty).collect(); + + let unflattened: Vec<_> = self .ctx .fields .iter() - .map(|f| f.column_name()) - .collect::>(); - let udt_field_names = rust_field_names.clone(); // For now, it's the same - let field_types = self.ctx.fields.iter().map(|f| &f.ty).collect::>(); - - // Declare a helper lambda for creating errors - statements.push(self.ctx.generate_mk_typck_err()); - statements.push(self.ctx.generate_mk_ser_err()); + .filter(|f| !f.attrs.flatten) + .collect(); + let unflattened_columns: Vec<_> = unflattened.iter().map(|f| f.column_name()).collect(); + let unflattened_fields: Vec<_> = unflattened.iter().map(|f| &f.ident).collect(); + let unflattened_tys: Vec<_> = unflattened.iter().map(|f| &f.ty).collect(); + + let all_names = self.ctx.fields.iter().map(|f| f.column_name()); + + let partial_struct: syn::ItemStruct = parse_quote! { + struct #partial_struct_name #partial_generics { + #(#unflattened_fields: &#partial_lt #unflattened_tys,)* + #(#flattened_fields: <#flattened_tys as #crate_path::SerializeRowByName>::Partial<#partial_lt>,)* + missing: ::std::collections::HashSet<&'static str>, + } + }; + + let serialize_field_block: syn::Block = if self.ctx.fields.is_empty() { + parse_quote! {{ + ::std::result::Result::Ok(#crate_path::FieldRowSerializationStatus::NotUsed) + }} + } else { + parse_quote! {{ + match spec.name() { + #(#unflattened_columns => { + let sub_writer = #crate_path::RowWriter::make_cell_writer(writer); + <#unflattened_tys as #crate_path::SerializeValue>::serialize(&self.#unflattened_fields, spec.typ(), sub_writer)?; + self.missing.remove(#unflattened_columns); + })* + _ => { + #({ + match self.#flattened_fields.serialize_field(spec, writer)? { + #crate_path::FieldRowSerializationStatus::Done => { + self.missing.remove(stringify!(#flattened_fields)); + if self.missing.is_empty() { + return ::std::result::Result::Ok(#crate_path::FieldRowSerializationStatus::Done); + } else { + return ::std::result::Result::Ok(#crate_path::FieldRowSerializationStatus::NotDone); + } + } + #crate_path::FieldRowSerializationStatus::NotDone => { + return ::std::result::Result::Ok(#crate_path::FieldRowSerializationStatus::NotDone) + } + #crate_path::FieldRowSerializationStatus::NotUsed => {} + }; + })* - // Generate a "visited" flag for each field - let visited_flag_names = rust_field_idents - .iter() - .map(|s| syn::Ident::new(&format!("visited_flag_{}", s), Span::call_site())) - .collect::>(); - statements.extend::>(parse_quote! { - #(let mut #visited_flag_names = false;)* - }); + return ::std::result::Result::Ok(#crate_path::FieldRowSerializationStatus::NotUsed); + } + } - // Generate a variable that counts down visited fields. - let field_count = self.ctx.fields.len(); - statements.push(parse_quote! { - let mut remaining_count = #field_count; - }); + if self.missing.is_empty() { + ::std::result::Result::Ok(#crate_path::FieldRowSerializationStatus::Done) + } else { + ::std::result::Result::Ok(#crate_path::FieldRowSerializationStatus::NotDone) + } + }} + }; + + let partial_serialize: syn::ItemImpl = parse_quote! { + impl #partial_impl_generics #crate_path::PartialSerializeRowByName for #partial_struct_name #partial_ty_generics #partial_where_clause { + fn serialize_field( + &mut self, + spec: &#crate_path::ColumnSpec, + writer: &mut #crate_path::RowWriter<'_>, + ) -> ::std::result::Result<#crate_path::FieldRowSerializationStatus, #crate_path::SerializationError> { + #serialize_field_block + } - // Generate a loop over the fields and a `match` block to match on - // the field name. - statements.push(parse_quote! { - for spec in ctx.columns() { - match spec.name() { - #( - #udt_field_names => { - let sub_writer = #crate_path::RowWriter::make_cell_writer(writer); - match <#field_types as #crate_path::SerializeValue>::serialize(&self.#rust_field_idents, spec.typ(), sub_writer) { - ::std::result::Result::Ok(_proof) => {} - ::std::result::Result::Err(err) => { - return ::std::result::Result::Err(mk_ser_err( - #crate_path::BuiltinRowSerializationErrorKind::ColumnSerializationFailed { - name: <_ as ::std::borrow::ToOwned>::to_owned(spec.name()), - err, - } - )); - } - } - if !#visited_flag_names { - #visited_flag_names = true; - remaining_count -= 1; - } - } - )* - _ => return ::std::result::Result::Err(mk_typck_err( - #crate_path::BuiltinRowTypeCheckErrorKind::NoColumnWithName { - name: <_ as ::std::borrow::ToOwned>::to_owned(spec.name()), - } - )), + fn check_missing(self) -> ::std::result::Result<(), #crate_path::SerializationError> { + use ::std::iter::{Iterator as _, IntoIterator as _}; + + let ::std::option::Option::Some(missing) = self.missing.into_iter().nth(0) else { + return ::std::result::Result::Ok(()); + }; + + match missing { + #(stringify!(#flattened_fields) => self.#flattened_fields.check_missing(),)* + _ => ::std::result::Result::Err(#crate_path::SerializationError::new(#crate_path::BuiltinRowTypeCheckError { + rust_name: ::std::any::type_name::<#struct_name #ty_generics>(), + kind: #crate_path::BuiltinRowTypeCheckErrorKind::ValueMissingForColumn { + name: <_ as ::std::borrow::ToOwned>::to_owned(missing), + }, + })) + } } } - }); + }; - // Finally, check that all fields were consumed. - // If there are some missing fields, return an error - statements.push(parse_quote! { - if remaining_count > 0 { - #( - if !#visited_flag_names { - return ::std::result::Result::Err(mk_typck_err( - #crate_path::BuiltinRowTypeCheckErrorKind::ValueMissingForColumn { - name: <_ as ::std::string::ToString>::to_string(#rust_field_names), - } - )); + let serialize_by_name: syn::ItemImpl = parse_quote! { + impl #impl_generics #crate_path::SerializeRowByName for #struct_name #ty_generics #where_clause { + type Partial<#partial_lt> = #partial_struct_name #partial_ty_generics where Self: #partial_lt; + + fn partial(&self) -> Self::Partial<'_> { + use ::std::iter::FromIterator as _; + + #partial_struct_name { + #(#unflattened_fields: &self.#unflattened_fields,)* + #(#flattened_fields: self.#flattened_fields.partial(),)* + missing: ::std::collections::HashSet::from_iter([#(#all_names,)*]), } - )* - ::std::unreachable!() + } } - }); + }; parse_quote! { fn serialize<'b>( @@ -294,8 +371,13 @@ impl Generator for ColumnSortingGenerator<'_> { ctx: &#crate_path::RowSerializationContext, writer: &mut #crate_path::RowWriter<'b>, ) -> ::std::result::Result<(), #crate_path::SerializationError> { - #(#statements)* - ::std::result::Result::Ok(()) + #partial_struct + #partial_serialize + + #[allow(non_local_definitions)] + #serialize_by_name + + ::serialize_by_name(self, ctx, writer) } } } diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index 6724a65d9..16387641a 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -81,7 +81,7 @@ //! .query_unpaged("SELECT a, b FROM ks.tab", &[]) //! .await? //! .into_rows_result()?; -//! +//! //! for row in query_rows.rows()? { //! // Parse row as int and text \ //! let (int_val, text_val): (i32, &str) = row?; @@ -152,7 +152,10 @@ pub mod serialize { /// Contains the [SerializeRow][row::SerializeRow] trait and its implementations. pub mod row { // Main types - pub use scylla_cql::types::serialize::row::{RowSerializationContext, SerializeRow}; + pub use scylla_cql::types::serialize::row::{ + FieldSerializationStatus, PartialSerializeRowByName, RowSerializationContext, + SerializeRow, SerializeRowByName, + }; // Errors pub use scylla_cql::types::serialize::row::{ diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index ce64153e8..ca7d64942 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -241,6 +241,14 @@ pub use scylla_cql::macros::SerializeValue; /// /// Don't use the field during serialization. /// +/// `#[scylla(flatten)]` +/// +/// Use this field's `SerializeRow` implementation to serialize its columns as part +/// of this struct. Note that the name of this field is ignored and hence the +/// `rename` attribute does not make sense here and will cause a compilation +/// error. Currently this is only supported for the `"match_by_name"` flavor in both +/// the wrapper struct and this flattened struct. +/// /// --- /// pub use scylla_cql::macros::SerializeRow;