Skip to content

Commit

Permalink
scylla-macros: implement enforce_order flavor of SerializeCql
Browse files Browse the repository at this point in the history
Some users might not need the additional robustness of `SerializeCql`
that comes from sorting the fields before serializing, as they are used
to the current behavior of `Value` and properly set the order of the
fields in their Rust struct. In order to give them some performance
boost, add an additional mode to `SerializeCql` called "enforce_order"
which expects that the order of the fields in the struct is kept in
sync with the DB definition of the UDT.

It's still safe to use because, as the struct fields are serialized,
their names are compared with the fields in the UDT definition order
and serialization fails if the field name on some position is
mismatched.
  • Loading branch information
piodul committed Dec 9, 2023
1 parent 30a69f8 commit dcb4cf4
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 5 deletions.
19 changes: 16 additions & 3 deletions scylla-cql/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ pub use scylla_macros::ValueList;
/// Derive macro for the [`SerializeCql`](crate::types::serialize::value::SerializeCql) trait
/// which serializes given Rust structure as a User Defined Type (UDT).
///
/// At the moment, only structs with named fields are supported. The generated
/// implementation of the trait will match the struct fields to UDT fields
/// by name automatically.
/// At the moment, only structs with named fields are supported.
///
/// Serialization will fail if there are some fields in the UDT that don't match
/// to any of the Rust struct fields, _or vice versa_.
Expand Down Expand Up @@ -50,6 +48,21 @@ pub use scylla_macros::ValueList;
///
/// # Attributes
///
/// `#[scylla(flavor = "flavor_name")]`
///
/// Allows to choose one of the possible "flavors", i.e. the way how the
/// generated code will approach serialization. Possible flavors are:
///
/// - `"match_by_name"` (default) - the generated implementation _does not
/// require_ the fields in the Rust struct to be in the same order as the
/// fields in the UDT. During serialization, the implementation will take
/// care to serialize the fields in the order which the database expects.
/// - `"enforce_order"` - the generated implementation _requires_ the fields
/// in the Rust struct to be in the same order as the fields in the UDT.
/// If the order is incorrect, type checking/serialization will fail.
/// This is a less robust flavor than `"match_by_name"`, but should be
/// slightly more performant as it doesn't need to perform lookups by name.
///
/// `#[scylla(crate = crate_name)]`
///
/// By default, the code generated by the derive macro will refer to the items
Expand Down
170 changes: 170 additions & 0 deletions scylla-cql/src/types/serialize/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,12 @@ pub enum UdtTypeCheckErrorKind {

/// The Rust data contains a field that is not present in the UDT
UnexpectedFieldInDestination { field_name: String },

/// A different field name was expected at given position.
FieldNameMismatch {
rust_field_name: String,
db_field_name: String,
},
}

impl Display for UdtTypeCheckErrorKind {
Expand All @@ -1337,6 +1343,10 @@ impl Display for UdtTypeCheckErrorKind {
f,
"the field {field_name} present in the Rust data is not present in the CQL type"
),
UdtTypeCheckErrorKind::FieldNameMismatch { rust_field_name, db_field_name } => write!(
f,
"expected field with name {db_field_name} at given position, but the Rust field name is {rust_field_name}"
),
}
}
}
Expand Down Expand Up @@ -1668,4 +1678,164 @@ mod tests {
check_with_type(ColumnType::Int, 123_i32, CqlValue::Int(123_i32));
check_with_type(ColumnType::Double, 123_f64, CqlValue::Double(123_f64));
}

#[derive(SerializeCql, Debug, PartialEq, Eq, Default)]
#[scylla(crate = crate, flavor = "enforce_order")]
struct TestUdtWithEnforcedOrder {
a: String,
b: i32,
c: Vec<i64>,
}

#[test]
fn test_udt_serialization_with_enforced_order_correct_order() {
let typ = ColumnType::UserDefinedType {
type_name: "typ".to_string(),
keyspace: "ks".to_string(),
field_types: vec![
("a".to_string(), ColumnType::Text),
("b".to_string(), ColumnType::Int),
(
"c".to_string(),
ColumnType::List(Box::new(ColumnType::BigInt)),
),
],
};

let reference = do_serialize(
CqlValue::UserDefinedType {
keyspace: "ks".to_string(),
type_name: "typ".to_string(),
fields: vec![
(
"a".to_string(),
Some(CqlValue::Text(String::from("Ala ma kota"))),
),
("b".to_string(), Some(CqlValue::Int(42))),
(
"c".to_string(),
Some(CqlValue::List(vec![
CqlValue::BigInt(1),
CqlValue::BigInt(2),
CqlValue::BigInt(3),
])),
),
],
},
&typ,
);
let udt = do_serialize(
TestUdtWithEnforcedOrder {
a: "Ala ma kota".to_owned(),
b: 42,
c: vec![1, 2, 3],
},
&typ,
);

assert_eq!(reference, udt);
}

#[test]
fn test_udt_serialization_with_enforced_order_failing_type_check() {
let typ_not_udt = ColumnType::Ascii;
let udt = TestUdtWithEnforcedOrder::default();

let mut data = Vec::new();

let err = <_ as SerializeCql>::serialize(&udt, &typ_not_udt, CellWriter::new(&mut data))
.unwrap_err();
let err = err.0.downcast_ref::<BuiltinTypeCheckError>().unwrap();
assert!(matches!(
err.kind,
BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt)
));

let typ = ColumnType::UserDefinedType {
type_name: "typ".to_string(),
keyspace: "ks".to_string(),
field_types: vec![
// Two first columns are swapped
("b".to_string(), ColumnType::Int),
("a".to_string(), ColumnType::Text),
(
"c".to_string(),
ColumnType::List(Box::new(ColumnType::BigInt)),
),
],
};

let err =
<_ as SerializeCql>::serialize(&udt, &typ, CellWriter::new(&mut data)).unwrap_err();
let err = err.0.downcast_ref::<BuiltinTypeCheckError>().unwrap();
assert!(matches!(
err.kind,
BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::FieldNameMismatch { .. })
));

let typ_without_c = ColumnType::UserDefinedType {
type_name: "typ".to_string(),
keyspace: "ks".to_string(),
field_types: vec![
("a".to_string(), ColumnType::Text),
("b".to_string(), ColumnType::Int),
// Last field is missing
],
};

let err = <_ as SerializeCql>::serialize(&udt, &typ_without_c, CellWriter::new(&mut data))
.unwrap_err();
let err = err.0.downcast_ref::<BuiltinTypeCheckError>().unwrap();
assert!(matches!(
err.kind,
BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::MissingField { .. })
));

let typ_unexpected_field = ColumnType::UserDefinedType {
type_name: "typ".to_string(),
keyspace: "ks".to_string(),
field_types: vec![
("a".to_string(), ColumnType::Text),
("b".to_string(), ColumnType::Int),
(
"c".to_string(),
ColumnType::List(Box::new(ColumnType::BigInt)),
),
// Unexpected field
("d".to_string(), ColumnType::Counter),
],
};

let err =
<_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data))
.unwrap_err();
let err = err.0.downcast_ref::<BuiltinTypeCheckError>().unwrap();
assert!(matches!(
err.kind,
BuiltinTypeCheckErrorKind::UdtError(
UdtTypeCheckErrorKind::UnexpectedFieldInDestination { .. }
)
));

let typ_unexpected_field = ColumnType::UserDefinedType {
type_name: "typ".to_string(),
keyspace: "ks".to_string(),
field_types: vec![
("a".to_string(), ColumnType::Text),
("b".to_string(), ColumnType::Int),
("c".to_string(), ColumnType::TinyInt), // Wrong column type
],
};

let err =
<_ as SerializeCql>::serialize(&udt, &typ_unexpected_field, CellWriter::new(&mut data))
.unwrap_err();
let err = err.0.downcast_ref::<BuiltinSerializationError>().unwrap();
assert!(matches!(
err.kind,
BuiltinSerializationErrorKind::UdtError(
UdtSerializationErrorKind::FieldSerializationFailed { .. }
)
));
}
}
121 changes: 119 additions & 2 deletions scylla-macros/src/serialize/cql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ use proc_macro::TokenStream;
use proc_macro2::Span;
use syn::parse_quote;

use super::Flavor;

#[derive(FromAttributes)]
#[darling(attributes(scylla))]
struct Attributes {
#[darling(rename = "crate")]
crate_path: Option<syn::Path>,

flavor: Option<Flavor>,
}

impl Attributes {
Expand Down Expand Up @@ -36,7 +40,11 @@ pub fn derive_serialize_cql(tokens_input: TokenStream) -> Result<syn::ItemImpl,

let fields = named_fields.named.iter().cloned().collect();
let ctx = Context { attributes, fields };
let gen = FieldSortingGenerator { ctx: &ctx };

let gen: Box<dyn Generator> = match ctx.attributes.flavor {
Some(Flavor::MatchByName) | None => Box::new(FieldSortingGenerator { ctx: &ctx }),
Some(Flavor::EnforceOrder) => Box::new(FieldOrderedGenerator { ctx: &ctx }),
};

let serialize_item = gen.generate_serialize();

Expand Down Expand Up @@ -93,13 +101,17 @@ impl Context {
}
}

trait Generator {
fn generate_serialize(&self) -> syn::TraitItemFn;
}

// Generates an implementation of the trait which sorts the fields according
// to how it is defined in the database.
struct FieldSortingGenerator<'a> {
ctx: &'a Context,
}

impl<'a> FieldSortingGenerator<'a> {
impl<'a> Generator for FieldSortingGenerator<'a> {
fn generate_serialize(&self) -> syn::TraitItemFn {
// Need to:
// - Check that all required fields are there and no more
Expand Down Expand Up @@ -222,3 +234,108 @@ impl<'a> FieldSortingGenerator<'a> {
}
}
}

// Generates an implementation of the trait which requires the fields
// to be placed in the same order as they are defined in the struct.
struct FieldOrderedGenerator<'a> {
ctx: &'a Context,
}

impl<'a> Generator for FieldOrderedGenerator<'a> {
fn generate_serialize(&self) -> syn::TraitItemFn {
let mut statements: Vec<syn::Stmt> = Vec::new();

let crate_path = self.ctx.attributes.crate_path();

// Declare a helper lambda for creating errors
statements.push(self.ctx.generate_mk_typck_err());
statements.push(self.ctx.generate_mk_ser_err());

// Check that the type we want to serialize to is a UDT
statements.push(
self.ctx
.generate_udt_type_match(parse_quote!(#crate_path::UdtTypeCheckErrorKind::NotUdt)),
);

// Turn the cell writer into a value builder
statements.push(parse_quote! {
let mut builder = #crate_path::CellWriter::into_value_builder(writer);
});

// Create an iterator over fields
statements.push(parse_quote! {
let mut field_iter = field_types.iter();
});

// Serialize each field
for field in self.ctx.fields.iter() {
let rust_field_ident = field.ident.as_ref().unwrap();
let rust_field_name = rust_field_ident.to_string();
let typ = &field.ty;
statements.push(parse_quote! {
match field_iter.next() {
Some((field_name, typ)) => {
if field_name == #rust_field_name {
let sub_builder = #crate_path::CellValueBuilder::make_sub_writer(&mut builder);
match <#typ as #crate_path::SerializeCql>::serialize(&self.#rust_field_ident, typ, sub_builder) {
Ok(_proof) => {},
Err(err) => {
return ::std::result::Result::Err(mk_ser_err(
#crate_path::UdtSerializationErrorKind::FieldSerializationFailed {
field_name: <_ as ::std::clone::Clone>::clone(field_name),
err,
}
));
}
}
} else {
return ::std::result::Result::Err(mk_typck_err(
#crate_path::UdtTypeCheckErrorKind::FieldNameMismatch {
rust_field_name: <_ as ::std::string::ToString>::to_string(#rust_field_name),
db_field_name: <_ as ::std::clone::Clone>::clone(field_name),
}
));
}
}
None => {
return ::std::result::Result::Err(mk_typck_err(
#crate_path::UdtTypeCheckErrorKind::MissingField {
field_name: <_ as ::std::string::ToString>::to_string(#rust_field_name),
}
));
}
}
});
}

// Check whether there are some fields remaining
statements.push(parse_quote! {
if let Some((field_name, typ)) = field_iter.next() {
return ::std::result::Result::Err(mk_typck_err(
#crate_path::UdtTypeCheckErrorKind::UnexpectedFieldInDestination {
field_name: <_ as ::std::clone::Clone>::clone(field_name),
}
));
}
});

parse_quote! {
fn serialize<'b>(
&self,
typ: &#crate_path::ColumnType,
writer: #crate_path::CellWriter<'b>,
) -> ::std::result::Result<#crate_path::WrittenCellProof<'b>, #crate_path::SerializationError> {
#(#statements)*
let proof = #crate_path::CellValueBuilder::finish(builder)
.map_err(|_| #crate_path::SerializationError::new(
#crate_path::BuiltinTypeSerializationError {
rust_name: ::std::any::type_name::<Self>(),
got: <_ as ::std::clone::Clone>::clone(typ),
kind: #crate_path::BuiltinTypeSerializationErrorKind::SizeOverflow,
}
) as #crate_path::SerializationError)?;
::std::result::Result::Ok(proof)
}
}
}
}
18 changes: 18 additions & 0 deletions scylla-macros/src/serialize/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,20 @@
use darling::FromMeta;

pub(crate) mod cql;
pub(crate) mod row;

#[derive(Copy, Clone, PartialEq, Eq)]
enum Flavor {
MatchByName,
EnforceOrder,
}

impl FromMeta for Flavor {
fn from_string(value: &str) -> darling::Result<Self> {
match value {
"match_by_name" => Ok(Self::MatchByName),
"enforce_order" => Ok(Self::EnforceOrder),
_ => Err(darling::Error::unknown_value(value)),
}
}
}

0 comments on commit dcb4cf4

Please sign in to comment.