diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index ab94470e10..6d74b680ba 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -21,6 +21,13 @@ pub mod _macro_internal { }; pub use crate::macros::*; + pub use crate::types::serialize::row::{ + BuiltinSerializationError as BuiltinRowSerializationError, + BuiltinSerializationErrorKind as BuiltinRowSerializationErrorKind, + BuiltinTypeCheckError as BuiltinRowTypeCheckError, + BuiltinTypeCheckErrorKind as BuiltinRowTypeCheckErrorKind, RowSerializationContext, + SerializeRow, + }; pub use crate::types::serialize::value::{ BuiltinSerializationError as BuiltinTypeSerializationError, BuiltinSerializationErrorKind as BuiltinTypeSerializationErrorKind, @@ -29,7 +36,9 @@ pub mod _macro_internal { UdtSerializationErrorKind, UdtTypeCheckErrorKind, }; pub use crate::types::serialize::writers::WrittenCellProof; - pub use crate::types::serialize::{CellValueBuilder, CellWriter, SerializationError}; + pub use crate::types::serialize::{ + CellValueBuilder, CellWriter, RowWriter, SerializationError, + }; pub use crate::frame::response::result::ColumnType; } diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index 56f1f43cf3..8f53e24fa9 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -75,6 +75,70 @@ pub use scylla_macros::ValueList; /// to either the `scylla` or `scylla-cql` crate. pub use scylla_macros::SerializeCql; +/// Derive macro for the [`SerializeRow`](crate::types::serialize::row::SerializeRow) trait +/// which serializes given Rust structure into bind markers for a CQL statement. +/// +/// At the moment, only structs with named fields are supported. The generated +/// implementation of the trait will match the struct fields to bind markers/columns +/// by name automatically. +/// +/// Serialization will fail if there are some bind markers/columns in the statement +/// that don't match to any of the Rust struct fields, _or vice versa_. +/// +/// In case of failure, either [`BuiltinTypeCheckError`](crate::types::serialize::row::BuiltinTypeCheckError) +/// or [`BuiltinSerializationError`](crate::types::serialize::row::BuiltinSerializationError) +/// will be returned. +/// +/// # Example +/// +/// A UDT defined like this: +/// Given a table and a query: +/// +/// ```notrust +/// CREATE TABLE ks.my_t (a int PRIMARY KEY, b text, c blob); +/// INSERT INTO ks.my_t (a, b, c) VALUES (?, ?, ?); +/// ``` +/// +/// ...the values for the query can be serialized using the following struct: +/// +/// ```rust +/// # use scylla_cql::macros::SerializeRow; +/// #[derive(SerializeRow)] +/// # #[scylla(crate = scylla_cql)] +/// struct MyValues { +/// a: i32, +/// b: Option, +/// c: Vec, +/// } +/// ``` +/// +/// # Attributes +/// +/// `#[scylla(crate = crate_name)]` +/// +/// By default, the code generated by the derive macro will refer to the items +/// defined by the driver (types, traits, etc.) via the `::scylla` path. +/// For example, it will refer to the [`SerializeRow`](crate::types::serialize::row::SerializeRow) trait +/// using the following path: +/// +/// ```rust,ignore +/// use ::scylla::_macro_internal::SerializeRow; +/// ``` +/// +/// Most users will simply add `scylla` to their dependencies, then use +/// the derive macro and the path above will work. However, there are some +/// niche cases where this path will _not_ work: +/// +/// - The `scylla` crate is imported under a different name, +/// - The `scylla` crate is _not imported at all_ - the macro actually +/// is defined in the `scylla-macros` crate and the generated code depends +/// on items defined in `scylla-cql`. +/// +/// It's not possible to automatically resolve those issues in the procedural +/// macro itself, so in those cases the user must provide an alternative path +/// to either the `scylla` or `scylla-cql` crate. +pub use scylla_macros::SerializeRow; + // Reexports for derive(IntoUserType) pub use bytes::{BufMut, Bytes, BytesMut}; diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index d8702100b6..d398a42281 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -561,7 +561,12 @@ mod tests { use crate::frame::value::{MaybeUnset, SerializedValues, ValueList}; use crate::types::serialize::RowWriter; - use super::{RowSerializationContext, SerializeRow}; + use super::{ + BuiltinSerializationError, BuiltinSerializationErrorKind, BuiltinTypeCheckError, + BuiltinTypeCheckErrorKind, RowSerializationContext, SerializeCql, SerializeRow, + }; + + use scylla_macros::SerializeRow; fn col_spec(name: &str, typ: ColumnType) -> ColumnSpec { ColumnSpec { @@ -672,4 +677,165 @@ mod tests { ); assert_eq!(typed_data, erased_data); } + + fn do_serialize(t: T, columns: &[ColumnSpec]) -> Vec { + let ctx = RowSerializationContext { columns }; + let mut ret = Vec::new(); + let mut builder = RowWriter::new(&mut ret); + t.serialize(&ctx, &mut builder).unwrap(); + ret + } + + fn col(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + table_spec: TableSpec { + ks_name: "ks".to_string(), + table_name: "tbl".to_string(), + }, + name: name.to_string(), + typ, + } + } + + // Do not remove. It's not used in tests but we keep it here to check that + // we properly ignore warnings about unused variables, unnecessary `mut`s + // etc. that usually pop up when generating code for empty structs. + #[derive(SerializeRow)] + #[scylla(crate = crate)] + struct TestRowWithNoColumns {} + + #[derive(SerializeRow, Debug, PartialEq, Eq, Default)] + #[scylla(crate = crate)] + struct TestRowWithColumnSorting { + a: String, + b: i32, + c: Vec, + } + + #[test] + fn test_row_serialization_with_column_sorting_correct_order() { + let spec = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + ]; + + let reference = do_serialize(("Ala ma kota", 42i32, vec![1i64, 2i64, 3i64]), &spec); + let row = do_serialize( + TestRowWithColumnSorting { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &spec, + ); + + assert_eq!(reference, row); + } + + #[test] + fn test_row_serialization_with_column_sorting_incorrect_order() { + // The order of two last columns is swapped + let spec = [ + col("a", ColumnType::Text), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + col("b", ColumnType::Int), + ]; + + let reference = do_serialize(("Ala ma kota", vec![1i64, 2i64, 3i64], 42i32), &spec); + let row = do_serialize( + TestRowWithColumnSorting { + a: "Ala ma kota".to_owned(), + b: 42, + c: vec![1, 2, 3], + }, + &spec, + ); + + assert_eq!(reference, row); + } + + #[test] + fn test_row_serialization_failing_type_check() { + let row = TestRowWithColumnSorting::default(); + let mut data = Vec::new(); + let mut row_writer = RowWriter::new(&mut data); + + let spec_without_c = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + // Missing column c + ]; + + let ctx = RowSerializationContext { + columns: &spec_without_c, + }; + let err = <_ as SerializeRow>::serialize(&row, &ctx, &mut row_writer).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::ColumnMissingForValue { .. } + )); + + let spec_duplicate_column = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::List(Box::new(ColumnType::BigInt))), + // Unexpected last column + col("d", ColumnType::Counter), + ]; + + let ctx = RowSerializationContext { + columns: &spec_duplicate_column, + }; + let err = <_ as SerializeRow>::serialize(&row, &ctx, &mut row_writer).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinTypeCheckErrorKind::MissingValueForColumn { .. } + )); + + let spec_wrong_type = [ + col("a", ColumnType::Text), + col("b", ColumnType::Int), + col("c", ColumnType::TinyInt), // Wrong type + ]; + + let ctx = RowSerializationContext { + columns: &spec_wrong_type, + }; + let err = <_ as SerializeRow>::serialize(&row, &ctx, &mut row_writer).unwrap_err(); + let err = err.0.downcast_ref::().unwrap(); + assert!(matches!( + err.kind, + BuiltinSerializationErrorKind::ColumnSerializationFailed { .. } + )); + } + + #[derive(SerializeRow)] + #[scylla(crate = crate)] + struct TestRowWithGenerics<'a, T: SerializeCql> { + a: &'a str, + b: T, + } + + #[test] + fn test_row_serialization_with_generics() { + // A minimal smoke test just to test that it works. + fn check_with_type(typ: ColumnType, t: T) { + let spec = [col("a", ColumnType::Text), col("b", typ)]; + let reference = do_serialize(("Ala ma kota", t), &spec); + let row = do_serialize( + TestRowWithGenerics { + a: "Ala ma kota", + b: t, + }, + &spec, + ); + assert_eq!(reference, row); + } + + check_with_type(ColumnType::Int, 123_i32); + check_with_type(ColumnType::Double, 123_f64); + } } diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index 84ee58bca0..64ce0ee06e 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -18,6 +18,15 @@ pub fn serialize_cql_derive(tokens_input: TokenStream) -> TokenStream { } } +/// See the documentation for this item in the `scylla` crate. +#[proc_macro_derive(SerializeRow, attributes(scylla))] +pub fn serialize_row_derive(tokens_input: TokenStream) -> TokenStream { + match serialize::row::derive_serialize_row(tokens_input) { + Ok(t) => t.into_token_stream().into(), + Err(e) => e.into_compile_error().into(), + } +} + /// #[derive(FromRow)] derives FromRow for struct /// Works only on simple structs without generics etc #[proc_macro_derive(FromRow, attributes(scylla_crate))] diff --git a/scylla-macros/src/serialize/mod.rs b/scylla-macros/src/serialize/mod.rs index 15fd9ae87c..53abe0f296 100644 --- a/scylla-macros/src/serialize/mod.rs +++ b/scylla-macros/src/serialize/mod.rs @@ -1 +1,2 @@ pub(crate) mod cql; +pub(crate) mod row; diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs new file mode 100644 index 0000000000..0dd2356041 --- /dev/null +++ b/scylla-macros/src/serialize/row.rs @@ -0,0 +1,202 @@ +use darling::FromAttributes; +use proc_macro::TokenStream; +use proc_macro2::Span; +use syn::parse_quote; + +#[derive(FromAttributes)] +#[darling(attributes(scylla))] +struct Attributes { + #[darling(rename = "crate")] + crate_path: Option, +} + +impl Attributes { + fn crate_path(&self) -> syn::Path { + self.crate_path + .as_ref() + .map(|p| parse_quote!(#p::_macro_internal)) + .unwrap_or_else(|| parse_quote!(::scylla::_macro_internal)) + } +} + +struct Context { + attributes: Attributes, + fields: Vec, +} + +pub fn derive_serialize_row(tokens_input: TokenStream) -> Result { + let input: syn::DeriveInput = syn::parse(tokens_input)?; + let struct_name = input.ident.clone(); + let named_fields = crate::parser::parse_named_fields(&input, "SerializeRow")?; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let attributes = Attributes::from_attributes(&input.attrs)?; + + let crate_path = attributes.crate_path(); + let implemented_trait: syn::Path = parse_quote!(#crate_path::SerializeRow); + + let fields = named_fields.named.iter().cloned().collect(); + let ctx = Context { attributes, fields }; + let gen = ColumnSortingGenerator { ctx: &ctx }; + + let serialize_item = gen.generate_serialize(); + let is_empty_item = gen.generate_is_empty(); + + let res = parse_quote! { + impl #impl_generics #implemented_trait for #struct_name #ty_generics #where_clause { + #serialize_item + #is_empty_item + } + }; + Ok(res) +} + +impl Context { + fn generate_mk_typck_err(&self) -> syn::Stmt { + let crate_path = self.attributes.crate_path(); + parse_quote! { + let mk_typck_err = |kind: #crate_path::BuiltinRowTypeCheckErrorKind| -> #crate_path::SerializationError { + #crate_path::SerializationError::new( + #crate_path::BuiltinRowTypeCheckError { + rust_name: ::std::any::type_name::(), + kind, + } + ) + }; + } + } + + fn generate_mk_ser_err(&self) -> syn::Stmt { + let crate_path = self.attributes.crate_path(); + parse_quote! { + let mk_ser_err = |kind: #crate_path::BuiltinRowSerializationErrorKind| -> #crate_path::SerializationError { + #crate_path::SerializationError::new( + #crate_path::BuiltinRowSerializationError { + rust_name: ::std::any::type_name::(), + kind, + } + ) + }; + } + } +} + +// Generates an implementation of the trait which sorts the columns according +// to how they are defined in prepared statement metadata. +struct ColumnSortingGenerator<'a> { + ctx: &'a Context, +} + +impl<'a> ColumnSortingGenerator<'a> { + fn generate_serialize(&self) -> syn::TraitItemFn { + // Need to: + // - Check that all required columns are there and no more + // - Check that the column types match + let mut statements: Vec = Vec::new(); + + let crate_path = self.ctx.attributes.crate_path(); + + let rust_field_idents = self + .ctx + .fields + .iter() + .map(|f| f.ident.clone()) + .collect::>(); + let rust_field_names = rust_field_idents + .iter() + .map(|i| i.as_ref().unwrap().to_string()) + .collect::>(); + let udt_field_names = rust_field_names.clone(); // For now, it's the same + let field_types = self.ctx.fields.iter().map(|f| &f.ty).collect::>(); + + // Declare a helper lambda for creating errors + statements.push(self.ctx.generate_mk_typck_err()); + statements.push(self.ctx.generate_mk_ser_err()); + + // Generate a "visited" flag for each field + let visited_flag_names = rust_field_names + .iter() + .map(|s| syn::Ident::new(&format!("visited_flag_{}", s), Span::call_site())) + .collect::>(); + statements.extend::>(parse_quote! { + #(let mut #visited_flag_names = false;)* + }); + + // Generate a variable that counts down visited fields. + let field_count = self.ctx.fields.len(); + statements.push(parse_quote! { + let mut remaining_count = #field_count; + }); + + // Generate a loop over the fields and a `match` block to match on + // the field name. + statements.push(parse_quote! { + for spec in ctx.columns() { + match ::std::string::String::as_str(&spec.name) { + #( + #udt_field_names => { + let sub_writer = #crate_path::RowWriter::make_cell_writer(writer); + match <#field_types as #crate_path::SerializeCql>::serialize(&self.#rust_field_idents, &spec.typ, sub_writer) { + ::std::result::Result::Ok(_proof) => {} + ::std::result::Result::Err(err) => { + return ::std::result::Result::Err(mk_ser_err( + #crate_path::BuiltinRowSerializationErrorKind::ColumnSerializationFailed { + name: <_ as ::std::clone::Clone>::clone(&spec.name), + err, + } + )); + } + } + if !#visited_flag_names { + #visited_flag_names = true; + remaining_count -= 1; + } + } + )* + _ => return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::MissingValueForColumn { + name: <_ as ::std::clone::Clone>::clone(&&spec.name), + } + )), + } + } + }); + + // Finally, check that all fields were consumed. + // If there are some missing fields, return an error + statements.push(parse_quote! { + if remaining_count > 0 { + #( + if !#visited_flag_names { + return ::std::result::Result::Err(mk_typck_err( + #crate_path::BuiltinRowTypeCheckErrorKind::ColumnMissingForValue { + name: <_ as ::std::string::ToString>::to_string(#rust_field_names), + } + )); + } + )* + ::std::unreachable!() + } + }); + + parse_quote! { + fn serialize<'b>( + &self, + ctx: &#crate_path::RowSerializationContext, + writer: &mut #crate_path::RowWriter<'b>, + ) -> ::std::result::Result<(), #crate_path::SerializationError> { + #(#statements)* + ::std::result::Result::Ok(()) + } + } + } + + fn generate_is_empty(&self) -> syn::TraitItemFn { + let is_empty = self.ctx.fields.is_empty(); + parse_quote! { + #[inline] + fn is_empty(&self) -> bool { + #is_empty + } + } + } +} diff --git a/scylla/tests/integration/hygiene.rs b/scylla/tests/integration/hygiene.rs index 12d55ccb61..cf2aaed7b3 100644 --- a/scylla/tests/integration/hygiene.rs +++ b/scylla/tests/integration/hygiene.rs @@ -64,7 +64,7 @@ macro_rules! test_crate { assert_eq!(sv, sv2); } - #[derive(_scylla::macros::SerializeCql)] + #[derive(_scylla::macros::SerializeCql, _scylla::macros::SerializeRow)] #[scylla(crate = _scylla)] struct TestStructNew { x: ::core::primitive::i32,