diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index 228fc43f8..84b74ced7 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -76,4 +76,110 @@ pub mod _macro_internal { pub use crate::types::serialize::{ CellValueBuilder, CellWriter, RowWriter, SerializationError, }; + + pub mod ser { + pub mod row { + pub use crate::{ + frame::response::result::ColumnSpec, + types::serialize::{ + row::{ + mk_ser_err, mk_typck_err, BuiltinSerializationErrorKind, + BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, RowSerializationContext, + }, + RowWriter, SerializationError, + }, + }; + + /// Whether a field used a column to finish its serialization or not + /// + /// Used when serializing by name as a single column may not have finished a rust + /// field in the case of a flattened struct + /// + /// For now this enum is an implementation detail of `#[derive(SerializeRow)]` when + /// serializing by name + #[derive(Debug)] + #[doc(hidden)] + pub enum FieldStatus { + /// The column finished the serialization for this field + Done, + /// The column was used but there are other fields not yet serialized + NotDone, + /// The column did not belong to this field + NotUsed, + } + + /// Represents a set of values that can be sent along a CQL statement when serializing by name + /// + /// For now this trait is an implementation detail of `#[derive(SerializeRow)]` when + /// serializing by name + #[doc(hidden)] + pub trait SerializeRowByName { + /// A type that can handle serialization of this struct column-by-column + type Partial<'d>: PartialSerializeRowByName + where + Self: 'd; + + /// Returns a type that can serialize this row "column-by-column" + fn partial(&self) -> Self::Partial<'_>; + } + + /// How to serialize a row column-by-column + /// + /// For now this trait is an implementation detail of `#[derive(SerializeRow)]` when + /// serializing by name + #[doc(hidden)] + pub trait PartialSerializeRowByName { + /// Serializes a single column in the row according to the information in the + /// given context + /// + /// It returns whether the column finished the serialization of the struct, did + /// it partially, none of at all, or errored + fn serialize_field( + &mut self, + spec: &ColumnSpec, + writer: &mut RowWriter<'_>, + ) -> Result; + + /// Checks if there are any missing columns to finish the serialization + fn check_missing(self) -> Result<(), SerializationError>; + } + + pub struct ByName<'t, T: SerializeRowByName>(pub &'t T); + + impl ByName<'_, T> { + /// Serializes all the fields/columns by name + pub fn serialize( + self, + ctx: &RowSerializationContext, + writer: &mut RowWriter<'_>, + ) -> Result<(), SerializationError> { + let mut partial = self.0.partial(); + + for spec in ctx.columns() { + let serialized = partial.serialize_field(spec, writer).map_err(|err| { + mk_ser_err::( + BuiltinSerializationErrorKind::ColumnSerializationFailed { + name: spec.name().to_owned(), + err, + }, + ) + })?; + + if matches!(serialized, FieldStatus::NotUsed) { + return Err(SerializationError::new(BuiltinTypeCheckError { + rust_name: std::any::type_name::(), + kind: BuiltinTypeCheckErrorKind::NoColumnWithName { + name: spec.name().to_owned(), + }, + })); + } + } + + partial.check_missing()?; + + Ok(()) + } + } + } + } } diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index 665335ebe..190be8b66 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -556,7 +556,8 @@ pub struct BuiltinTypeCheckError { pub kind: BuiltinTypeCheckErrorKind, } -fn mk_typck_err(kind: impl Into) -> SerializationError { +#[doc(hidden)] +pub fn mk_typck_err(kind: impl Into) -> SerializationError { mk_typck_err_named(std::any::type_name::(), kind) } @@ -582,7 +583,8 @@ pub struct BuiltinSerializationError { pub kind: BuiltinSerializationErrorKind, } -fn mk_ser_err(kind: impl Into) -> SerializationError { +#[doc(hidden)] +pub fn mk_ser_err(kind: impl Into) -> SerializationError { mk_ser_err_named(std::any::type_name::(), kind) } @@ -1634,4 +1636,56 @@ pub(crate) mod tests { assert_eq!(reference, row); } + + #[test] + fn test_row_serialization_nested_structs() { + #[derive(SerializeRow, Debug)] + #[scylla(crate = crate)] + struct InnerColumnsOne { + x: i32, + y: f64, + } + + #[derive(SerializeRow, Debug)] + #[scylla(crate = crate)] + struct InnerColumnsTwo { + z: bool, + } + + #[derive(SerializeRow, Debug)] + #[scylla(crate = crate)] + struct OuterColumns { + #[scylla(flatten)] + inner_one: InnerColumnsOne, + a: String, + #[scylla(flatten)] + inner_two: InnerColumnsTwo, + } + + let spec = [ + col("a", ColumnType::Text), + col("x", ColumnType::Int), + col("z", ColumnType::Boolean), + col("y", ColumnType::Double), + ]; + + let value = OuterColumns { + inner_one: InnerColumnsOne { x: 5, y: 1.0 }, + a: "something".to_owned(), + inner_two: InnerColumnsTwo { z: true }, + }; + + let reference = do_serialize( + ( + &value.a, + &value.inner_one.x, + &value.inner_two.z, + &value.inner_one.y, + ), + &spec, + ); + let row = do_serialize(value, &spec); + + assert_eq!(reference, row); + } } diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs index ffa2c7a2b..5e89ec387 100644 --- a/scylla-macros/src/serialize/row.rs +++ b/scylla-macros/src/serialize/row.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use darling::FromAttributes; use proc_macro::TokenStream; -use proc_macro2::Span; use syn::parse_quote; use crate::Flavor; @@ -55,6 +54,11 @@ struct FieldAttributes { // instead of the Rust field name. rename: Option, + // If set, then this field's columns are serialized using its own implementation + // of `SerializeRow` and flattened as if they were fields in this struct. + #[darling(default)] + flatten: bool, + // If true, then the field is not serialized at all, but simply ignored. // All other attributes are ignored. #[darling(default)] @@ -64,6 +68,8 @@ struct FieldAttributes { struct Context { attributes: Attributes, fields: Vec, + struct_name: syn::Ident, + generics: syn::Generics, } pub(crate) fn derive_serialize_row(tokens_input: TokenStream) -> Result { @@ -90,7 +96,12 @@ pub(crate) fn derive_serialize_row(tokens_input: TokenStream) -> Result>()?; - let ctx = Context { attributes, fields }; + let ctx = Context { + attributes, + fields, + struct_name: struct_name.clone(), + generics: input.generics.clone(), + }; ctx.validate(&input.ident)?; let gen: Box = match ctx.attributes.flavor { @@ -136,6 +147,30 @@ impl Context { } } + // `flatten` annotations is not yet supported outside of `match_by_name` + if !matches!(self.attributes.flavor, Flavor::MatchByName) { + if let Some(field) = self.fields.iter().find(|f| f.attrs.flatten) { + let err = darling::Error::custom( + "the `flatten` annotations is only supported wit the `match_by_name` flavor", + ) + .with_span(&field.ident); + errors.push(err); + } + } + + // Check that no renames are attempted on flattened fields + let rename_flatten_errors = self + .fields + .iter() + .filter(|f| f.attrs.flatten && f.attrs.rename.is_some()) + .map(|f| { + darling::Error::custom( + "`rename` and `flatten` annotations do not make sense together", + ) + .with_span(&f.ident) + }); + errors.extend(rename_flatten_errors); + // Check for name collisions let mut used_names = HashMap::::new(); for field in self.fields.iter() { @@ -199,94 +234,148 @@ impl Generator for ColumnSortingGenerator<'_> { // Need to: // - Check that all required columns are there and no more // - Check that the column types match - let mut statements: Vec = Vec::new(); let crate_path = self.ctx.attributes.crate_path(); + let struct_name = &self.ctx.struct_name; + let (impl_generics, ty_generics, where_clause) = self.ctx.generics.split_for_impl(); + let partial_struct_name = syn::Ident::new( + &format!("_{}ScyllaSerPartial", struct_name), + struct_name.span(), + ); + let mut partial_generics = self.ctx.generics.clone(); + let partial_lt: syn::LifetimeParam = syn::parse_quote!('scylla_ser_partial); + if !self.ctx.fields.is_empty() { + partial_generics + .params + .push(syn::GenericParam::Lifetime(partial_lt.clone())); + } - let rust_field_idents = self - .ctx - .fields - .iter() - .map(|f| f.ident.clone()) - .collect::>(); - let rust_field_names = self + let (partial_impl_generics, partial_ty_generics, partial_where_clause) = + partial_generics.split_for_impl(); + + let flattened: Vec<_> = self.ctx.fields.iter().filter(|f| f.attrs.flatten).collect(); + let flattened_fields: Vec<_> = flattened.iter().map(|f| &f.ident).collect(); + let flattened_tys: Vec<_> = flattened.iter().map(|f| &f.ty).collect(); + + let unflattened: Vec<_> = self .ctx .fields .iter() - .map(|f| f.column_name()) - .collect::>(); - let udt_field_names = rust_field_names.clone(); // For now, it's the same - let field_types = self.ctx.fields.iter().map(|f| &f.ty).collect::>(); - - // Declare a helper lambda for creating errors - statements.push(self.ctx.generate_mk_typck_err()); - statements.push(self.ctx.generate_mk_ser_err()); + .filter(|f| !f.attrs.flatten) + .collect(); + let unflattened_columns: Vec<_> = unflattened.iter().map(|f| f.column_name()).collect(); + let unflattened_fields: Vec<_> = unflattened.iter().map(|f| &f.ident).collect(); + let unflattened_tys: Vec<_> = unflattened.iter().map(|f| &f.ty).collect(); + + let all_names = self.ctx.fields.iter().map(|f| f.column_name()); + + let partial_struct: syn::ItemStruct = parse_quote! { + struct #partial_struct_name #partial_generics { + #(#unflattened_fields: &#partial_lt #unflattened_tys,)* + #(#flattened_fields: <#flattened_tys as #crate_path::ser::row::SerializeRowByName>::Partial<#partial_lt>,)* + missing: ::std::collections::HashSet<&'static str>, + } + }; + + let serialize_field_block: syn::Block = if self.ctx.fields.is_empty() { + parse_quote! {{ + ::std::result::Result::Ok(#crate_path::ser::row::FieldStatus::NotUsed) + }} + } else { + parse_quote! {{ + match spec.name() { + #(#unflattened_columns => { + let sub_writer = #crate_path::RowWriter::make_cell_writer(writer); + <#unflattened_tys as #crate_path::SerializeValue>::serialize(&self.#unflattened_fields, spec.typ(), sub_writer)?; + self.missing.remove(#unflattened_columns); + })* + _ => { + let mk_err = |err| { + #crate_path::SerializationError::new(#crate_path::BuiltinRowSerializationError { + rust_name: ::std::any::type_name::<#struct_name #ty_generics>(), + kind: #crate_path::BuiltinRowSerializationErrorKind::ColumnSerializationFailed { + name: <_ as ::std::borrow::ToOwned>::to_owned(spec.name()), + err, + }, + }) + }; + + #({ + let serialized = + self.#flattened_fields.serialize_field(spec, writer).map_err(mk_err)?; + match serialized { + #crate_path::ser::row::FieldStatus::Done => { + self.missing.remove(stringify!(#flattened_fields)); + if self.missing.is_empty() { + return ::std::result::Result::Ok(#crate_path::ser::row::FieldStatus::Done); + } else { + return ::std::result::Result::Ok(#crate_path::ser::row::FieldStatus::NotDone); + } + } + #crate_path::ser::row::FieldStatus::NotDone => { + return ::std::result::Result::Ok(#crate_path::ser::row::FieldStatus::NotDone) + } + #crate_path::ser::row::FieldStatus::NotUsed => {} + }; + })* - // Generate a "visited" flag for each field - let visited_flag_names = rust_field_idents - .iter() - .map(|s| syn::Ident::new(&format!("visited_flag_{}", s), Span::call_site())) - .collect::>(); - statements.extend::>(parse_quote! { - #(let mut #visited_flag_names = false;)* - }); + return ::std::result::Result::Ok(#crate_path::ser::row::FieldStatus::NotUsed); + } + } - // Generate a variable that counts down visited fields. - let field_count = self.ctx.fields.len(); - statements.push(parse_quote! { - let mut remaining_count = #field_count; - }); + if self.missing.is_empty() { + ::std::result::Result::Ok(#crate_path::ser::row::FieldStatus::Done) + } else { + ::std::result::Result::Ok(#crate_path::ser::row::FieldStatus::NotDone) + } + }} + }; + + let partial_serialize: syn::ItemImpl = parse_quote! { + impl #partial_impl_generics #crate_path::ser::row::PartialSerializeRowByName for #partial_struct_name #partial_ty_generics #partial_where_clause { + fn serialize_field( + &mut self, + spec: &#crate_path::ColumnSpec, + writer: &mut #crate_path::RowWriter<'_>, + ) -> ::std::result::Result<#crate_path::ser::row::FieldStatus, #crate_path::SerializationError> { + #serialize_field_block + } - // Generate a loop over the fields and a `match` block to match on - // the field name. - statements.push(parse_quote! { - for spec in ctx.columns() { - match spec.name() { - #( - #udt_field_names => { - let sub_writer = #crate_path::RowWriter::make_cell_writer(writer); - match <#field_types as #crate_path::SerializeValue>::serialize(&self.#rust_field_idents, spec.typ(), sub_writer) { - ::std::result::Result::Ok(_proof) => {} - ::std::result::Result::Err(err) => { - return ::std::result::Result::Err(mk_ser_err( - #crate_path::BuiltinRowSerializationErrorKind::ColumnSerializationFailed { - name: <_ as ::std::borrow::ToOwned>::to_owned(spec.name()), - err, - } - )); - } - } - if !#visited_flag_names { - #visited_flag_names = true; - remaining_count -= 1; - } - } - )* - _ => return ::std::result::Result::Err(mk_typck_err( - #crate_path::BuiltinRowTypeCheckErrorKind::NoColumnWithName { - name: <_ as ::std::borrow::ToOwned>::to_owned(spec.name()), - } - )), + fn check_missing(self) -> ::std::result::Result<(), #crate_path::SerializationError> { + use ::std::iter::{Iterator as _, IntoIterator as _}; + + let ::std::option::Option::Some(missing) = self.missing.into_iter().nth(0) else { + return ::std::result::Result::Ok(()); + }; + + match missing { + #(stringify!(#flattened_fields) => self.#flattened_fields.check_missing(),)* + _ => ::std::result::Result::Err(#crate_path::SerializationError::new(#crate_path::BuiltinRowTypeCheckError { + rust_name: ::std::any::type_name::<#struct_name #ty_generics>(), + kind: #crate_path::BuiltinRowTypeCheckErrorKind::ValueMissingForColumn { + name: <_ as ::std::borrow::ToOwned>::to_owned(missing), + }, + })) + } } } - }); + }; - // Finally, check that all fields were consumed. - // If there are some missing fields, return an error - statements.push(parse_quote! { - if remaining_count > 0 { - #( - if !#visited_flag_names { - return ::std::result::Result::Err(mk_typck_err( - #crate_path::BuiltinRowTypeCheckErrorKind::ValueMissingForColumn { - name: <_ as ::std::string::ToString>::to_string(#rust_field_names), - } - )); + let serialize_by_name: syn::ItemImpl = parse_quote! { + impl #impl_generics #crate_path::ser::row::SerializeRowByName for #struct_name #ty_generics #where_clause { + type Partial<#partial_lt> = #partial_struct_name #partial_ty_generics where Self: #partial_lt; + + fn partial(&self) -> Self::Partial<'_> { + use ::std::iter::FromIterator as _; + + #partial_struct_name { + #(#unflattened_fields: &self.#unflattened_fields,)* + #(#flattened_fields: self.#flattened_fields.partial(),)* + missing: ::std::collections::HashSet::from_iter([#(#all_names,)*]), } - )* - ::std::unreachable!() + } } - }); + }; parse_quote! { fn serialize<'b>( @@ -294,8 +383,13 @@ impl Generator for ColumnSortingGenerator<'_> { ctx: &#crate_path::RowSerializationContext, writer: &mut #crate_path::RowWriter<'b>, ) -> ::std::result::Result<(), #crate_path::SerializationError> { - #(#statements)* - ::std::result::Result::Ok(()) + #partial_struct + #partial_serialize + + #[allow(non_local_definitions)] + #serialize_by_name + + #crate_path::ser::row::ByName(self).serialize(ctx, writer) } } } diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index 6724a65d9..2a4670f6e 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -81,7 +81,7 @@ //! .query_unpaged("SELECT a, b FROM ks.tab", &[]) //! .await? //! .into_rows_result()?; -//! +//! //! for row in query_rows.rows()? { //! // Parse row as int and text \ //! let (int_val, text_val): (i32, &str) = row?; diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index ce64153e8..ca7d64942 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -241,6 +241,14 @@ pub use scylla_cql::macros::SerializeValue; /// /// Don't use the field during serialization. /// +/// `#[scylla(flatten)]` +/// +/// Use this field's `SerializeRow` implementation to serialize its columns as part +/// of this struct. Note that the name of this field is ignored and hence the +/// `rename` attribute does not make sense here and will cause a compilation +/// error. Currently this is only supported for the `"match_by_name"` flavor in both +/// the wrapper struct and this flattened struct. +/// /// --- /// pub use scylla_cql::macros::SerializeRow;