diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index f8b3e07995..61d9f345c5 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -2,6 +2,8 @@ pub mod errors; pub mod frame; #[macro_use] pub mod macros { + pub use scylla_macros::DeserializeRow; + pub use scylla_macros::DeserializeValue; pub use scylla_macros::FromRow; pub use scylla_macros::FromUserType; pub use scylla_macros::IntoUserType; @@ -27,12 +29,30 @@ pub mod _macro_internal { pub use crate::frame::response::cql_to_rust::{ FromCqlVal, FromCqlValError, FromRow, FromRowError, }; - pub use crate::frame::response::result::{CqlValue, Row}; + pub use crate::frame::response::result::{ColumnSpec, ColumnType, CqlValue, Row}; pub use crate::frame::value::{ LegacySerializedValues, SerializedResult, Value, ValueList, ValueTooBig, }; pub use crate::macros::*; + pub use crate::types::deserialize::row::{ + deser_error_replace_rust_name as row_deser_error_replace_rust_name, + mk_deser_err as mk_row_deser_err, mk_typck_err as mk_row_typck_err, + BuiltinDeserializationError as BuiltinRowDeserializationError, + BuiltinDeserializationErrorKind as BuiltinRowDeserializationErrorKind, + BuiltinTypeCheckErrorKind as DeserBuiltinRowTypeCheckErrorKind, ColumnIterator, + DeserializeRow, + }; + pub use crate::types::deserialize::value::{ + deser_error_replace_rust_name as value_deser_error_replace_rust_name, + mk_deser_err as mk_value_deser_err, mk_typck_err as mk_value_typck_err, + BuiltinDeserializationError as BuiltinTypeDeserializationError, + BuiltinDeserializationErrorKind as BuiltinTypeDeserializationErrorKind, + BuiltinTypeCheckErrorKind as DeserBuiltinTypeTypeCheckErrorKind, DeserializeValue, + UdtDeserializationErrorKind, UdtIterator, + UdtTypeCheckErrorKind as DeserUdtTypeCheckErrorKind, + }; + pub use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; pub use crate::types::serialize::row::{ BuiltinSerializationError as BuiltinRowSerializationError, BuiltinSerializationErrorKind as BuiltinRowSerializationErrorKind, @@ -51,6 +71,4 @@ pub mod _macro_internal { pub use crate::types::serialize::{ CellValueBuilder, CellWriter, RowWriter, SerializationError, }; - - pub use crate::frame::response::result::ColumnType; } diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index 12e73052ba..2d8f0713e8 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -245,7 +245,9 @@ impl Display for DeserializationError { // - BEFORE an error is cloned (because otherwise the Arc::get_mut fails). macro_rules! make_error_replace_rust_name { ($fn_name: ident, $outer_err: ty, $inner_err: ty) => { - fn $fn_name(mut err: $outer_err) -> $outer_err { + // Not part of the public API; used in derive macros. + #[doc(hidden)] + pub fn $fn_name(mut err: $outer_err) -> $outer_err { // Safety: the assumed usage of this function guarantees that the Arc has not yet been cloned. let arc_mut = std::sync::Arc::get_mut(&mut err.0).unwrap(); diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 5dfec4b12a..8971e1a711 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -260,7 +260,9 @@ pub struct BuiltinTypeCheckError { pub kind: BuiltinTypeCheckErrorKind, } -fn mk_typck_err( +// Not part of the public API; used in derive macros. +#[doc(hidden)] +pub fn mk_typck_err( cql_types: impl IntoIterator, kind: impl Into, ) -> TypeCheckError { @@ -292,6 +294,38 @@ pub enum BuiltinTypeCheckErrorKind { cql_cols: usize, }, + /// The CQL row contains a column for which a corresponding field is not found + /// in the Rust type. + ColumnWithUnknownName { + /// Index of the excess column. + column_index: usize, + + /// Name of the column that is present in CQL row but not in the Rust type. + column_name: String, + }, + + /// Several values required by the Rust type are not provided by the DB. + ValuesMissingForColumns { + /// Names of the columns in the Rust type for which the DB doesn't + /// provide value. + column_names: Vec<&'static str>, + }, + + /// A different column name was expected at given position. + ColumnNameMismatch { + /// Index of the field determining the expected name. + field_index: usize, + + /// Index of the column having mismatched name. + column_index: usize, + + /// Name of the column, as expected by the Rust type. + rust_column_name: &'static str, + + /// Name of the column for which the DB requested a value. + db_column_name: String, + }, + /// Column type check failed between Rust type and DB type at given position (=in given column). ColumnTypeCheckFailed { /// Index of the column. @@ -303,6 +337,15 @@ pub enum BuiltinTypeCheckErrorKind { /// Inner type check error due to the type mismatch. err: TypeCheckError, }, + + /// Duplicated column in DB metadata. + DuplicatedColumn { + /// Column index of the second occurence of the column with the same name. + column_index: usize, + + /// The name of the duplicated column. + column_name: &'static str, + }, } impl Display for BuiltinTypeCheckErrorKind { @@ -314,6 +357,33 @@ impl Display for BuiltinTypeCheckErrorKind { } => { write!(f, "wrong column count: the statement operates on {cql_cols} columns, but the given rust types contains {rust_cols}") } + BuiltinTypeCheckErrorKind::ColumnWithUnknownName { column_name, column_index } => { + write!( + f, + "the CQL row contains a column {} at column index {}, but the corresponding field is not found in the Rust type", + column_name, + column_index, + ) + } + BuiltinTypeCheckErrorKind::ValuesMissingForColumns { column_names } => { + write!( + f, + "values for columns {:?} are missing from the DB data but are required by the Rust type", + column_names + ) + }, + BuiltinTypeCheckErrorKind::ColumnNameMismatch { + field_index, + column_index,rust_column_name, + db_column_name + } => write!( + f, + "expected column with name {} at column index {}, but the Rust field name at corresponding field index {} is {}", + db_column_name, + column_index, + field_index, + rust_column_name, + ), BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { column_index, column_name, @@ -322,6 +392,12 @@ impl Display for BuiltinTypeCheckErrorKind { f, "mismatched types in column {column_name} at index {column_index}: {err}" ), + BuiltinTypeCheckErrorKind::DuplicatedColumn { column_name, column_index } => write!( + f, + "column {} occurs more than once in DB metadata; second occurence is at column index {}", + column_name, + column_index, + ), } } } @@ -338,13 +414,13 @@ pub struct BuiltinDeserializationError { pub kind: BuiltinDeserializationErrorKind, } -pub(super) fn mk_deser_err( - kind: impl Into, -) -> DeserializationError { +// Not part of the public API; used in derive macros. +#[doc(hidden)] +pub fn mk_deser_err(kind: impl Into) -> DeserializationError { mk_deser_err_named(std::any::type_name::(), kind) } -pub(super) fn mk_deser_err_named( +fn mk_deser_err_named( name: &'static str, kind: impl Into, ) -> DeserializationError { @@ -412,272 +488,49 @@ impl Display for BuiltinDeserializationErrorKind { } #[cfg(test)] -mod tests { - use assert_matches::assert_matches; - use bytes::Bytes; - - use crate::frame::response::result::{ColumnSpec, ColumnType}; - use crate::types::deserialize::row::BuiltinDeserializationErrorKind; - use crate::types::deserialize::{DeserializationError, FrameSlice}; - - use super::super::tests::{serialize_cells, spec}; - use super::{BuiltinDeserializationError, ColumnIterator, CqlValue, DeserializeRow, Row}; - use super::{BuiltinTypeCheckError, BuiltinTypeCheckErrorKind}; - - #[test] - fn test_tuple_deserialization() { - // Empty tuple - deserialize::<()>(&[], &Bytes::new()).unwrap(); - - // 1-elem tuple - let (a,) = deserialize::<(i32,)>( - &[spec("i", ColumnType::Int)], - &serialize_cells([val_int(123)]), - ) - .unwrap(); - assert_eq!(a, 123); - - // 3-elem tuple - let (a, b, c) = deserialize::<(i32, i32, i32)>( - &[ - spec("i1", ColumnType::Int), - spec("i2", ColumnType::Int), - spec("i3", ColumnType::Int), - ], - &serialize_cells([val_int(123), val_int(456), val_int(789)]), - ) - .unwrap(); - assert_eq!((a, b, c), (123, 456, 789)); - - // Make sure that column type mismatch is detected - deserialize::<(i32, String, i32)>( - &[ - spec("i1", ColumnType::Int), - spec("i2", ColumnType::Int), - spec("i3", ColumnType::Int), - ], - &serialize_cells([val_int(123), val_int(456), val_int(789)]), - ) - .unwrap_err(); - - // Make sure that borrowing types compile and work correctly - let specs = &[spec("s", ColumnType::Text)]; - let byts = serialize_cells([val_str("abc")]); - let (s,) = deserialize::<(&str,)>(specs, &byts).unwrap(); - assert_eq!(s, "abc"); - } - - #[test] - fn test_deserialization_as_column_iterator() { - let col_specs = [ - spec("i1", ColumnType::Int), - spec("i2", ColumnType::Text), - spec("i3", ColumnType::Counter), - ]; - let serialized_values = serialize_cells([val_int(123), val_str("ScyllaDB"), None]); - let mut iter = deserialize::(&col_specs, &serialized_values).unwrap(); - - let col1 = iter.next().unwrap().unwrap(); - assert_eq!(col1.spec.name, "i1"); - assert_eq!(col1.spec.typ, ColumnType::Int); - assert_eq!(col1.slice.unwrap().as_slice(), &123i32.to_be_bytes()); - - let col2 = iter.next().unwrap().unwrap(); - assert_eq!(col2.spec.name, "i2"); - assert_eq!(col2.spec.typ, ColumnType::Text); - assert_eq!(col2.slice.unwrap().as_slice(), "ScyllaDB".as_bytes()); - - let col3 = iter.next().unwrap().unwrap(); - assert_eq!(col3.spec.name, "i3"); - assert_eq!(col3.spec.typ, ColumnType::Counter); - assert!(col3.slice.is_none()); - - assert!(iter.next().is_none()); - } - - fn val_int(i: i32) -> Option> { - Some(i.to_be_bytes().to_vec()) - } - - fn val_str(s: &str) -> Option> { - Some(s.as_bytes().to_vec()) - } - - fn deserialize<'frame, R>( - specs: &'frame [ColumnSpec], - byts: &'frame Bytes, - ) -> Result - where - R: DeserializeRow<'frame>, - { - >::type_check(specs) - .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; - let slice = FrameSlice::new(byts); - let iter = ColumnIterator::new(specs, slice); - >::deserialize(iter) - } - - #[track_caller] - fn get_typck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { - match err.0.downcast_ref() { - Some(err) => err, - None => panic!("not a BuiltinTypeCheckError: {:?}", err), - } - } - - #[track_caller] - fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { - match err.0.downcast_ref() { - Some(err) => err, - None => panic!("not a BuiltinDeserializationError: {:?}", err), - } - } - - #[test] - fn test_tuple_errors() { - // Column type check failure - { - let col_name: &str = "i"; - let specs = &[spec(col_name, ColumnType::Int)]; - let err = deserialize::<(i64,)>(specs, &serialize_cells([val_int(123)])).unwrap_err(); - let err = get_typck_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); - assert_eq!( - err.cql_types, - specs - .iter() - .map(|spec| spec.typ.clone()) - .collect::>() - ); - let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { - column_index, - column_name, - err, - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(*column_index, 0); - assert_eq!(column_name, col_name); - let err = super::super::value::tests::get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Int); - assert_matches!( - &err.kind, - super::super::value::BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::BigInt, ColumnType::Counter] - } - ); - } - - // Column deserialization failure - { - let col_name: &str = "i"; - let err = deserialize::<(i64,)>( - &[spec(col_name, ColumnType::BigInt)], - &serialize_cells([val_int(123)]), - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); - let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { - column_name, - err, - .. - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(column_name, col_name); - let err = super::super::value::tests::get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::BigInt); - assert_matches!( - err.kind, - super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4 - } - ); - } +#[path = "row_tests.rs"] +mod tests; - // Raw column deserialization failure - { - let col_name: &str = "i"; - let err = deserialize::<(i64,)>( - &[spec(col_name, ColumnType::BigInt)], - &Bytes::from_static(b"alamakota"), - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); - let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { - column_index: _column_index, - column_name, - err: _err, - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(column_name, col_name); - } - } - - #[test] - fn test_row_errors() { - // Column type check failure - happens never, because Row consists of CqlValues, - // which accept all CQL types. - - // Column deserialization failure - { - let col_name: &str = "i"; - let err = deserialize::( - &[spec(col_name, ColumnType::BigInt)], - &serialize_cells([val_int(123)]), - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { - column_index: _column_index, - column_name, - err, - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(column_name, col_name); - let err = super::super::value::tests::get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::BigInt); - let super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4, - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - } +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeRow)] +/// #[scylla(crate = scylla_cql, skip_name_checks)] +/// struct TestRow {} +/// ``` +fn _test_struct_deserialization_name_check_skip_requires_enforce_order() {} - // Raw column deserialization failure - { - let col_name: &str = "i"; - let err = deserialize::( - &[spec(col_name, ColumnType::BigInt)], - &Bytes::from_static(b"alamakota"), - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { - column_index: _column_index, - column_name, - err: _err, - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(column_name, col_name); - } - } -} +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeRow)] +/// #[scylla(crate = scylla_cql, skip_name_checks)] +/// struct TestRow { +/// #[scylla(rename = "b")] +/// a: i32, +/// } +/// ``` +fn _test_struct_deserialization_skip_name_check_conflicts_with_rename() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeRow)] +/// #[scylla(crate = scylla_cql)] +/// struct TestRow { +/// #[scylla(rename = "b")] +/// a: i32, +/// b: String, +/// } +/// ``` +fn _test_struct_deserialization_skip_rename_collision_with_field() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeRow)] +/// #[scylla(crate = scylla_cql)] +/// struct TestRow { +/// #[scylla(rename = "c")] +/// a: i32, +/// #[scylla(rename = "c")] +/// b: String, +/// } +/// ``` +fn _test_struct_deserialization_rename_collision_with_another_rename() {} diff --git a/scylla-cql/src/types/deserialize/row_tests.rs b/scylla-cql/src/types/deserialize/row_tests.rs new file mode 100644 index 0000000000..f4c79d66a8 --- /dev/null +++ b/scylla-cql/src/types/deserialize/row_tests.rs @@ -0,0 +1,861 @@ +use assert_matches::assert_matches; +use bytes::Bytes; +use scylla_macros::DeserializeRow; + +use crate::frame::response::result::{ColumnSpec, ColumnType}; +use crate::types::deserialize::row::BuiltinDeserializationErrorKind; +use crate::types::deserialize::{value, DeserializationError, FrameSlice}; + +use super::super::tests::{serialize_cells, spec}; +use super::{BuiltinDeserializationError, ColumnIterator, CqlValue, DeserializeRow, Row}; +use super::{BuiltinTypeCheckError, BuiltinTypeCheckErrorKind}; + +#[test] +fn test_tuple_deserialization() { + // Empty tuple + deserialize::<()>(&[], &Bytes::new()).unwrap(); + + // 1-elem tuple + let (a,) = deserialize::<(i32,)>( + &[spec("i", ColumnType::Int)], + &serialize_cells([val_int(123)]), + ) + .unwrap(); + assert_eq!(a, 123); + + // 3-elem tuple + let (a, b, c) = deserialize::<(i32, i32, i32)>( + &[ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Int), + spec("i3", ColumnType::Int), + ], + &serialize_cells([val_int(123), val_int(456), val_int(789)]), + ) + .unwrap(); + assert_eq!((a, b, c), (123, 456, 789)); + + // Make sure that column type mismatch is detected + deserialize::<(i32, String, i32)>( + &[ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Int), + spec("i3", ColumnType::Int), + ], + &serialize_cells([val_int(123), val_int(456), val_int(789)]), + ) + .unwrap_err(); + + // Make sure that borrowing types compile and work correctly + let specs = &[spec("s", ColumnType::Text)]; + let byts = serialize_cells([val_str("abc")]); + let (s,) = deserialize::<(&str,)>(specs, &byts).unwrap(); + assert_eq!(s, "abc"); +} + +#[test] +fn test_deserialization_as_column_iterator() { + let col_specs = [ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Text), + spec("i3", ColumnType::Counter), + ]; + let serialized_values = serialize_cells([val_int(123), val_str("ScyllaDB"), None]); + let mut iter = deserialize::(&col_specs, &serialized_values).unwrap(); + + let col1 = iter.next().unwrap().unwrap(); + assert_eq!(col1.spec.name, "i1"); + assert_eq!(col1.spec.typ, ColumnType::Int); + assert_eq!(col1.slice.unwrap().as_slice(), &123i32.to_be_bytes()); + + let col2 = iter.next().unwrap().unwrap(); + assert_eq!(col2.spec.name, "i2"); + assert_eq!(col2.spec.typ, ColumnType::Text); + assert_eq!(col2.slice.unwrap().as_slice(), "ScyllaDB".as_bytes()); + + let col3 = iter.next().unwrap().unwrap(); + assert_eq!(col3.spec.name, "i3"); + assert_eq!(col3.spec.typ, ColumnType::Counter); + assert!(col3.slice.is_none()); + + assert!(iter.next().is_none()); +} + +// Do not remove. It's not used in tests but we keep it here to check that +// we properly ignore warnings about unused variables, unnecessary `mut`s +// etc. that usually pop up when generating code for empty structs. +#[allow(unused)] +#[derive(DeserializeRow)] +#[scylla(crate = crate)] +struct TestUdtWithNoFieldsUnordered {} + +#[allow(unused)] +#[derive(DeserializeRow)] +#[scylla(crate = crate, enforce_order)] +struct TestUdtWithNoFieldsOrdered {} + +#[test] +fn test_struct_deserialization_loose_ordering() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Original order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Different order of columns - should still work + let specs = &[spec("b", ColumnType::Int), spec("a", ColumnType::Text)]; + let byts = serialize_cells([val_int(123), val_str("abc")]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Missing column + let specs = &[spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Wrong column type + let specs = &[spec("a", ColumnType::Int), spec("b", ColumnType::Int)]; + MyRow::type_check(specs).unwrap_err(); +} + +#[test] +fn test_struct_deserialization_strict_ordering() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Correct order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Wrong order of columns + let specs = &[spec("b", ColumnType::Int), spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Missing column + let specs = &[spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Wrong column type + let specs = &[spec("a", ColumnType::Int), spec("b", ColumnType::Int)]; + MyRow::type_check(specs).unwrap_err(); +} + +#[test] +fn test_struct_deserialization_no_name_check() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, skip_name_checks)] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Correct order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Correct order of columns, but different names - should still succeed + let specs = &[spec("z", ColumnType::Text), spec("x", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); +} + +#[test] +fn test_struct_deserialization_cross_rename_fields() { + #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = crate)] + struct TestRow { + #[scylla(rename = "b")] + a: i32, + #[scylla(rename = "a")] + b: String, + } + + // Columns switched wrt fields - should still work. + { + let row_bytes = + serialize_cells(["The quick brown fox".as_bytes(), &42_i32.to_be_bytes()].map(Some)); + let specs = [spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + + let row = deserialize::(&specs, &row_bytes).unwrap(); + assert_eq!( + row, + TestRow { + a: 42, + b: "The quick brown fox".to_owned(), + } + ); + } +} + +fn val_int(i: i32) -> Option> { + Some(i.to_be_bytes().to_vec()) +} + +fn val_str(s: &str) -> Option> { + Some(s.as_bytes().to_vec()) +} + +fn deserialize<'frame, R>( + specs: &'frame [ColumnSpec], + byts: &'frame Bytes, +) -> Result +where + R: DeserializeRow<'frame>, +{ + >::type_check(specs) + .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; + let slice = FrameSlice::new(byts); + let iter = ColumnIterator::new(specs, slice); + >::deserialize(iter) +} + +#[track_caller] +pub(crate) fn get_typck_err_inner<'a>( + err: &'a (dyn std::error::Error + 'static), +) -> &'a BuiltinTypeCheckError { + match err.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinTypeCheckError: {:?}", err), + } +} + +#[track_caller] +fn get_typck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { + get_typck_err_inner(err.0.as_ref()) +} + +#[track_caller] +fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { + match err.0.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinDeserializationError: {:?}", err), + } +} + +#[test] +fn test_tuple_errors() { + // Column type check failure + { + let col_name: &str = "i"; + let specs = &[spec(col_name, ColumnType::Int)]; + let err = deserialize::<(i64,)>(specs, &serialize_cells([val_int(123)])).unwrap_err(); + let err = get_typck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + assert_eq!( + err.cql_types, + specs + .iter() + .map(|spec| spec.typ.clone()) + .collect::>() + ); + let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + column_name, + err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(*column_index, 0); + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Int); + assert_matches!( + &err.kind, + super::super::value::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + // Column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::<(i64,)>( + &[spec(col_name, ColumnType::BigInt)], + &serialize_cells([val_int(123)]), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_name, err, .. + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + + // Raw column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::<(i64,)>( + &[spec(col_name, ColumnType::BigInt)], + &Bytes::from_static(b"alamakota"), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index: _column_index, + column_name, + err: _err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + } +} + +#[test] +fn test_row_errors() { + // Column type check failure - happens never, because Row consists of CqlValues, + // which accept all CQL types. + + // Column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::( + &[spec(col_name, ColumnType::BigInt)], + &serialize_cells([val_int(123)]), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: _column_index, + column_name, + err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + let super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + } + + // Raw column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::( + &[spec(col_name, ColumnType::BigInt)], + &Bytes::from_static(b"alamakota"), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index: _column_index, + column_name, + err: _err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + } +} + +fn specs_to_types(specs: &[ColumnSpec]) -> Vec { + specs.iter().map(|spec| spec.typ.clone()).collect() +} + +#[test] +fn test_struct_deserialization_errors() { + // Loose ordering + { + #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct MyRow<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + #[scylla(rename = "c")] + d: bool, + } + + // Type check errors + { + // Missing column + { + let specs = [spec("a", ColumnType::Ascii), spec("b", ColumnType::Int)]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ValuesMissingForColumns { + column_names: ref missing_fields, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(missing_fields.as_slice(), &["c"]); + } + + // Duplicated column + { + let specs = [ + spec("a", ColumnType::Ascii), + spec("b", ColumnType::Int), + spec("a", ColumnType::Ascii), + ]; + + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::DuplicatedColumn { + column_index, + column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 2); + assert_eq!(column_name, "a"); + } + + // Unknown column + { + let specs = [ + spec("d", ColumnType::Counter), + spec("a", ColumnType::Ascii), + spec("b", ColumnType::Int), + ]; + + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnWithUnknownName { + column_index, + ref column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 0); + assert_eq!(column_name.as_str(), "d"); + } + + // Column incompatible types - column type check failed + { + let specs = [spec("b", ColumnType::Int), spec("a", ColumnType::Blob)]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 1); + assert_eq!(column_name.as_str(), "a"); + let err = value::tests::get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + value::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Got null + { + let specs = [ + spec("c", ColumnType::Boolean), + spec("a", ColumnType::Blob), + spec("b", ColumnType::Int), + ]; + + let err = MyRow::deserialize(ColumnIterator::new( + &specs, + FrameSlice::new(&serialize_cells([Some([true as u8])])), + )) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + ref column_name, + .. + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 1); + assert_eq!(column_name, "a"); + } + + // Column deserialization failed + { + let specs = [ + spec("b", ColumnType::Int), + spec("a", ColumnType::Ascii), + spec("c", ColumnType::Boolean), + ]; + + let row_bytes = serialize_cells( + [ + &0_i32.to_be_bytes(), + "alamakota".as_bytes(), + &42_i16.to_be_bytes(), + ] + .map(Some), + ); + + let err = deserialize::(&specs, &row_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 2); + assert_eq!(column_name.as_str(), "c"); + let err = value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Boolean); + assert_matches!( + err.kind, + value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 1, + got: 2, + } + ); + } + } + } + + // Strict ordering + { + #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct MyRow<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + c: bool, + } + + // Type check errors + { + // Too few columns + { + let specs = [spec("a", ColumnType::Text)]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::WrongColumnCount { + rust_cols, + cql_cols, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(rust_cols, 3); + assert_eq!(cql_cols, 1); + } + + // Excess columns + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + spec("d", ColumnType::Counter), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::WrongColumnCount { + rust_cols, + cql_cols, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(rust_cols, 3); + assert_eq!(cql_cols, 4); + } + + // Renamed column name mismatch + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("d", ColumnType::Boolean), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinTypeCheckErrorKind::ColumnNameMismatch { + field_index, + column_index, + rust_column_name, + ref db_column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_index, 3); + assert_eq!(rust_column_name, "c"); + assert_eq!(column_index, 2); + assert_eq!(db_column_name.as_str(), "d"); + } + + // Columns switched - column name mismatch + { + let specs = [ + spec("b", ColumnType::Int), + spec("a", ColumnType::Text), + spec("c", ColumnType::Boolean), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnNameMismatch { + field_index, + column_index, + rust_column_name, + ref db_column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_index, 0); + assert_eq!(column_index, 0); + assert_eq!(rust_column_name, "a"); + assert_eq!(db_column_name.as_str(), "b"); + } + + // Column incompatible types - column type check failed + { + let specs = [ + spec("a", ColumnType::Blob), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 0); + assert_eq!(column_name.as_str(), "a"); + let err = value::tests::get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + value::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Too few columns + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + + let err = MyRow::deserialize(ColumnIterator::new( + &specs, + FrameSlice::new(&serialize_cells([Some([true as u8])])), + )) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + ref column_name, + .. + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 1); + assert_eq!(column_name, "b"); + } + + // Bad field format + { + let typ = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + + let row_bytes = serialize_cells( + [(&b"alamakota"[..]), &42_i32.to_be_bytes(), &[true as u8]].map(Some), + ); + + let row_bytes_too_short = row_bytes.slice(..row_bytes.len() - 1); + assert!(row_bytes.len() > row_bytes_too_short.len()); + + let err = deserialize::(&typ, &row_bytes_too_short).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + ref column_name, + .. + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 2); + assert_eq!(column_name, "c"); + } + + // Column deserialization failed + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + + let row_bytes = serialize_cells( + [&b"alamakota"[..], &42_i64.to_be_bytes(), &[true as u8]].map(Some), + ); + + let err = deserialize::(&specs, &row_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: field_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_name.as_str(), "b"); + assert_eq!(field_index, 2); + let err = value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Int); + assert_matches!( + err.kind, + value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 4, + got: 8, + } + ); + } + } + } +} diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 8431ea17cc..074a7c298a 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1335,7 +1335,9 @@ pub struct BuiltinTypeCheckError { pub kind: BuiltinTypeCheckErrorKind, } -fn mk_typck_err( +// Not part of the public API; used in derive macros. +#[doc(hidden)] +pub fn mk_typck_err( cql_type: &ColumnType, kind: impl Into, ) -> TypeCheckError { @@ -1547,6 +1549,52 @@ impl Display for TupleTypeCheckErrorKind { pub enum UdtTypeCheckErrorKind { /// The CQL type is not a user defined type. NotUdt, + + /// The CQL UDT type does not have some fields that is required in the Rust struct. + ValuesMissingForUdtFields { + /// Names of fields that the Rust struct requires but are missing in the CQL UDT. + field_names: Vec<&'static str>, + }, + + /// A different field name was expected at given position. + FieldNameMismatch { + /// Index of the field in the Rust struct. + position: usize, + + /// The name of the Rust field. + rust_field_name: String, + + /// The name of the CQL UDT field. + db_field_name: String, + }, + + /// UDT contains an excess field, which does not correspond to any Rust struct's field. + ExcessFieldInUdt { + /// The name of the CQL UDT field. + db_field_name: String, + }, + + /// Duplicated field in serialized data. + DuplicatedField { + /// The name of the duplicated field. + field_name: String, + }, + + /// Fewer fields present in the UDT than required by the Rust type. + TooFewFields { + // TODO: decide whether we are OK with restricting to `&'static str` here. + required_fields: Vec<&'static str>, + present_fields: Vec, + }, + + /// Type check failed between UDT and Rust type field. + FieldTypeCheckFailed { + /// The name of the field whose type check failed. + field_name: String, + + /// Inner type check error that occured. + err: TypeCheckError, + }, } impl Display for UdtTypeCheckErrorKind { @@ -1556,6 +1604,35 @@ impl Display for UdtTypeCheckErrorKind { f, "the CQL type the Rust type was attempted to be type checked against is not a UDT" ), + UdtTypeCheckErrorKind::ValuesMissingForUdtFields { field_names } => { + write!(f, "the fields {field_names:?} are missing from the DB data but are required by the Rust type") + }, + UdtTypeCheckErrorKind::FieldNameMismatch { rust_field_name, db_field_name, position } => write!( + f, + "expected field with name {db_field_name} at position {position}, but the Rust field name is {rust_field_name}" + ), + UdtTypeCheckErrorKind::ExcessFieldInUdt { db_field_name } => write!( + f, + "UDT contains an excess field {}, which does not correspond to any Rust struct's field.", + db_field_name + ), + UdtTypeCheckErrorKind::DuplicatedField { field_name } => write!( + f, + "field {} occurs more than once in CQL UDT type", + field_name + ), + UdtTypeCheckErrorKind::TooFewFields { required_fields, present_fields } => write!( + f, + "fewer fields present in the UDT than required by the Rust type: UDT has {:?}, Rust type requires {:?}", + present_fields, + required_fields, + ), + UdtTypeCheckErrorKind::FieldTypeCheckFailed { field_name, err } => write!( + f, + "the UDT field {} types between the CQL type and the Rust type failed to type check against each other: {}", + field_name, + err + ), } } } @@ -1574,7 +1651,9 @@ pub struct BuiltinDeserializationError { pub kind: BuiltinDeserializationErrorKind, } -pub(crate) fn mk_deser_err( +// Not part of the public API; used in derive macros. +#[doc(hidden)] +pub fn mk_deser_err( cql_type: &ColumnType, kind: impl Into, ) -> DeserializationError { @@ -1639,6 +1718,9 @@ pub enum BuiltinDeserializationErrorKind { /// A deserialization failure specific to a CQL tuple. TupleError(TupleDeserializationErrorKind), + + /// A deserialization failure specific to a CQL UDT. + UdtError(UdtDeserializationErrorKind), } impl Display for BuiltinDeserializationErrorKind { @@ -1671,6 +1753,7 @@ impl Display for BuiltinDeserializationErrorKind { BuiltinDeserializationErrorKind::SetOrListError(err) => err.fmt(f), BuiltinDeserializationErrorKind::MapError(err) => err.fmt(f), BuiltinDeserializationErrorKind::TupleError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::UdtError(err) => err.fmt(f), BuiltinDeserializationErrorKind::CustomTypeNotSupported(typ) => write!(f, "Support for custom types is not yet implemented: {}", typ), } } @@ -1776,1162 +1859,105 @@ impl From for BuiltinDeserializationErrorKind { } } -#[cfg(test)] -pub(super) mod tests { - use assert_matches::assert_matches; - use bytes::{BufMut, Bytes, BytesMut}; - use uuid::Uuid; - - use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; - use std::fmt::Debug; - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - - use crate::frame::response::result::{ColumnType, CqlValue}; - use crate::frame::value::{ - Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, - }; - use crate::types::deserialize::value::{ - TupleDeserializationErrorKind, TupleTypeCheckErrorKind, - }; - use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; - use crate::types::serialize::value::SerializeValue; - use crate::types::serialize::CellWriter; - - use super::{ - mk_deser_err, BuiltinDeserializationError, BuiltinDeserializationErrorKind, - BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, DeserializeValue, ListlikeIterator, - MapDeserializationErrorKind, MapIterator, MapTypeCheckErrorKind, MaybeEmpty, - SetOrListDeserializationErrorKind, SetOrListTypeCheckErrorKind, - }; - - #[test] - fn test_deserialize_bytes() { - const ORIGINAL_BYTES: &[u8] = &[1, 5, 2, 4, 3]; - - let bytes = make_bytes(ORIGINAL_BYTES); - - let decoded_slice = deserialize::<&[u8]>(&ColumnType::Blob, &bytes).unwrap(); - let decoded_vec = deserialize::>(&ColumnType::Blob, &bytes).unwrap(); - let decoded_bytes = deserialize::(&ColumnType::Blob, &bytes).unwrap(); - - assert_eq!(decoded_slice, ORIGINAL_BYTES); - assert_eq!(decoded_vec, ORIGINAL_BYTES); - assert_eq!(decoded_bytes, ORIGINAL_BYTES); - - // ser/de identity - - // Nonempty blob - assert_ser_de_identity(&ColumnType::Blob, &ORIGINAL_BYTES, &mut Bytes::new()); - - // Empty blob - assert_ser_de_identity(&ColumnType::Blob, &(&[] as &[u8]), &mut Bytes::new()); - } - - #[test] - fn test_deserialize_ascii() { - const ASCII_TEXT: &str = "The quick brown fox jumps over the lazy dog"; - - let ascii = make_bytes(ASCII_TEXT.as_bytes()); - - for typ in [ColumnType::Ascii, ColumnType::Text].iter() { - let decoded_str = deserialize::<&str>(typ, &ascii).unwrap(); - let decoded_string = deserialize::(typ, &ascii).unwrap(); - - assert_eq!(decoded_str, ASCII_TEXT); - assert_eq!(decoded_string, ASCII_TEXT); - - // ser/de identity - - // Empty string - assert_ser_de_identity(typ, &"", &mut Bytes::new()); - assert_ser_de_identity(typ, &"".to_owned(), &mut Bytes::new()); - - // Nonempty string - assert_ser_de_identity(typ, &ASCII_TEXT, &mut Bytes::new()); - assert_ser_de_identity(typ, &ASCII_TEXT.to_owned(), &mut Bytes::new()); - } - } - - #[test] - fn test_deserialize_text() { - const UNICODE_TEXT: &str = "Zażółć gęślą jaźń"; - - let unicode = make_bytes(UNICODE_TEXT.as_bytes()); - - // Should fail because it's not an ASCII string - deserialize::<&str>(&ColumnType::Ascii, &unicode).unwrap_err(); - deserialize::(&ColumnType::Ascii, &unicode).unwrap_err(); - - let decoded_text_str = deserialize::<&str>(&ColumnType::Text, &unicode).unwrap(); - let decoded_text_string = deserialize::(&ColumnType::Text, &unicode).unwrap(); - assert_eq!(decoded_text_str, UNICODE_TEXT); - assert_eq!(decoded_text_string, UNICODE_TEXT); - - // ser/de identity - - assert_ser_de_identity(&ColumnType::Text, &UNICODE_TEXT, &mut Bytes::new()); - assert_ser_de_identity( - &ColumnType::Text, - &UNICODE_TEXT.to_owned(), - &mut Bytes::new(), - ); - } - - #[test] - fn test_integral() { - let tinyint = make_bytes(&[0x01]); - let decoded_tinyint = deserialize::(&ColumnType::TinyInt, &tinyint).unwrap(); - assert_eq!(decoded_tinyint, 0x01); - - let smallint = make_bytes(&[0x01, 0x02]); - let decoded_smallint = deserialize::(&ColumnType::SmallInt, &smallint).unwrap(); - assert_eq!(decoded_smallint, 0x0102); - - let int = make_bytes(&[0x01, 0x02, 0x03, 0x04]); - let decoded_int = deserialize::(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, 0x01020304); - - let bigint = make_bytes(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); - let decoded_bigint = deserialize::(&ColumnType::BigInt, &bigint).unwrap(); - assert_eq!(decoded_bigint, 0x0102030405060708); - - // ser/de identity - assert_ser_de_identity(&ColumnType::TinyInt, &42_i8, &mut Bytes::new()); - assert_ser_de_identity(&ColumnType::SmallInt, &2137_i16, &mut Bytes::new()); - assert_ser_de_identity(&ColumnType::Int, &21372137_i32, &mut Bytes::new()); - assert_ser_de_identity(&ColumnType::BigInt, &0_i64, &mut Bytes::new()); - } - - #[test] - fn test_bool() { - for boolean in [true, false] { - let boolean_bytes = make_bytes(&[boolean as u8]); - let decoded_bool = deserialize::(&ColumnType::Boolean, &boolean_bytes).unwrap(); - assert_eq!(decoded_bool, boolean); - - // ser/de identity - assert_ser_de_identity(&ColumnType::Boolean, &boolean, &mut Bytes::new()); - } - } - - #[test] - fn test_floating_point() { - let float = make_bytes(&[63, 0, 0, 0]); - let decoded_float = deserialize::(&ColumnType::Float, &float).unwrap(); - assert_eq!(decoded_float, 0.5); - - let double = make_bytes(&[64, 0, 0, 0, 0, 0, 0, 0]); - let decoded_double = deserialize::(&ColumnType::Double, &double).unwrap(); - assert_eq!(decoded_double, 2.0); - - // ser/de identity - assert_ser_de_identity(&ColumnType::Float, &21.37_f32, &mut Bytes::new()); - assert_ser_de_identity(&ColumnType::Double, &2137.2137_f64, &mut Bytes::new()); - } - - #[test] - fn test_varlen_numbers() { - // varint - assert_ser_de_identity( - &ColumnType::Varint, - &CqlVarint::from_signed_bytes_be_slice(b"Ala ma kota"), - &mut Bytes::new(), - ); - - #[cfg(feature = "num-bigint-03")] - assert_ser_de_identity( - &ColumnType::Varint, - &num_bigint_03::BigInt::from_signed_bytes_be(b"Kot ma Ale"), - &mut Bytes::new(), - ); - - #[cfg(feature = "num-bigint-04")] - assert_ser_de_identity( - &ColumnType::Varint, - &num_bigint_04::BigInt::from_signed_bytes_be(b"Kot ma Ale"), - &mut Bytes::new(), - ); - - // decimal - assert_ser_de_identity( - &ColumnType::Decimal, - &CqlDecimal::from_signed_be_bytes_slice_and_exponent(b"Ala ma kota", 42), - &mut Bytes::new(), - ); - - #[cfg(feature = "bigdecimal-04")] - assert_ser_de_identity( - &ColumnType::Decimal, - &bigdecimal_04::BigDecimal::new( - bigdecimal_04::num_bigint::BigInt::from_signed_bytes_be(b"Ala ma kota"), - 42, - ), - &mut Bytes::new(), - ); - } - - #[test] - fn test_date_time_types() { - // duration - assert_ser_de_identity( - &ColumnType::Duration, - &CqlDuration { - months: 21, - days: 37, - nanoseconds: 42, - }, - &mut Bytes::new(), - ); - - // date - assert_ser_de_identity(&ColumnType::Date, &CqlDate(0xbeaf), &mut Bytes::new()); - - #[cfg(feature = "chrono-04")] - assert_ser_de_identity( - &ColumnType::Date, - &chrono_04::NaiveDate::from_yo_opt(1999, 99).unwrap(), - &mut Bytes::new(), - ); - - #[cfg(feature = "time-03")] - assert_ser_de_identity( - &ColumnType::Date, - &time_03::Date::from_ordinal_date(1999, 99).unwrap(), - &mut Bytes::new(), - ); - - // time - assert_ser_de_identity(&ColumnType::Time, &CqlTime(0xdeed), &mut Bytes::new()); - - #[cfg(feature = "chrono-04")] - assert_ser_de_identity( - &ColumnType::Time, - &chrono_04::NaiveTime::from_hms_micro_opt(21, 37, 21, 37).unwrap(), - &mut Bytes::new(), - ); - - #[cfg(feature = "time-03")] - assert_ser_de_identity( - &ColumnType::Time, - &time_03::Time::from_hms_micro(21, 37, 21, 37).unwrap(), - &mut Bytes::new(), - ); - - // timestamp - assert_ser_de_identity( - &ColumnType::Timestamp, - &CqlTimestamp(0xceed), - &mut Bytes::new(), - ); - - #[cfg(feature = "chrono-04")] - assert_ser_de_identity( - &ColumnType::Timestamp, - &chrono_04::DateTime::::from_timestamp_millis(0xdead_cafe_deaf) - .unwrap(), - &mut Bytes::new(), - ); - - #[cfg(feature = "time-03")] - assert_ser_de_identity( - &ColumnType::Timestamp, - &time_03::OffsetDateTime::from_unix_timestamp(0xdead_cafe).unwrap(), - &mut Bytes::new(), - ); - } - - #[test] - fn test_inet() { - assert_ser_de_identity( - &ColumnType::Inet, - &IpAddr::V4(Ipv4Addr::BROADCAST), - &mut Bytes::new(), - ); - - assert_ser_de_identity( - &ColumnType::Inet, - &IpAddr::V6(Ipv6Addr::LOCALHOST), - &mut Bytes::new(), - ); - } - - #[test] - fn test_uuid() { - assert_ser_de_identity( - &ColumnType::Uuid, - &Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), - &mut Bytes::new(), - ); - - assert_ser_de_identity( - &ColumnType::Timeuuid, - &CqlTimeuuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), - &mut Bytes::new(), - ); - } - - #[test] - fn test_null_and_empty() { - // non-nullable emptiable deserialization, non-empty value - let int = make_bytes(&[21, 37, 0, 0]); - let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, MaybeEmpty::Value((21 << 24) + (37 << 16))); - - // non-nullable emptiable deserialization, empty value - let int = make_bytes(&[]); - let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, MaybeEmpty::Empty); - - // nullable non-emptiable deserialization, non-null value - let int = make_bytes(&[21, 37, 0, 0]); - let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, Some((21 << 24) + (37 << 16))); - - // nullable non-emptiable deserialization, null value - let int = make_null(); - let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, None); - - // nullable emptiable deserialization, non-null non-empty value - let int = make_bytes(&[]); - let decoded_int = deserialize::>>(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, Some(MaybeEmpty::Empty)); - - // ser/de identity - assert_ser_de_identity(&ColumnType::Int, &Some(12321_i32), &mut Bytes::new()); - assert_ser_de_identity(&ColumnType::Double, &None::, &mut Bytes::new()); - assert_ser_de_identity( - &ColumnType::Set(Box::new(ColumnType::Ascii)), - &None::>, - &mut Bytes::new(), - ); - } - - #[test] - fn test_maybe_empty() { - let empty = make_bytes(&[]); - let decoded_empty = deserialize::>(&ColumnType::TinyInt, &empty).unwrap(); - assert_eq!(decoded_empty, MaybeEmpty::Empty); - - let non_empty = make_bytes(&[0x01]); - let decoded_non_empty = - deserialize::>(&ColumnType::TinyInt, &non_empty).unwrap(); - assert_eq!(decoded_non_empty, MaybeEmpty::Value(0x01)); - } - - #[test] - fn test_cql_value() { - assert_ser_de_identity( - &ColumnType::Counter, - &CqlValue::Counter(Counter(765)), - &mut Bytes::new(), - ); - - assert_ser_de_identity( - &ColumnType::Timestamp, - &CqlValue::Timestamp(CqlTimestamp(2136)), - &mut Bytes::new(), - ); - - assert_ser_de_identity(&ColumnType::Boolean, &CqlValue::Empty, &mut Bytes::new()); - - assert_ser_de_identity( - &ColumnType::Text, - &CqlValue::Text("kremówki".to_owned()), - &mut Bytes::new(), - ); - assert_ser_de_identity( - &ColumnType::Ascii, - &CqlValue::Ascii("kremowy".to_owned()), - &mut Bytes::new(), - ); - - assert_ser_de_identity( - &ColumnType::Set(Box::new(ColumnType::Text)), - &CqlValue::Set(vec![CqlValue::Text("Ala ma kota".to_owned())]), - &mut Bytes::new(), - ); - } - - #[test] - fn test_list_and_set() { - let mut collection_contents = BytesMut::new(); - collection_contents.put_i32(3); - append_bytes(&mut collection_contents, "quick".as_bytes()); - append_bytes(&mut collection_contents, "brown".as_bytes()); - append_bytes(&mut collection_contents, "fox".as_bytes()); - - let collection = make_bytes(&collection_contents); - - let list_typ = ColumnType::List(Box::new(ColumnType::Ascii)); - let set_typ = ColumnType::Set(Box::new(ColumnType::Ascii)); - - // iterator - let mut iter = deserialize::>(&list_typ, &collection).unwrap(); - assert_eq!(iter.next().transpose().unwrap(), Some("quick")); - assert_eq!(iter.next().transpose().unwrap(), Some("brown")); - assert_eq!(iter.next().transpose().unwrap(), Some("fox")); - assert_eq!(iter.next().transpose().unwrap(), None); - - let expected_vec_str = vec!["quick", "brown", "fox"]; - let expected_vec_string = vec!["quick".to_string(), "brown".to_string(), "fox".to_string()]; - - // list - let decoded_vec_str = deserialize::>(&list_typ, &collection).unwrap(); - let decoded_vec_string = deserialize::>(&list_typ, &collection).unwrap(); - assert_eq!(decoded_vec_str, expected_vec_str); - assert_eq!(decoded_vec_string, expected_vec_string); - - // hash set - let decoded_hash_str = deserialize::>(&set_typ, &collection).unwrap(); - let decoded_hash_string = deserialize::>(&set_typ, &collection).unwrap(); - assert_eq!( - decoded_hash_str, - expected_vec_str.clone().into_iter().collect(), - ); - assert_eq!( - decoded_hash_string, - expected_vec_string.clone().into_iter().collect(), - ); - - // btree set - let decoded_btree_str = deserialize::>(&set_typ, &collection).unwrap(); - let decoded_btree_string = deserialize::>(&set_typ, &collection).unwrap(); - assert_eq!( - decoded_btree_str, - expected_vec_str.clone().into_iter().collect(), - ); - assert_eq!( - decoded_btree_string, - expected_vec_string.into_iter().collect(), - ); - - // ser/de identity - assert_ser_de_identity(&list_typ, &vec!["qwik"], &mut Bytes::new()); - assert_ser_de_identity(&set_typ, &vec!["qwik"], &mut Bytes::new()); - assert_ser_de_identity( - &set_typ, - &HashSet::<&str, std::collections::hash_map::RandomState>::from_iter(["qwik"]), - &mut Bytes::new(), - ); - assert_ser_de_identity( - &set_typ, - &BTreeSet::<&str>::from_iter(["qwik"]), - &mut Bytes::new(), - ); - } - - #[test] - fn test_map() { - let mut collection_contents = BytesMut::new(); - collection_contents.put_i32(3); - append_bytes(&mut collection_contents, &1i32.to_be_bytes()); - append_bytes(&mut collection_contents, "quick".as_bytes()); - append_bytes(&mut collection_contents, &2i32.to_be_bytes()); - append_bytes(&mut collection_contents, "brown".as_bytes()); - append_bytes(&mut collection_contents, &3i32.to_be_bytes()); - append_bytes(&mut collection_contents, "fox".as_bytes()); - - let collection = make_bytes(&collection_contents); - - let typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Ascii)); - - // iterator - let mut iter = deserialize::>(&typ, &collection).unwrap(); - assert_eq!(iter.next().transpose().unwrap(), Some((1, "quick"))); - assert_eq!(iter.next().transpose().unwrap(), Some((2, "brown"))); - assert_eq!(iter.next().transpose().unwrap(), Some((3, "fox"))); - assert_eq!(iter.next().transpose().unwrap(), None); - - let expected_str = vec![(1, "quick"), (2, "brown"), (3, "fox")]; - let expected_string = vec![ - (1, "quick".to_string()), - (2, "brown".to_string()), - (3, "fox".to_string()), - ]; - - // hash set - let decoded_hash_str = deserialize::>(&typ, &collection).unwrap(); - let decoded_hash_string = deserialize::>(&typ, &collection).unwrap(); - assert_eq!(decoded_hash_str, expected_str.clone().into_iter().collect()); - assert_eq!( - decoded_hash_string, - expected_string.clone().into_iter().collect(), - ); - - // btree set - let decoded_btree_str = deserialize::>(&typ, &collection).unwrap(); - let decoded_btree_string = deserialize::>(&typ, &collection).unwrap(); - assert_eq!( - decoded_btree_str, - expected_str.clone().into_iter().collect(), - ); - assert_eq!(decoded_btree_string, expected_string.into_iter().collect()); - - // ser/de identity - assert_ser_de_identity( - &typ, - &HashMap::::from_iter([( - -42, "qwik", - )]), - &mut Bytes::new(), - ); - assert_ser_de_identity( - &typ, - &BTreeMap::::from_iter([(-42, "qwik")]), - &mut Bytes::new(), - ); - } - - #[test] - fn test_tuples() { - let mut tuple_contents = BytesMut::new(); - append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); - append_bytes(&mut tuple_contents, "foo".as_bytes()); - append_null(&mut tuple_contents); - - let tuple = make_bytes(&tuple_contents); - - let typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Ascii, ColumnType::Uuid]); - - let tup = deserialize::<(i32, &str, Option)>(&typ, &tuple).unwrap(); - assert_eq!(tup, (42, "foo", None)); - - // ser/de identity - - // () does not implement SerializeValue, yet it does implement DeserializeValue. - // assert_ser_de_identity(&ColumnType::Tuple(vec![]), &(), &mut Bytes::new()); - - // nonempty, varied tuple - assert_ser_de_identity( - &ColumnType::Tuple(vec![ - ColumnType::List(Box::new(ColumnType::Boolean)), - ColumnType::BigInt, - ColumnType::Uuid, - ColumnType::Inet, - ]), - &( - vec![true, false, true], - 42_i64, - Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), - IpAddr::V6(Ipv6Addr::new(0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11)), - ), - &mut Bytes::new(), - ); - - // nested tuples - assert_ser_de_identity( - &ColumnType::Tuple(vec![ColumnType::Tuple(vec![ColumnType::Tuple(vec![ - ColumnType::Text, - ])])]), - &((("",),),), - &mut Bytes::new(), - ); - } - - #[test] - fn test_custom_type_parser() { - #[derive(Default, Debug, PartialEq, Eq)] - struct SwappedPair(B, A); - impl<'frame, A, B> DeserializeValue<'frame> for SwappedPair - where - A: DeserializeValue<'frame>, - B: DeserializeValue<'frame>, - { - fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { - <(B, A) as DeserializeValue<'frame>>::type_check(typ) - } - - fn deserialize( - typ: &'frame ColumnType, - v: Option>, - ) -> Result { - <(B, A) as DeserializeValue<'frame>>::deserialize(typ, v).map(|(b, a)| Self(b, a)) - } - } - - let mut tuple_contents = BytesMut::new(); - append_bytes(&mut tuple_contents, "foo".as_bytes()); - append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); - let tuple = make_bytes(&tuple_contents); - - let typ = ColumnType::Tuple(vec![ColumnType::Ascii, ColumnType::Int]); - - let tup = deserialize::>(&typ, &tuple).unwrap(); - assert_eq!(tup, SwappedPair("foo", 42)); - } - - fn deserialize<'frame, T>( - typ: &'frame ColumnType, - bytes: &'frame Bytes, - ) -> Result - where - T: DeserializeValue<'frame>, - { - >::type_check(typ) - .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; - let mut frame_slice = FrameSlice::new(bytes); - let value = frame_slice.read_cql_bytes().map_err(|err| { - mk_deser_err::( - typ, - BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), - ) - })?; - >::deserialize(typ, value) - } - - fn make_bytes(cell: &[u8]) -> Bytes { - let mut b = BytesMut::new(); - append_bytes(&mut b, cell); - b.freeze() - } - - fn serialize(typ: &ColumnType, value: &dyn SerializeValue) -> Bytes { - let mut bytes = Bytes::new(); - serialize_to_buf(typ, value, &mut bytes); - bytes - } - - fn serialize_to_buf(typ: &ColumnType, value: &dyn SerializeValue, buf: &mut Bytes) { - let mut v = Vec::new(); - let writer = CellWriter::new(&mut v); - value.serialize(typ, writer).unwrap(); - *buf = v.into(); - } - - fn append_bytes(b: &mut impl BufMut, cell: &[u8]) { - b.put_i32(cell.len() as i32); - b.put_slice(cell); - } - - fn make_null() -> Bytes { - let mut b = BytesMut::new(); - append_null(&mut b); - b.freeze() - } - - fn append_null(b: &mut impl BufMut) { - b.put_i32(-1); - } - - fn assert_ser_de_identity<'f, T: SerializeValue + DeserializeValue<'f> + PartialEq + Debug>( - typ: &'f ColumnType, - v: &'f T, - buf: &'f mut Bytes, // `buf` must be passed as a reference from outside, because otherwise - // we cannot specify the lifetime for DeserializeValue. - ) { - serialize_to_buf(typ, v, buf); - let deserialized = deserialize::(typ, buf).unwrap(); - assert_eq!(&deserialized, v); - } - - /* Errors checks */ - - #[track_caller] - pub(crate) fn get_typeck_err_inner<'a>( - err: &'a (dyn std::error::Error + 'static), - ) -> &'a BuiltinTypeCheckError { - match err.downcast_ref() { - Some(err) => err, - None => panic!("not a BuiltinTypeCheckError: {:?}", err), - } - } - - #[track_caller] - pub(crate) fn get_typeck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { - get_typeck_err_inner(err.0.as_ref()) - } - - #[track_caller] - pub(crate) fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { - match err.0.downcast_ref() { - Some(err) => err, - None => panic!("not a BuiltinDeserializationError: {:?}", err), - } - } - - macro_rules! assert_given_error { - ($get_err:ident, $bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { - let cql_typ = $cql_typ.clone(); - let err = deserialize::<$DestT>(&cql_typ, $bytes).unwrap_err(); - let err = $get_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<$DestT>()); - assert_eq!(err.cql_type, cql_typ); - assert_matches::assert_matches!(err.kind, $kind); - }; - } - - macro_rules! assert_type_check_error { - ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { - assert_given_error!(get_typeck_err, $bytes, $DestT, $cql_typ, $kind); - }; - } - - macro_rules! assert_deser_error { - ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { - assert_given_error!(get_deser_err, $bytes, $DestT, $cql_typ, $kind); - }; - } - - #[test] - fn test_native_errors() { - // Simple type mismatch - { - let v = 123_i32; - let bytes = serialize(&ColumnType::Int, &v); - - // Incompatible types render type check error. - assert_type_check_error!( - &bytes, - f64, - ColumnType::Int, - super::BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::Double], - } - ); - - // ColumnType is said to be Double (8 bytes expected), but in reality the serialized form has 4 bytes only. - assert_deser_error!( - &bytes, - f64, - ColumnType::Double, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4, - } - ); - - // ColumnType is said to be Float, but in reality Int was serialized. - // As these types have the same size, though, and every binary number in [0, 2^32] is a valid - // value for both of them, this always succeeds. - { - deserialize::(&ColumnType::Float, &bytes).unwrap(); - } - } - - // str (and also Uuid) are interesting because they accept two types. - { - let v = "Ala ma kota"; - let bytes = serialize(&ColumnType::Ascii, &v); - - assert_type_check_error!( - &bytes, - &str, - ColumnType::Double, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::Ascii, ColumnType::Text], - } - ); - - // ColumnType is said to be BigInt (8 bytes expected), but in reality the serialized form - // (the string) has 11 bytes. - assert_deser_error!( - &bytes, - i64, - ColumnType::BigInt, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 11, // str len - } - ); - } - { - // -126 is not a valid ASCII nor UTF-8 byte. - let v = -126_i8; - let bytes = serialize(&ColumnType::TinyInt, &v); - - assert_deser_error!( - &bytes, - &str, - ColumnType::Ascii, - BuiltinDeserializationErrorKind::ExpectedAscii - ); - - assert_deser_error!( - &bytes, - &str, - ColumnType::Text, - BuiltinDeserializationErrorKind::InvalidUtf8(_) - ); - } - } - - #[test] - fn test_set_or_list_errors() { - // Not a set or list - { - assert_type_check_error!( - &Bytes::new(), - Vec, - ColumnType::Float, - BuiltinTypeCheckErrorKind::SetOrListError( - SetOrListTypeCheckErrorKind::NotSetOrList - ) - ); - - // Type check of Rust set against CQL list must fail, because it would be lossy. - assert_type_check_error!( - &Bytes::new(), - BTreeSet, - ColumnType::List(Box::new(ColumnType::Int)), - BuiltinTypeCheckErrorKind::SetOrListError(SetOrListTypeCheckErrorKind::NotSet) - ); - } - - // Got null - { - type RustTyp = Vec; - let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); - - let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ser_typ); - assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); - } - - // Bad element type - { - assert_type_check_error!( - &Bytes::new(), - Vec, - ColumnType::List(Box::new(ColumnType::Ascii)), - BuiltinTypeCheckErrorKind::SetOrListError( - SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(_) - ) - ); +/// Describes why deserialization of a user defined type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum UdtDeserializationErrorKind { + /// One of the fields failed to deserialize. + FieldDeserializationFailed { + /// Name of the field which failed to deserialize. + field_name: String, - let err = deserialize::>( - &ColumnType::List(Box::new(ColumnType::Varint)), - &Bytes::new(), - ) - .unwrap_err(); - let err = get_typeck_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::Varint)),); - let BuiltinTypeCheckErrorKind::SetOrListError( - SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(ref err), - ) = err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Varint); - assert_matches!( - err.kind, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::BigInt, ColumnType::Counter] - } - ); - } + /// The error that caused the UDT field deserialization to fail. + err: DeserializationError, + }, +} - { - let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); - let v = vec![123_i32]; - let bytes = serialize(&ser_typ, &v); - - { - let err = deserialize::>( - &ColumnType::List(Box::new(ColumnType::BigInt)), - &bytes, - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::BigInt)),); - let BuiltinDeserializationErrorKind::SetOrListError( - SetOrListDeserializationErrorKind::ElementDeserializationFailed(err), - ) = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::BigInt); - assert_matches!( - err.kind, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4 - } - ); +impl Display for UdtDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UdtDeserializationErrorKind::FieldDeserializationFailed { field_name, err } => { + write!(f, "field {field_name} failed to deserialize: {err}") } } } +} - #[test] - fn test_map_errors() { - // Not a map - { - let ser_typ = ColumnType::Float; - let v = 2.12_f32; - let bytes = serialize(&ser_typ, &v); - - assert_type_check_error!( - &bytes, - HashMap, - ser_typ, - BuiltinTypeCheckErrorKind::MapError( - MapTypeCheckErrorKind::NotMap, - ) - ); - } - - // Got null - { - type RustTyp = HashMap; - let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); - - let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ser_typ); - assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); - } - - // Key type mismatch - { - let err = deserialize::>( - &ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)), - &Bytes::new(), - ) - .unwrap_err(); - let err = get_typeck_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!( - err.cql_type, - ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)) - ); - let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::KeyTypeCheckFailed( - ref err, - )) = err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Varint); - assert_matches!( - err.kind, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::BigInt, ColumnType::Counter] - } - ); - } - - // Value type mismatch - { - let err = deserialize::>( - &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), - &Bytes::new(), - ) - .unwrap_err(); - let err = get_typeck_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!( - err.cql_type, - ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) - ); - let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::ValueTypeCheckFailed( - ref err, - )) = err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::<&str>()); - assert_eq!(err.cql_type, ColumnType::Boolean); - assert_matches!( - err.kind, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::Ascii, ColumnType::Text] - } - ); - } - - // Key length mismatch - { - let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); - let v = HashMap::from([(42, false), (2137, true)]); - let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); - - let err = deserialize::>( - &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), - &bytes, - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!( - err.cql_type, - ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) - ); - let BuiltinDeserializationErrorKind::MapError( - MapDeserializationErrorKind::KeyDeserializationFailed(err), - ) = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::BigInt); - assert_matches!( - err.kind, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4 - } - ); - } - - // Value length mismatch - { - let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); - let v = HashMap::from([(42, false), (2137, true)]); - let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); - - let err = deserialize::>( - &ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)), - &bytes, - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!( - err.cql_type, - ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)) - ); - let BuiltinDeserializationErrorKind::MapError( - MapDeserializationErrorKind::ValueDeserializationFailed(err), - ) = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::SmallInt); - assert_matches!( - err.kind, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 2, - got: 1 - } - ); - } +impl From for BuiltinDeserializationErrorKind { + fn from(err: UdtDeserializationErrorKind) -> Self { + Self::UdtError(err) } +} - #[test] - fn test_tuple_errors() { - // Not a tuple - { - assert_type_check_error!( - &Bytes::new(), - (i64,), - ColumnType::BigInt, - BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::NotTuple) - ); - } - // Wrong element count - { - assert_type_check_error!( - &Bytes::new(), - (i64,), - ColumnType::Tuple(vec![]), - BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { - rust_type_el_count: 1, - cql_type_el_count: 0, - }) - ); - - assert_type_check_error!( - &Bytes::new(), - (f32,), - ColumnType::Tuple(vec![ColumnType::Float, ColumnType::Float]), - BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { - rust_type_el_count: 1, - cql_type_el_count: 2, - }) - ); - } - - // Bad field type - { - { - let err = deserialize::<(i64,)>( - &ColumnType::Tuple(vec![ColumnType::SmallInt]), - &Bytes::new(), - ) - .unwrap_err(); - let err = get_typeck_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); - assert_eq!(err.cql_type, ColumnType::Tuple(vec![ColumnType::SmallInt])); - let BuiltinTypeCheckErrorKind::TupleError( - TupleTypeCheckErrorKind::FieldTypeCheckFailed { ref err, position }, - ) = err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(position, 0); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::SmallInt); - assert_matches!( - err.kind, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::BigInt, ColumnType::Counter] - } - ); - } - } - - { - let ser_typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Float]); - let v = (123_i32, 123.123_f32); - let bytes = serialize(&ser_typ, &v); - - { - let err = deserialize::<(i32, f64)>( - &ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]), - &bytes, - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<(i32, f64)>()); - assert_eq!( - err.cql_type, - ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]) - ); - let BuiltinDeserializationErrorKind::TupleError( - TupleDeserializationErrorKind::FieldDeserializationFailed { - ref err, - position: index, - }, - ) = err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(index, 1); - let err = get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Double); - assert_matches!( - err.kind, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4 - } - ); - } - } - } +#[cfg(test)] +#[path = "value_tests.rs"] +pub(super) mod tests; - #[test] - fn test_null_errors() { - let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); - let v = HashMap::from([(42, false), (2137, true)]); - let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql, skip_name_checks)] +/// struct TestUdt {} +/// ``` +fn _test_udt_bad_attributes_skip_name_check_requires_enforce_order() {} - deserialize::>(&ser_typ, &bytes).unwrap_err(); - } -} +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql, enforce_order, skip_name_checks)] +/// struct TestUdt { +/// #[scylla(rename = "b")] +/// a: i32, +/// } +/// ``` +fn _test_udt_bad_attributes_skip_name_check_conflicts_with_rename() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql)] +/// struct TestUdt { +/// #[scylla(rename = "b")] +/// a: i32, +/// b: String, +/// } +/// ``` +fn _test_udt_bad_attributes_rename_collision_with_field() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql)] +/// struct TestUdt { +/// #[scylla(rename = "c")] +/// a: i32, +/// #[scylla(rename = "c")] +/// b: String, +/// } +/// ``` +fn _test_udt_bad_attributes_rename_collision_with_another_rename() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql, enforce_order, skip_name_checks)] +/// struct TestUdt { +/// a: i32, +/// #[scylla(allow_missing)] +/// b: bool, +/// c: String, +/// } +/// ``` +fn _test_udt_bad_attributes_name_skip_name_checks_limitations_on_allow_missing() {} + +/// ``` +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql)] +/// struct TestUdt { +/// a: i32, +/// #[scylla(allow_missing)] +/// b: bool, +/// c: String, +/// } +/// ``` +fn _test_udt_unordered_flavour_no_limitations_on_allow_missing() {} diff --git a/scylla-cql/src/types/deserialize/value_tests.rs b/scylla-cql/src/types/deserialize/value_tests.rs new file mode 100644 index 0000000000..9375ce47f6 --- /dev/null +++ b/scylla-cql/src/types/deserialize/value_tests.rs @@ -0,0 +1,1946 @@ +use assert_matches::assert_matches; +use bytes::{BufMut, Bytes, BytesMut}; +use uuid::Uuid; + +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::fmt::Debug; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +use crate::frame::response::result::{ColumnType, CqlValue}; +use crate::frame::value::{ + Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, +}; +use crate::types::deserialize::value::{TupleDeserializationErrorKind, TupleTypeCheckErrorKind}; +use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; +use crate::types::serialize::value::SerializeValue; +use crate::types::serialize::CellWriter; + +use super::{ + mk_deser_err, BuiltinDeserializationError, BuiltinDeserializationErrorKind, + BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, DeserializeValue, ListlikeIterator, + MapDeserializationErrorKind, MapIterator, MapTypeCheckErrorKind, MaybeEmpty, + SetOrListDeserializationErrorKind, SetOrListTypeCheckErrorKind, UdtDeserializationErrorKind, + UdtTypeCheckErrorKind, +}; + +#[test] +fn test_deserialize_bytes() { + const ORIGINAL_BYTES: &[u8] = &[1, 5, 2, 4, 3]; + + let bytes = make_bytes(ORIGINAL_BYTES); + + let decoded_slice = deserialize::<&[u8]>(&ColumnType::Blob, &bytes).unwrap(); + let decoded_vec = deserialize::>(&ColumnType::Blob, &bytes).unwrap(); + let decoded_bytes = deserialize::(&ColumnType::Blob, &bytes).unwrap(); + + assert_eq!(decoded_slice, ORIGINAL_BYTES); + assert_eq!(decoded_vec, ORIGINAL_BYTES); + assert_eq!(decoded_bytes, ORIGINAL_BYTES); + + // ser/de identity + + // Nonempty blob + assert_ser_de_identity(&ColumnType::Blob, &ORIGINAL_BYTES, &mut Bytes::new()); + + // Empty blob + assert_ser_de_identity(&ColumnType::Blob, &(&[] as &[u8]), &mut Bytes::new()); +} + +#[test] +fn test_deserialize_ascii() { + const ASCII_TEXT: &str = "The quick brown fox jumps over the lazy dog"; + + let ascii = make_bytes(ASCII_TEXT.as_bytes()); + + for typ in [ColumnType::Ascii, ColumnType::Text].iter() { + let decoded_str = deserialize::<&str>(typ, &ascii).unwrap(); + let decoded_string = deserialize::(typ, &ascii).unwrap(); + + assert_eq!(decoded_str, ASCII_TEXT); + assert_eq!(decoded_string, ASCII_TEXT); + + // ser/de identity + + // Empty string + assert_ser_de_identity(typ, &"", &mut Bytes::new()); + assert_ser_de_identity(typ, &"".to_owned(), &mut Bytes::new()); + + // Nonempty string + assert_ser_de_identity(typ, &ASCII_TEXT, &mut Bytes::new()); + assert_ser_de_identity(typ, &ASCII_TEXT.to_owned(), &mut Bytes::new()); + } +} + +#[test] +fn test_deserialize_text() { + const UNICODE_TEXT: &str = "Zażółć gęślą jaźń"; + + let unicode = make_bytes(UNICODE_TEXT.as_bytes()); + + // Should fail because it's not an ASCII string + deserialize::<&str>(&ColumnType::Ascii, &unicode).unwrap_err(); + deserialize::(&ColumnType::Ascii, &unicode).unwrap_err(); + + let decoded_text_str = deserialize::<&str>(&ColumnType::Text, &unicode).unwrap(); + let decoded_text_string = deserialize::(&ColumnType::Text, &unicode).unwrap(); + assert_eq!(decoded_text_str, UNICODE_TEXT); + assert_eq!(decoded_text_string, UNICODE_TEXT); + + // ser/de identity + + assert_ser_de_identity(&ColumnType::Text, &UNICODE_TEXT, &mut Bytes::new()); + assert_ser_de_identity( + &ColumnType::Text, + &UNICODE_TEXT.to_owned(), + &mut Bytes::new(), + ); +} + +#[test] +fn test_integral() { + let tinyint = make_bytes(&[0x01]); + let decoded_tinyint = deserialize::(&ColumnType::TinyInt, &tinyint).unwrap(); + assert_eq!(decoded_tinyint, 0x01); + + let smallint = make_bytes(&[0x01, 0x02]); + let decoded_smallint = deserialize::(&ColumnType::SmallInt, &smallint).unwrap(); + assert_eq!(decoded_smallint, 0x0102); + + let int = make_bytes(&[0x01, 0x02, 0x03, 0x04]); + let decoded_int = deserialize::(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, 0x01020304); + + let bigint = make_bytes(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); + let decoded_bigint = deserialize::(&ColumnType::BigInt, &bigint).unwrap(); + assert_eq!(decoded_bigint, 0x0102030405060708); + + // ser/de identity + assert_ser_de_identity(&ColumnType::TinyInt, &42_i8, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::SmallInt, &2137_i16, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Int, &21372137_i32, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::BigInt, &0_i64, &mut Bytes::new()); +} + +#[test] +fn test_bool() { + for boolean in [true, false] { + let boolean_bytes = make_bytes(&[boolean as u8]); + let decoded_bool = deserialize::(&ColumnType::Boolean, &boolean_bytes).unwrap(); + assert_eq!(decoded_bool, boolean); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Boolean, &boolean, &mut Bytes::new()); + } +} + +#[test] +fn test_floating_point() { + let float = make_bytes(&[63, 0, 0, 0]); + let decoded_float = deserialize::(&ColumnType::Float, &float).unwrap(); + assert_eq!(decoded_float, 0.5); + + let double = make_bytes(&[64, 0, 0, 0, 0, 0, 0, 0]); + let decoded_double = deserialize::(&ColumnType::Double, &double).unwrap(); + assert_eq!(decoded_double, 2.0); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Float, &21.37_f32, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Double, &2137.2137_f64, &mut Bytes::new()); +} + +#[test] +fn test_varlen_numbers() { + // varint + assert_ser_de_identity( + &ColumnType::Varint, + &CqlVarint::from_signed_bytes_be_slice(b"Ala ma kota"), + &mut Bytes::new(), + ); + + #[cfg(feature = "num-bigint-03")] + assert_ser_de_identity( + &ColumnType::Varint, + &num_bigint_03::BigInt::from_signed_bytes_be(b"Kot ma Ale"), + &mut Bytes::new(), + ); + + #[cfg(feature = "num-bigint-04")] + assert_ser_de_identity( + &ColumnType::Varint, + &num_bigint_04::BigInt::from_signed_bytes_be(b"Kot ma Ale"), + &mut Bytes::new(), + ); + + // decimal + assert_ser_de_identity( + &ColumnType::Decimal, + &CqlDecimal::from_signed_be_bytes_slice_and_exponent(b"Ala ma kota", 42), + &mut Bytes::new(), + ); + + #[cfg(feature = "bigdecimal-04")] + assert_ser_de_identity( + &ColumnType::Decimal, + &bigdecimal_04::BigDecimal::new( + bigdecimal_04::num_bigint::BigInt::from_signed_bytes_be(b"Ala ma kota"), + 42, + ), + &mut Bytes::new(), + ); +} + +#[test] +fn test_date_time_types() { + // duration + assert_ser_de_identity( + &ColumnType::Duration, + &CqlDuration { + months: 21, + days: 37, + nanoseconds: 42, + }, + &mut Bytes::new(), + ); + + // date + assert_ser_de_identity(&ColumnType::Date, &CqlDate(0xbeaf), &mut Bytes::new()); + + #[cfg(feature = "chrono-04")] + assert_ser_de_identity( + &ColumnType::Date, + &chrono_04::NaiveDate::from_yo_opt(1999, 99).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time-03")] + assert_ser_de_identity( + &ColumnType::Date, + &time_03::Date::from_ordinal_date(1999, 99).unwrap(), + &mut Bytes::new(), + ); + + // time + assert_ser_de_identity(&ColumnType::Time, &CqlTime(0xdeed), &mut Bytes::new()); + + #[cfg(feature = "chrono-04")] + assert_ser_de_identity( + &ColumnType::Time, + &chrono_04::NaiveTime::from_hms_micro_opt(21, 37, 21, 37).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time-03")] + assert_ser_de_identity( + &ColumnType::Time, + &time_03::Time::from_hms_micro(21, 37, 21, 37).unwrap(), + &mut Bytes::new(), + ); + + // timestamp + assert_ser_de_identity( + &ColumnType::Timestamp, + &CqlTimestamp(0xceed), + &mut Bytes::new(), + ); + + #[cfg(feature = "chrono-04")] + assert_ser_de_identity( + &ColumnType::Timestamp, + &chrono_04::DateTime::::from_timestamp_millis(0xdead_cafe_deaf).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time-03")] + assert_ser_de_identity( + &ColumnType::Timestamp, + &time_03::OffsetDateTime::from_unix_timestamp(0xdead_cafe).unwrap(), + &mut Bytes::new(), + ); +} + +#[test] +fn test_inet() { + assert_ser_de_identity( + &ColumnType::Inet, + &IpAddr::V4(Ipv4Addr::BROADCAST), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Inet, + &IpAddr::V6(Ipv6Addr::LOCALHOST), + &mut Bytes::new(), + ); +} + +#[test] +fn test_uuid() { + assert_ser_de_identity( + &ColumnType::Uuid, + &Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Timeuuid, + &CqlTimeuuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + &mut Bytes::new(), + ); +} + +#[test] +fn test_null_and_empty() { + // non-nullable emptiable deserialization, non-empty value + let int = make_bytes(&[21, 37, 0, 0]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, MaybeEmpty::Value((21 << 24) + (37 << 16))); + + // non-nullable emptiable deserialization, empty value + let int = make_bytes(&[]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, MaybeEmpty::Empty); + + // nullable non-emptiable deserialization, non-null value + let int = make_bytes(&[21, 37, 0, 0]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, Some((21 << 24) + (37 << 16))); + + // nullable non-emptiable deserialization, null value + let int = make_null(); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, None); + + // nullable emptiable deserialization, non-null non-empty value + let int = make_bytes(&[]); + let decoded_int = deserialize::>>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, Some(MaybeEmpty::Empty)); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Int, &Some(12321_i32), &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Double, &None::, &mut Bytes::new()); + assert_ser_de_identity( + &ColumnType::Set(Box::new(ColumnType::Ascii)), + &None::>, + &mut Bytes::new(), + ); +} + +#[test] +fn test_maybe_empty() { + let empty = make_bytes(&[]); + let decoded_empty = deserialize::>(&ColumnType::TinyInt, &empty).unwrap(); + assert_eq!(decoded_empty, MaybeEmpty::Empty); + + let non_empty = make_bytes(&[0x01]); + let decoded_non_empty = + deserialize::>(&ColumnType::TinyInt, &non_empty).unwrap(); + assert_eq!(decoded_non_empty, MaybeEmpty::Value(0x01)); +} + +#[test] +fn test_cql_value() { + assert_ser_de_identity( + &ColumnType::Counter, + &CqlValue::Counter(Counter(765)), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Timestamp, + &CqlValue::Timestamp(CqlTimestamp(2136)), + &mut Bytes::new(), + ); + + assert_ser_de_identity(&ColumnType::Boolean, &CqlValue::Empty, &mut Bytes::new()); + + assert_ser_de_identity( + &ColumnType::Text, + &CqlValue::Text("kremówki".to_owned()), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &ColumnType::Ascii, + &CqlValue::Ascii("kremowy".to_owned()), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Set(Box::new(ColumnType::Text)), + &CqlValue::Set(vec![CqlValue::Text("Ala ma kota".to_owned())]), + &mut Bytes::new(), + ); +} + +#[test] +fn test_list_and_set() { + let mut collection_contents = BytesMut::new(); + collection_contents.put_i32(3); + append_bytes(&mut collection_contents, "quick".as_bytes()); + append_bytes(&mut collection_contents, "brown".as_bytes()); + append_bytes(&mut collection_contents, "fox".as_bytes()); + + let collection = make_bytes(&collection_contents); + + let list_typ = ColumnType::List(Box::new(ColumnType::Ascii)); + let set_typ = ColumnType::Set(Box::new(ColumnType::Ascii)); + + // iterator + let mut iter = deserialize::>(&list_typ, &collection).unwrap(); + assert_eq!(iter.next().transpose().unwrap(), Some("quick")); + assert_eq!(iter.next().transpose().unwrap(), Some("brown")); + assert_eq!(iter.next().transpose().unwrap(), Some("fox")); + assert_eq!(iter.next().transpose().unwrap(), None); + + let expected_vec_str = vec!["quick", "brown", "fox"]; + let expected_vec_string = vec!["quick".to_string(), "brown".to_string(), "fox".to_string()]; + + // list + let decoded_vec_str = deserialize::>(&list_typ, &collection).unwrap(); + let decoded_vec_string = deserialize::>(&list_typ, &collection).unwrap(); + assert_eq!(decoded_vec_str, expected_vec_str); + assert_eq!(decoded_vec_string, expected_vec_string); + + // hash set + let decoded_hash_str = deserialize::>(&set_typ, &collection).unwrap(); + let decoded_hash_string = deserialize::>(&set_typ, &collection).unwrap(); + assert_eq!( + decoded_hash_str, + expected_vec_str.clone().into_iter().collect(), + ); + assert_eq!( + decoded_hash_string, + expected_vec_string.clone().into_iter().collect(), + ); + + // btree set + let decoded_btree_str = deserialize::>(&set_typ, &collection).unwrap(); + let decoded_btree_string = deserialize::>(&set_typ, &collection).unwrap(); + assert_eq!( + decoded_btree_str, + expected_vec_str.clone().into_iter().collect(), + ); + assert_eq!( + decoded_btree_string, + expected_vec_string.into_iter().collect(), + ); + + // ser/de identity + assert_ser_de_identity(&list_typ, &vec!["qwik"], &mut Bytes::new()); + assert_ser_de_identity(&set_typ, &vec!["qwik"], &mut Bytes::new()); + assert_ser_de_identity( + &set_typ, + &HashSet::<&str, std::collections::hash_map::RandomState>::from_iter(["qwik"]), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &set_typ, + &BTreeSet::<&str>::from_iter(["qwik"]), + &mut Bytes::new(), + ); +} + +#[test] +fn test_map() { + let mut collection_contents = BytesMut::new(); + collection_contents.put_i32(3); + append_bytes(&mut collection_contents, &1i32.to_be_bytes()); + append_bytes(&mut collection_contents, "quick".as_bytes()); + append_bytes(&mut collection_contents, &2i32.to_be_bytes()); + append_bytes(&mut collection_contents, "brown".as_bytes()); + append_bytes(&mut collection_contents, &3i32.to_be_bytes()); + append_bytes(&mut collection_contents, "fox".as_bytes()); + + let collection = make_bytes(&collection_contents); + + let typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Ascii)); + + // iterator + let mut iter = deserialize::>(&typ, &collection).unwrap(); + assert_eq!(iter.next().transpose().unwrap(), Some((1, "quick"))); + assert_eq!(iter.next().transpose().unwrap(), Some((2, "brown"))); + assert_eq!(iter.next().transpose().unwrap(), Some((3, "fox"))); + assert_eq!(iter.next().transpose().unwrap(), None); + + let expected_str = vec![(1, "quick"), (2, "brown"), (3, "fox")]; + let expected_string = vec![ + (1, "quick".to_string()), + (2, "brown".to_string()), + (3, "fox".to_string()), + ]; + + // hash set + let decoded_hash_str = deserialize::>(&typ, &collection).unwrap(); + let decoded_hash_string = deserialize::>(&typ, &collection).unwrap(); + assert_eq!(decoded_hash_str, expected_str.clone().into_iter().collect()); + assert_eq!( + decoded_hash_string, + expected_string.clone().into_iter().collect(), + ); + + // btree set + let decoded_btree_str = deserialize::>(&typ, &collection).unwrap(); + let decoded_btree_string = deserialize::>(&typ, &collection).unwrap(); + assert_eq!( + decoded_btree_str, + expected_str.clone().into_iter().collect(), + ); + assert_eq!(decoded_btree_string, expected_string.into_iter().collect()); + + // ser/de identity + assert_ser_de_identity( + &typ, + &HashMap::::from_iter([(-42, "qwik")]), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &typ, + &BTreeMap::::from_iter([(-42, "qwik")]), + &mut Bytes::new(), + ); +} + +#[test] +fn test_tuples() { + let mut tuple_contents = BytesMut::new(); + append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); + append_bytes(&mut tuple_contents, "foo".as_bytes()); + append_null(&mut tuple_contents); + + let tuple = make_bytes(&tuple_contents); + + let typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Ascii, ColumnType::Uuid]); + + let tup = deserialize::<(i32, &str, Option)>(&typ, &tuple).unwrap(); + assert_eq!(tup, (42, "foo", None)); + + // ser/de identity + + // () does not implement SerializeValue, yet it does implement DeserializeValue. + // assert_ser_de_identity(&ColumnType::Tuple(vec![]), &(), &mut Bytes::new()); + + // nonempty, varied tuple + assert_ser_de_identity( + &ColumnType::Tuple(vec![ + ColumnType::List(Box::new(ColumnType::Boolean)), + ColumnType::BigInt, + ColumnType::Uuid, + ColumnType::Inet, + ]), + &( + vec![true, false, true], + 42_i64, + Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + IpAddr::V6(Ipv6Addr::new(0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11)), + ), + &mut Bytes::new(), + ); + + // nested tuples + assert_ser_de_identity( + &ColumnType::Tuple(vec![ColumnType::Tuple(vec![ColumnType::Tuple(vec![ + ColumnType::Text, + ])])]), + &((("",),),), + &mut Bytes::new(), + ); +} + +fn udt_def_with_fields( + fields: impl IntoIterator, ColumnType)>, +) -> ColumnType { + ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: fields.into_iter().map(|(s, t)| (s.into(), t)).collect(), + } +} + +#[must_use] +struct UdtSerializer { + buf: BytesMut, +} + +impl UdtSerializer { + fn new() -> Self { + Self { + buf: BytesMut::default(), + } + } + + fn field(mut self, field_bytes: &[u8]) -> Self { + append_bytes(&mut self.buf, field_bytes); + self + } + + fn null_field(mut self) -> Self { + append_null(&mut self.buf); + self + } + + fn finalize(&self) -> Bytes { + make_bytes(&self.buf) + } +} + +// Do not remove. It's not used in tests but we keep it here to check that +// we properly ignore warnings about unused variables, unnecessary `mut`s +// etc. that usually pop up when generating code for empty structs. +#[allow(unused)] +#[derive(scylla_macros::DeserializeValue)] +#[scylla(crate = crate)] +struct TestUdtWithNoFieldsUnordered {} + +#[allow(unused)] +#[derive(scylla_macros::DeserializeValue)] +#[scylla(crate = crate, enforce_order)] +struct TestUdtWithNoFieldsOrdered {} + +#[test] +fn test_udt_loose_ordering() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + #[scylla(allow_missing)] + b: Option, + #[scylla(default_when_null)] + c: i64, + } + + // UDT fields in correct same order. + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42_i32.to_be_bytes()) + .field(&2137_i64.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::BigInt), + ]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + c: 2137, + } + ); + } + + // The last two UDT field are missing in serialized form - it should treat it + // as if there were nulls at the end. + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::BigInt), + ]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + c: 0, + } + ); + } + + // UDT fields switched - should still work. + { + let udt_bytes = UdtSerializer::new() + .field(&42_i32.to_be_bytes()) + .field("The quick brown fox".as_bytes()) + .field(&2137_i64.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("b", ColumnType::Int), + ("a", ColumnType::Text), + ("c", ColumnType::BigInt), + ]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + c: 2137, + } + ); + } + + // An excess UDT field - should still work. + { + let udt_bytes = UdtSerializer::new() + .field(&12_i8.to_be_bytes()) + .field(&42_i32.to_be_bytes()) + .field("The quick brown fox".as_bytes()) + .field(&2137_i64.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("d", ColumnType::TinyInt), + ("b", ColumnType::Int), + ("a", ColumnType::Text), + ("c", ColumnType::BigInt), + ]); + + Udt::type_check(&typ).unwrap(); + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + c: 2137, + } + ); + } + + // Only field 'a' is present + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("c", ColumnType::BigInt)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + c: 0, + } + ); + } + + // Wrong column type + { + let typ = udt_def_with_fields([("a", ColumnType::Text)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Missing required column + { + let typ = udt_def_with_fields([("b", ColumnType::Int)]); + Udt::type_check(&typ).unwrap_err(); + } +} + +#[test] +fn test_udt_strict_ordering() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct Udt<'a> { + #[scylla(default_when_null)] + a: &'a str, + #[scylla(skip)] + x: String, + #[scylla(allow_missing)] + b: Option, + } + + // UDT fields in correct same order + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + + // The last UDT field is missing in serialized form - it should treat + // as if there were null at the end + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + } + ); + } + + // An excess field at the end of UDT + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42_i32.to_be_bytes()) + .field(&(true as i8).to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("d", ColumnType::Boolean), + ]); + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + + // An excess field at the end of UDT, when such are forbidden + { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, forbid_excess_udt_fields)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + } + + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("d", ColumnType::Boolean), + ]); + + Udt::type_check(&typ).unwrap_err(); + } + + // UDT fields switched - will not work + { + let typ = udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Wrong column type + { + let typ = udt_def_with_fields([("a", ColumnType::Int), ("b", ColumnType::Int)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Missing required column + { + let typ = udt_def_with_fields([("b", ColumnType::Int)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Missing non-required column + { + let udt_bytes = UdtSerializer::new().field(b"kotmaale").finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "kotmaale", + x: String::new(), + b: None, + } + ); + } + + // The first field is null, but `default_when_null` prevents failure. + { + let udt_bytes = UdtSerializer::new() + .null_field() + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "", + x: String::new(), + b: Some(42), + } + ); + } +} + +#[test] +fn test_udt_no_name_check() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, skip_name_checks)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + } + + // UDT fields in correct same order + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + + // Correct order of UDT fields, but different names - should still succeed + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("k", ColumnType::Text), ("l", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } +} + +#[test] +fn test_udt_cross_rename_fields() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = crate)] + struct TestUdt { + #[scylla(rename = "b")] + a: i32, + #[scylla(rename = "a")] + b: String, + } + + // UDT fields switched - should still work. + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42_i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + TestUdt { + a: 42, + b: "The quick brown fox".to_owned(), + } + ); + } +} + +#[test] +fn test_custom_type_parser() { + #[derive(Default, Debug, PartialEq, Eq)] + struct SwappedPair(B, A); + impl<'frame, A, B> DeserializeValue<'frame> for SwappedPair + where + A: DeserializeValue<'frame>, + B: DeserializeValue<'frame>, + { + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + <(B, A) as DeserializeValue<'frame>>::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + <(B, A) as DeserializeValue<'frame>>::deserialize(typ, v).map(|(b, a)| Self(b, a)) + } + } + + let mut tuple_contents = BytesMut::new(); + append_bytes(&mut tuple_contents, "foo".as_bytes()); + append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); + let tuple = make_bytes(&tuple_contents); + + let typ = ColumnType::Tuple(vec![ColumnType::Ascii, ColumnType::Int]); + + let tup = deserialize::>(&typ, &tuple).unwrap(); + assert_eq!(tup, SwappedPair("foo", 42)); +} + +fn deserialize<'frame, T>( + typ: &'frame ColumnType, + bytes: &'frame Bytes, +) -> Result +where + T: DeserializeValue<'frame>, +{ + >::type_check(typ) + .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; + let mut frame_slice = FrameSlice::new(bytes); + let value = frame_slice.read_cql_bytes().map_err(|err| { + mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), + ) + })?; + >::deserialize(typ, value) +} + +fn make_bytes(cell: &[u8]) -> Bytes { + let mut b = BytesMut::new(); + append_bytes(&mut b, cell); + b.freeze() +} + +fn serialize(typ: &ColumnType, value: &dyn SerializeValue) -> Bytes { + let mut bytes = Bytes::new(); + serialize_to_buf(typ, value, &mut bytes); + bytes +} + +fn serialize_to_buf(typ: &ColumnType, value: &dyn SerializeValue, buf: &mut Bytes) { + let mut v = Vec::new(); + let writer = CellWriter::new(&mut v); + value.serialize(typ, writer).unwrap(); + *buf = v.into(); +} + +fn append_bytes(b: &mut impl BufMut, cell: &[u8]) { + b.put_i32(cell.len() as i32); + b.put_slice(cell); +} + +fn make_null() -> Bytes { + let mut b = BytesMut::new(); + append_null(&mut b); + b.freeze() +} + +fn append_null(b: &mut impl BufMut) { + b.put_i32(-1); +} + +fn assert_ser_de_identity<'f, T: SerializeValue + DeserializeValue<'f> + PartialEq + Debug>( + typ: &'f ColumnType, + v: &'f T, + buf: &'f mut Bytes, // `buf` must be passed as a reference from outside, because otherwise + // we cannot specify the lifetime for DeserializeValue. +) { + serialize_to_buf(typ, v, buf); + let deserialized = deserialize::(typ, buf).unwrap(); + assert_eq!(&deserialized, v); +} + +/* Errors checks */ + +#[track_caller] +pub(crate) fn get_typeck_err_inner<'a>( + err: &'a (dyn std::error::Error + 'static), +) -> &'a BuiltinTypeCheckError { + match err.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinTypeCheckError: {:?}", err), + } +} + +#[track_caller] +pub(crate) fn get_typeck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { + get_typeck_err_inner(err.0.as_ref()) +} + +#[track_caller] +pub(crate) fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { + match err.0.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinDeserializationError: {:?}", err), + } +} + +macro_rules! assert_given_error { + ($get_err:ident, $bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + let cql_typ = $cql_typ.clone(); + let err = deserialize::<$DestT>(&cql_typ, $bytes).unwrap_err(); + let err = $get_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<$DestT>()); + assert_eq!(err.cql_type, cql_typ); + assert_matches::assert_matches!(err.kind, $kind); + }; +} + +macro_rules! assert_type_check_error { + ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + assert_given_error!(get_typeck_err, $bytes, $DestT, $cql_typ, $kind); + }; +} + +macro_rules! assert_deser_error { + ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + assert_given_error!(get_deser_err, $bytes, $DestT, $cql_typ, $kind); + }; +} + +#[test] +fn test_native_errors() { + // Simple type mismatch + { + let v = 123_i32; + let bytes = serialize(&ColumnType::Int, &v); + + // Incompatible types render type check error. + assert_type_check_error!( + &bytes, + f64, + ColumnType::Int, + super::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Double], + } + ); + + // ColumnType is said to be Double (8 bytes expected), but in reality the serialized form has 4 bytes only. + assert_deser_error!( + &bytes, + f64, + ColumnType::Double, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4, + } + ); + + // ColumnType is said to be Float, but in reality Int was serialized. + // As these types have the same size, though, and every binary number in [0, 2^32] is a valid + // value for both of them, this always succeeds. + { + deserialize::(&ColumnType::Float, &bytes).unwrap(); + } + } + + // str (and also Uuid) are interesting because they accept two types. + { + let v = "Ala ma kota"; + let bytes = serialize(&ColumnType::Ascii, &v); + + assert_type_check_error!( + &bytes, + &str, + ColumnType::Double, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text], + } + ); + + // ColumnType is said to be BigInt (8 bytes expected), but in reality the serialized form + // (the string) has 11 bytes. + assert_deser_error!( + &bytes, + i64, + ColumnType::BigInt, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 11, // str len + } + ); + } + { + // -126 is not a valid ASCII nor UTF-8 byte. + let v = -126_i8; + let bytes = serialize(&ColumnType::TinyInt, &v); + + assert_deser_error!( + &bytes, + &str, + ColumnType::Ascii, + BuiltinDeserializationErrorKind::ExpectedAscii + ); + + assert_deser_error!( + &bytes, + &str, + ColumnType::Text, + BuiltinDeserializationErrorKind::InvalidUtf8(_) + ); + } +} + +#[test] +fn test_set_or_list_errors() { + // Not a set or list + { + assert_type_check_error!( + &Bytes::new(), + Vec, + ColumnType::Float, + BuiltinTypeCheckErrorKind::SetOrListError(SetOrListTypeCheckErrorKind::NotSetOrList) + ); + + // Type check of Rust set against CQL list must fail, because it would be lossy. + assert_type_check_error!( + &Bytes::new(), + BTreeSet, + ColumnType::List(Box::new(ColumnType::Int)), + BuiltinTypeCheckErrorKind::SetOrListError(SetOrListTypeCheckErrorKind::NotSet) + ); + } + + // Got null + { + type RustTyp = Vec; + let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); + + let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ser_typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Bad element type + { + assert_type_check_error!( + &Bytes::new(), + Vec, + ColumnType::List(Box::new(ColumnType::Ascii)), + BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(_) + ) + ); + + let err = deserialize::>( + &ColumnType::List(Box::new(ColumnType::Varint)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::Varint)),); + let BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(ref err), + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Varint); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + { + let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); + let v = vec![123_i32]; + let bytes = serialize(&ser_typ, &v); + + { + let err = + deserialize::>(&ColumnType::List(Box::new(ColumnType::BigInt)), &bytes) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::BigInt)),); + let BuiltinDeserializationErrorKind::SetOrListError( + SetOrListDeserializationErrorKind::ElementDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + } +} + +#[test] +fn test_map_errors() { + // Not a map + { + let ser_typ = ColumnType::Float; + let v = 2.12_f32; + let bytes = serialize(&ser_typ, &v); + + assert_type_check_error!( + &bytes, + HashMap, + ser_typ, + BuiltinTypeCheckErrorKind::MapError( + MapTypeCheckErrorKind::NotMap, + ) + ); + } + + // Got null + { + type RustTyp = HashMap; + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + + let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ser_typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Key type mismatch + { + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)) + ); + let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::KeyTypeCheckFailed(ref err)) = + err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Varint); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + // Value type mismatch + { + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) + ); + let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::ValueTypeCheckFailed( + ref err, + )) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Boolean); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + + // Key length mismatch + { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) + ); + let BuiltinDeserializationErrorKind::MapError( + MapDeserializationErrorKind::KeyDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + + // Value length mismatch + { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)) + ); + let BuiltinDeserializationErrorKind::MapError( + MapDeserializationErrorKind::ValueDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::SmallInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 2, + got: 1 + } + ); + } +} + +#[test] +fn test_tuple_errors() { + // Not a tuple + { + assert_type_check_error!( + &Bytes::new(), + (i64,), + ColumnType::BigInt, + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::NotTuple) + ); + } + // Wrong element count + { + assert_type_check_error!( + &Bytes::new(), + (i64,), + ColumnType::Tuple(vec![]), + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count: 1, + cql_type_el_count: 0, + }) + ); + + assert_type_check_error!( + &Bytes::new(), + (f32,), + ColumnType::Tuple(vec![ColumnType::Float, ColumnType::Float]), + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count: 1, + cql_type_el_count: 2, + }) + ); + } + + // Bad field type + { + { + let err = deserialize::<(i64,)>( + &ColumnType::Tuple(vec![ColumnType::SmallInt]), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + assert_eq!(err.cql_type, ColumnType::Tuple(vec![ColumnType::SmallInt])); + let BuiltinTypeCheckErrorKind::TupleError( + TupleTypeCheckErrorKind::FieldTypeCheckFailed { ref err, position }, + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(position, 0); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::SmallInt); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + } + + { + let ser_typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Float]); + let v = (123_i32, 123.123_f32); + let bytes = serialize(&ser_typ, &v); + + { + let err = deserialize::<(i32, f64)>( + &ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i32, f64)>()); + assert_eq!( + err.cql_type, + ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]) + ); + let BuiltinDeserializationErrorKind::TupleError( + TupleDeserializationErrorKind::FieldDeserializationFailed { + ref err, + position: index, + }, + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(index, 1); + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Double); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + } +} + +#[test] +fn test_null_errors() { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + deserialize::>(&ser_typ, &bytes).unwrap_err(); +} + +#[test] +fn test_udt_errors() { + // Loose ordering + { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", forbid_excess_udt_fields)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + #[scylla(allow_missing)] + b: Option, + #[scylla(default_when_null)] + c: bool, + } + + // Type check errors + { + // Not UDT + { + let typ = ColumnType::Map(Box::new(ColumnType::Ascii), Box::new(ColumnType::Blob)); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + } + + // UDT missing fields + { + let typ = udt_def_with_fields([("c", ColumnType::Boolean)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::ValuesMissingForUdtFields { + field_names: ref missing_fields, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(missing_fields.as_slice(), &["a"]); + } + + // excess fields in UDT + { + let typ = udt_def_with_fields([ + ("d", ColumnType::Boolean), + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::ExcessFieldInUdt { + ref db_field_name, + }) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(db_field_name.as_str(), "d"); + } + + // missing UDT field + { + let typ = udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::ValuesMissingForUdtFields { ref field_names }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_names, &["c"]); + } + + // UDT fields incompatible types - field type check failed + { + let typ = udt_def_with_fields([("a", ColumnType::Blob), ("b", ColumnType::Int)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::FieldTypeCheckFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "a"); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Got null + { + let typ = udt_def_with_fields([ + ("c", ColumnType::Boolean), + ("a", ColumnType::Blob), + ("b", ColumnType::Int), + ]); + + let err = Udt::deserialize(&typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // UDT field deserialization failed + { + let typ = + udt_def_with_fields([("a", ColumnType::Ascii), ("c", ColumnType::Boolean)]); + + let udt_bytes = UdtSerializer::new() + .field("alamakota".as_bytes()) + .field(&42_i16.to_be_bytes()) + .finalize(); + + let err = deserialize::(&typ, &udt_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinDeserializationErrorKind::UdtError( + UdtDeserializationErrorKind::FieldDeserializationFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "c"); + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Boolean); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 1, + got: 2, + } + ); + } + } + } + + // Strict ordering + { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, forbid_excess_udt_fields)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + #[scylla(allow_missing)] + c: bool, + } + + // Type check errors + { + // Not UDT + { + let typ = ColumnType::Map(Box::new(ColumnType::Ascii), Box::new(ColumnType::Blob)); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + } + + // UDT too few fields + { + let typ = udt_def_with_fields([("a", ColumnType::Text)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::TooFewFields { + ref required_fields, + ref present_fields, + }) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(required_fields.as_slice(), &["a", "b"]); + assert_eq!(present_fields.as_slice(), &["a".to_string()]); + } + + // excess fields in UDT + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("d", ColumnType::Boolean), + ]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::ExcessFieldInUdt { + ref db_field_name, + }) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(db_field_name.as_str(), "d"); + } + + // UDT fields switched - field name mismatch + { + let typ = udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::FieldNameMismatch { + position, + ref rust_field_name, + ref db_field_name, + }) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(position, 0); + assert_eq!(rust_field_name.as_str(), "a".to_owned()); + assert_eq!(db_field_name.as_str(), "b".to_owned()); + } + + // UDT fields incompatible types - field type check failed + { + let typ = udt_def_with_fields([("a", ColumnType::Blob), ("b", ColumnType::Int)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::FieldTypeCheckFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "a"); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Got null + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ]); + + let err = Udt::deserialize(&typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Bad field format + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ]); + + let udt_bytes = UdtSerializer::new() + .field(b"alamakota") + .field(&42_i64.to_be_bytes()) + .field(&[true as u8]) + .finalize(); + + let udt_bytes_too_short = udt_bytes.slice(..udt_bytes.len() - 1); + assert!(udt_bytes.len() > udt_bytes_too_short.len()); + + let err = deserialize::(&typ, &udt_bytes_too_short).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinDeserializationErrorKind::RawCqlBytesReadError(_) = err.kind else { + panic!("unexpected error kind: {:?}", err.kind) + }; + } + + // UDT field deserialization failed + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ]); + + let udt_bytes = UdtSerializer::new() + .field(b"alamakota") + .field(&42_i64.to_be_bytes()) + .field(&[true as u8]) + .finalize(); + + let err = deserialize::(&typ, &udt_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinDeserializationErrorKind::UdtError( + UdtDeserializationErrorKind::FieldDeserializationFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "b"); + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Int); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 4, + got: 8, + } + ); + } + } + } +} diff --git a/scylla-macros/src/deserialize/mod.rs b/scylla-macros/src/deserialize/mod.rs new file mode 100644 index 0000000000..9b038ecfb6 --- /dev/null +++ b/scylla-macros/src/deserialize/mod.rs @@ -0,0 +1,183 @@ +use darling::{FromAttributes, FromField}; +use proc_macro2::Span; +use syn::parse_quote; + +pub(crate) mod row; +pub(crate) mod value; + +/// Common attributes that all deserialize impls should understand. +trait DeserializeCommonStructAttrs { + /// The path to either `scylla` or `scylla_cql` crate. + fn crate_path(&self) -> Option<&syn::Path>; + + /// The path to `macro_internal` module, + /// which contains exports used by macros. + fn macro_internal_path(&self) -> syn::Path { + match self.crate_path() { + Some(path) => parse_quote!(#path::_macro_internal), + None => parse_quote!(scylla::_macro_internal), + } + } +} + +/// Provides access to attributes that are common to DeserializeValue +/// and DeserializeRow traits. +trait DeserializeCommonFieldAttrs { + /// Does the type of this field need Default to be implemented? + fn needs_default(&self) -> bool; + + /// The type of the field, i.e. what this field deserializes to. + fn deserialize_target(&self) -> &syn::Type; +} + +/// A structure helpful in implementing DeserializeValue and DeserializeRow. +/// +/// It implements some common logic for both traits: +/// - Generates a unique lifetime that binds all other lifetimes in both structs, +/// - Adds appropriate trait bounds (DeserializeValue + Default) +struct StructDescForDeserialize { + name: syn::Ident, + attrs: Attrs, + fields: Vec, + constraint_trait: syn::Path, + constraint_lifetime: syn::Lifetime, + + generics: syn::Generics, +} + +impl StructDescForDeserialize +where + Attrs: FromAttributes + DeserializeCommonStructAttrs, + Field: FromField + DeserializeCommonFieldAttrs, +{ + fn new( + input: &syn::DeriveInput, + trait_name: &str, + constraint_trait: syn::Path, + ) -> Result { + let attrs = Attrs::from_attributes(&input.attrs)?; + + // TODO: support structs with unnamed fields. + // A few things to consider: + // - such support would necessarily require `enforce_order` and `skip_name_checks` attributes to be passed, + // - either: + // - the inner code would have to represent unnamed fields differently and handle the errors differently, + // - or we could use `.0, .1` or `0`, `1` as names for consecutive fields, making representation and error handling uniform. + let fields = crate::parser::parse_named_fields(input, trait_name) + .unwrap_or_else(|err| panic!("{}", err)) + .named + .iter() + .map(Field::from_field) + .collect::>()?; + + let constraint_lifetime = generate_unique_lifetime_for_impl(&input.generics); + + Ok(Self { + name: input.ident.clone(), + attrs, + fields, + constraint_trait, + constraint_lifetime, + generics: input.generics.clone(), + }) + } + + fn struct_attrs(&self) -> &Attrs { + &self.attrs + } + + fn constraint_lifetime(&self) -> &syn::Lifetime { + &self.constraint_lifetime + } + + fn fields(&self) -> &[Field] { + &self.fields + } + + fn generate_impl( + &self, + trait_: syn::Path, + items: impl IntoIterator, + ) -> syn::ItemImpl { + let constraint_lifetime = &self.constraint_lifetime; + let (_, ty_generics, _) = self.generics.split_for_impl(); + let impl_generics = &self.generics.params; + + let macro_internal = self.attrs.macro_internal_path(); + let struct_name = &self.name; + let predicates = generate_lifetime_constraints_for_impl( + &self.generics, + self.constraint_trait.clone(), + &self.constraint_lifetime, + ) + .chain(generate_default_constraints(&self.fields)); + let trait_: syn::Path = parse_quote!(#macro_internal::#trait_); + let items = items.into_iter(); + + parse_quote! { + impl<#constraint_lifetime, #impl_generics> #trait_<#constraint_lifetime> for #struct_name #ty_generics + where #(#predicates),* + { + #(#items)* + } + } + } +} + +/// Generates T: Default constraints for those fields that need it. +fn generate_default_constraints( + fields: &[Field], +) -> impl Iterator + '_ { + fields.iter().filter(|f| f.needs_default()).map(|f| { + let t = f.deserialize_target(); + parse_quote!(#t: std::default::Default) + }) +} + +/// Helps introduce a lifetime to an `impl` definition that constrains +/// other lifetimes and types. +/// +/// The original use case is DeserializeValue and DeserializeRow. Both of those traits +/// are parametrized with a lifetime. If T: DeserializeValue<'a> then this means +/// that you can deserialize T as some CQL value from bytes that have +/// lifetime 'a, similarly for DeserializeRow. In impls for those traits, +/// an additional lifetime must be introduced and properly constrained. +fn generate_lifetime_constraints_for_impl<'a>( + generics: &'a syn::Generics, + trait_full_name: syn::Path, + constraint_lifetime: &'a syn::Lifetime, +) -> impl Iterator + 'a { + // Constrain the new lifetime with the existing lifetime parameters + // 'lifetime: 'a + 'b + 'c ... + let mut lifetimes = generics.lifetimes().map(|l| &l.lifetime).peekable(); + let lifetime_constraints = std::iter::from_fn(move || { + let lifetimes = lifetimes.by_ref(); + lifetimes + .peek() + .is_some() + .then::(|| parse_quote!(#constraint_lifetime: #(#lifetimes)+*)) + }); + + // For each type parameter T, constrain it like this: + // T: DeserializeValue<'lifetime>, + let type_constraints = generics.type_params().map(move |t| { + let t_ident = &t.ident; + parse_quote!(#t_ident: #trait_full_name<#constraint_lifetime>) + }); + + lifetime_constraints.chain(type_constraints) +} + +/// Generates a new lifetime parameter, with a different name to any of the +/// existing generic lifetimes. +fn generate_unique_lifetime_for_impl(generics: &syn::Generics) -> syn::Lifetime { + let mut constraint_lifetime_name = "'lifetime".to_string(); + while generics + .lifetimes() + .any(|l| l.lifetime.to_string() == constraint_lifetime_name) + { + // Extend the lifetime name with another underscore. + constraint_lifetime_name += "_"; + } + syn::Lifetime::new(&constraint_lifetime_name, Span::call_site()) +} diff --git a/scylla-macros/src/deserialize/row.rs b/scylla-macros/src/deserialize/row.rs new file mode 100644 index 0000000000..1a43c8b343 --- /dev/null +++ b/scylla-macros/src/deserialize/row.rs @@ -0,0 +1,611 @@ +use std::collections::HashMap; + +use darling::{FromAttributes, FromField}; +use proc_macro2::Span; +use syn::ext::IdentExt; +use syn::parse_quote; + +use super::{DeserializeCommonFieldAttrs, DeserializeCommonStructAttrs}; + +#[derive(FromAttributes)] +#[darling(attributes(scylla))] +struct StructAttrs { + #[darling(rename = "crate")] + crate_path: Option, + + // If true, then the type checking code will require the order of the fields + // to be the same in both the Rust struct and the columns. This allows the + // deserialization to be slightly faster because looking struct fields up + // by name can be avoided, though it is less convenient. + #[darling(default)] + enforce_order: bool, + + // If true, then the type checking code won't verify the column names. + // Columns will be matched to struct fields based solely on the order. + // + // This annotation only works if `enforce_order` is specified. + #[darling(default)] + skip_name_checks: bool, +} + +impl DeserializeCommonStructAttrs for StructAttrs { + fn crate_path(&self) -> Option<&syn::Path> { + self.crate_path.as_ref() + } +} + +#[derive(FromField)] +#[darling(attributes(scylla))] +struct Field { + // If true, then the field is not parsed at all, but it is initialized + // with Default::default() instead. All other attributes are ignored. + #[darling(default)] + skip: bool, + + // If set, then deserialization will look for the column with given name + // and deserialize it to this Rust field, instead of just using the Rust + // field name. + #[darling(default)] + rename: Option, + + ident: Option, + ty: syn::Type, +} + +impl DeserializeCommonFieldAttrs for Field { + fn needs_default(&self) -> bool { + self.skip + } + + fn deserialize_target(&self) -> &syn::Type { + &self.ty + } +} + +// derive(DeserializeRow) for the new DeserializeRow trait +pub(crate) fn deserialize_row_derive( + tokens_input: proc_macro::TokenStream, +) -> Result { + let input = syn::parse(tokens_input)?; + + let implemented_trait: syn::Path = parse_quote! { DeserializeRow }; + let implemented_trait_name = implemented_trait + .segments + .last() + .unwrap() + .ident + .unraw() + .to_string(); + let constraining_trait = parse_quote! { DeserializeValue }; + let s = StructDesc::new(&input, &implemented_trait_name, constraining_trait)?; + + validate_attrs(&s.attrs, &s.fields)?; + + let items = [ + s.generate_type_check_method().into(), + s.generate_deserialize_method().into(), + ]; + + Ok(s.generate_impl(implemented_trait, items)) +} + +fn validate_attrs(attrs: &StructAttrs, fields: &[Field]) -> Result<(), darling::Error> { + let mut errors = darling::Error::accumulator(); + + if attrs.skip_name_checks { + // Skipping name checks is only available in enforce_order mode + if !attrs.enforce_order { + let error = + darling::Error::custom("attribute requires ."); + errors.push(error); + } + + // annotations don't make sense with skipped name checks + for field in fields { + if field.rename.is_some() { + let err = darling::Error::custom( + " annotations don't make sense with attribute", + ) + .with_span(&field.ident); + errors.push(err); + } + } + } else { + // Detect name collisions caused by `rename`. + let mut used_names = HashMap::::new(); + for field in fields { + let column_name = field.column_name(); + if let Some(other_field) = used_names.get(&column_name) { + let other_field_ident = other_field.ident.as_ref().unwrap(); + let msg = format!("the column name `{column_name}` used by this struct field is already used by field `{other_field_ident}`"); + let err = darling::Error::custom(msg).with_span(&field.ident); + errors.push(err); + } else { + used_names.insert(column_name, field); + } + } + } + + errors.finish() +} + +impl Field { + // Returns whether this field is mandatory for deserialization. + fn is_required(&self) -> bool { + !self.skip + } + + // The name of the column corresponding to this Rust struct field + fn column_name(&self) -> String { + match self.rename.as_ref() { + Some(rename) => rename.to_owned(), + None => self.ident.as_ref().unwrap().unraw().to_string(), + } + } + + // A Rust literal representing the name of this field + fn cql_name_literal(&self) -> syn::LitStr { + syn::LitStr::new(&self.column_name(), Span::call_site()) + } +} + +type StructDesc = super::StructDescForDeserialize; + +impl StructDesc { + fn generate_type_check_method(&self) -> syn::ImplItemFn { + if self.attrs.enforce_order { + TypeCheckAssumeOrderGenerator(self).generate() + } else { + TypeCheckUnorderedGenerator(self).generate() + } + } + + fn generate_deserialize_method(&self) -> syn::ImplItemFn { + if self.attrs.enforce_order { + DeserializeAssumeOrderGenerator(self).generate() + } else { + DeserializeUnorderedGenerator(self).generate() + } + } +} + +struct TypeCheckAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { + fn generate_name_verification( + &self, + field_index: usize, // These two indices can be different because of `skip` attribute + column_index: usize, // applied to some field. + field: &Field, + column_spec: &syn::Ident, + ) -> Option { + (!self.0.attrs.skip_name_checks).then(|| { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let rust_field_name = field.cql_name_literal(); + + parse_quote! { + if #column_spec.name != #rust_field_name { + return ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnNameMismatch { + field_index: #field_index, + column_index: #column_index, + rust_column_name: #rust_field_name, + db_column_name: ::std::clone::Clone::clone(&#column_spec.name), + } + ) + ); + } + } + }) + } + + fn generate(&self) -> syn::ImplItemFn { + // The generated method will check that the order and the types + // of the columns correspond fields' names/types. + + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + + let required_fields_iter = || { + self.0 + .fields() + .iter() + .enumerate() + .filter(|(_, f)| f.is_required()) + }; + let required_fields_count = required_fields_iter().count(); + let required_fields_idents: Vec<_> = (0..required_fields_count) + .map(|i| quote::format_ident!("f_{}", i)) + .collect(); + let name_verifications = required_fields_iter() + .zip(required_fields_idents.iter().enumerate()) + .map(|((field_idx, field), (col_idx, fidents))| { + self.generate_name_verification(field_idx, col_idx, field, fidents) + }); + + let required_fields_deserializers = + required_fields_iter().map(|(_, f)| f.deserialize_target()); + let numbers = 0usize..; + + parse_quote! { + fn type_check( + specs: &[#macro_internal::ColumnSpec], + ) -> ::std::result::Result<(), #macro_internal::TypeCheckError> { + let column_types_iter = || specs.iter().map(|spec| ::std::clone::Clone::clone(&spec.typ)); + + match specs { + [#(#required_fields_idents),*] => { + #( + // Verify the name (unless `skip_name_checks' is specified) + #name_verifications + + // Verify the type + <#required_fields_deserializers as #macro_internal::DeserializeValue<#constraint_lifetime>>::type_check(&#required_fields_idents.typ) + .map_err(|err| #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index: #numbers, + column_name: ::std::clone::Clone::clone(&#required_fields_idents.name), + err, + } + ))?; + )* + ::std::result::Result::Ok(()) + }, + _ => ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::WrongColumnCount { + rust_cols: #required_fields_count, + cql_cols: specs.len(), + } + ), + ), + } + } + } + } +} + +struct DeserializeAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeAssumeOrderGenerator<'sd> { + fn generate_finalize_field(&self, field_index: usize, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote!(::std::default::Default::default()); + } + + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let cql_name_literal = field.cql_name_literal(); + let deserializer = field.deserialize_target(); + let constraint_lifetime = self.0.constraint_lifetime(); + + let name_check: Option = (!self.0.struct_attrs().skip_name_checks).then(|| parse_quote! { + if col.spec.name.as_str() != #cql_name_literal { + panic!( + "Typecheck should have prevented this scenario - field-column name mismatch! Rust field name {}, CQL column name {}", + #cql_name_literal, + col.spec.name.as_str() + ); + } + }); + + parse_quote!( + { + let col = row.next() + .expect("Typecheck should have prevented this scenario! Too few columns in the serialized data.") + .map_err(#macro_internal::row_deser_error_replace_rust_name::)?; + + #name_check + + <#deserializer as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(&col.spec.typ, col.slice) + .map_err(|err| #macro_internal::mk_row_deser_err::( + #macro_internal::BuiltinRowDeserializationErrorKind::ColumnDeserializationFailed { + column_index: #field_index, + column_name: <_ as std::clone::Clone>::clone(&col.spec.name), + err, + } + ))? + } + ) + } + + fn generate(&self) -> syn::ImplItemFn { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + + let fields = self.0.fields(); + let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let field_finalizers = fields + .iter() + .enumerate() + .map(|(field_idx, f)| self.generate_finalize_field(field_idx, f)); + + parse_quote! { + fn deserialize( + #[allow(unused_mut)] + mut row: #macro_internal::ColumnIterator<#constraint_lifetime>, + ) -> ::std::result::Result { + ::std::result::Result::Ok(Self { + #(#field_idents: #field_finalizers,)* + }) + } + } + } +} + +struct TypeCheckUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckUnorderedGenerator<'sd> { + // An identifier for a bool variable that represents whether given + // field was already visited during type check + fn visited_flag_variable(field: &Field) -> syn::Ident { + quote::format_ident!("visited_{}", field.ident.as_ref().unwrap().unraw()) + } + + // Generates a declaration of a "visited" flag for the purpose of type check. + // We generate it even if the flag is not required in order to protect + // from fields appearing more than once + fn generate_visited_flag_decl(field: &Field) -> Option { + (!field.skip).then(|| { + let visited_flag = Self::visited_flag_variable(field); + parse_quote! { + let mut #visited_flag = false; + } + }) + } + + // Generates code that, given variable `typ`, type-checks given field + fn generate_type_check(&self, field: &Field) -> Option { + (!field.skip).then(|| { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let visited_flag = Self::visited_flag_variable(field); + let typ = field.deserialize_target(); + let cql_name_literal = field.cql_name_literal(); + let decrement_if_required: Option:: = field.is_required().then(|| parse_quote! { + remaining_required_fields -= 1; + }); + + parse_quote! { + { + if !#visited_flag { + <#typ as #macro_internal::DeserializeValue<#constraint_lifetime>>::type_check(&spec.typ) + .map_err(|err| { + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + column_name: <_ as ::std::borrow::ToOwned>::to_owned(#cql_name_literal), + err, + } + ) + })?; + #visited_flag = true; + #decrement_if_required + } else { + return ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::DuplicatedColumn { + column_index, + column_name: #cql_name_literal, + } + ) + ) + } + } + } + }) + } + + // Generates code that appends the flag name if it is missing. + // The generated code is used to construct a nice error message. + fn generate_append_name(field: &Field) -> Option { + field.is_required().then(|| { + let visited_flag = Self::visited_flag_variable(field); + let cql_name_literal = field.cql_name_literal(); + parse_quote! { + { + if !#visited_flag { + missing_fields.push(#cql_name_literal); + } + } + } + }) + } + + fn generate(&self) -> syn::ImplItemFn { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + + let fields = self.0.fields(); + let visited_field_declarations = fields.iter().flat_map(Self::generate_visited_flag_decl); + let type_check_blocks = fields.iter().flat_map(|f| self.generate_type_check(f)); + let append_name_blocks = fields.iter().flat_map(Self::generate_append_name); + let nonskipped_field_names = fields + .iter() + .filter(|f| !f.skip) + .map(|f| f.cql_name_literal()); + let field_count_lit = fields.iter().filter(|f| f.is_required()).count(); + + parse_quote! { + fn type_check( + specs: &[#macro_internal::ColumnSpec], + ) -> ::std::result::Result<(), #macro_internal::TypeCheckError> { + // Counts down how many required fields are remaining + let mut remaining_required_fields: ::std::primitive::usize = #field_count_lit; + + // For each required field, generate a "visited" boolean flag + #(#visited_field_declarations)* + + let column_types_iter = || specs.iter().map(|spec| ::std::clone::Clone::clone(&spec.typ)); + + for (column_index, spec) in specs.iter().enumerate() { + // Pattern match on the name and verify that the type is correct. + match spec.name.as_str() { + #(#nonskipped_field_names => #type_check_blocks,)* + _unknown => { + return ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnWithUnknownName { + column_index, + column_name: <_ as ::std::clone::Clone>::clone(&spec.name) + } + ) + ) + } + } + } + + if remaining_required_fields > 0 { + // If there are some missing required fields, generate an error + // which contains missing field names + let mut missing_fields = ::std::vec::Vec::<&'static str>::with_capacity(remaining_required_fields); + #(#append_name_blocks)* + return ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ValuesMissingForColumns { + column_names: missing_fields + } + ) + ) + } + + ::std::result::Result::Ok(()) + } + } + } +} + +struct DeserializeUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeUnorderedGenerator<'sd> { + // An identifier for a variable that is meant to store the parsed variable + // before being ultimately moved to the struct on deserialize + fn deserialize_field_variable(field: &Field) -> syn::Ident { + quote::format_ident!("f_{}", field.ident.as_ref().unwrap().unraw()) + } + + // Generates an expression which produces a value ready to be put into a field + // of the target structure + fn generate_finalize_field(&self, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote! { + ::std::default::Default::default() + }; + } + + let deserialize_field = Self::deserialize_field_variable(field); + let cql_name_literal = field.cql_name_literal(); + parse_quote! { + #deserialize_field.unwrap_or_else(|| panic!( + "column {} missing in DB row - type check should have prevented this!", + #cql_name_literal + )) + } + } + + // Generated code that performs deserialization when the raw field + // is being processed + fn generate_deserialization(&self, column_index: usize, field: &Field) -> syn::Expr { + assert!(!field.skip); + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let deserialize_field = Self::deserialize_field_variable(field); + let deserializer = field.deserialize_target(); + + parse_quote! { + { + assert!( + #deserialize_field.is_none(), + "duplicated column {} - type check should have prevented this!", + stringify!(#deserialize_field) + ); + + #deserialize_field = ::std::option::Option::Some( + <#deserializer as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(&col.spec.typ, col.slice) + .map_err(|err| { + #macro_internal::mk_row_deser_err::( + #macro_internal::BuiltinRowDeserializationErrorKind::ColumnDeserializationFailed { + column_index: #column_index, + column_name: <_ as std::clone::Clone>::clone(&col.spec.name), + err, + } + ) + })? + ); + } + } + } + + // Generate a declaration of a variable that temporarily keeps + // the deserialized value + fn generate_deserialize_field_decl(field: &Field) -> Option { + (!field.skip).then(|| { + let deserialize_field = Self::deserialize_field_variable(field); + parse_quote! { + let mut #deserialize_field = ::std::option::Option::None; + } + }) + } + + fn generate(&self) -> syn::ImplItemFn { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let fields = self.0.fields(); + + let deserialize_field_decls = fields + .iter() + .flat_map(Self::generate_deserialize_field_decl); + let deserialize_blocks = fields + .iter() + .filter(|f| !f.skip) + .enumerate() + .map(|(col_idx, f)| self.generate_deserialization(col_idx, f)); + let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let nonskipped_field_names = fields + .iter() + .filter(|&f| (!f.skip)) + .map(|f| f.cql_name_literal()); + + let field_finalizers = fields.iter().map(|f| self.generate_finalize_field(f)); + + // TODO: Allow collecting unrecognized fields into some special field + + parse_quote! { + fn deserialize( + #[allow(unused_mut)] + mut row: #macro_internal::ColumnIterator<#constraint_lifetime>, + ) -> ::std::result::Result { + + // Generate fields that will serve as temporary storage + // for the fields' values. Those are of type Option. + #(#deserialize_field_decls)* + + for col in row { + let col = col.map_err(#macro_internal::row_deser_error_replace_rust_name::)?; + // Pattern match on the field name and deserialize. + match col.spec.name.as_str() { + #(#nonskipped_field_names => #deserialize_blocks,)* + unknown => unreachable!("Typecheck should have prevented this scenario! Unknown column name: {}", unknown), + } + } + + // Create the final struct. The finalizer expressions convert + // the temporary storage fields to the final field values. + // For example, if a field is missing but marked as + // `default_when_null` it will create a default value, otherwise + // it will report an error. + Ok(Self { + #(#field_idents: #field_finalizers,)* + }) + } + } + } +} diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs new file mode 100644 index 0000000000..6b89bcfb58 --- /dev/null +++ b/scylla-macros/src/deserialize/value.rs @@ -0,0 +1,898 @@ +use std::collections::HashMap; + +use darling::{FromAttributes, FromField}; +use proc_macro::TokenStream; +use proc_macro2::Span; +use syn::{ext::IdentExt, parse_quote}; + +use super::{DeserializeCommonFieldAttrs, DeserializeCommonStructAttrs}; + +#[derive(FromAttributes)] +#[darling(attributes(scylla))] +struct StructAttrs { + #[darling(rename = "crate")] + crate_path: Option, + + // If true, then the type checking code will require the order of the fields + // to be the same in both the Rust struct and the UDT. This allows the + // deserialization to be slightly faster because looking struct fields up + // by name can be avoided, though it is less convenient. + #[darling(default)] + enforce_order: bool, + + // If true, then the type checking code won't verify the UDT field names. + // UDT fields will be matched to struct fields based solely on the order. + // + // This annotation only works if `enforce_order` is specified. + #[darling(default)] + skip_name_checks: bool, + + // If true, then the type checking code will require that the UDT does not + // contain excess fields at its suffix. Otherwise, if UDT has some fields + // at its suffix that do not correspond to Rust struct's fields, + // they will be ignored. With true, an error will be raised. + #[darling(default)] + forbid_excess_udt_fields: bool, +} + +impl DeserializeCommonStructAttrs for StructAttrs { + fn crate_path(&self) -> Option<&syn::Path> { + self.crate_path.as_ref() + } +} + +#[derive(FromField)] +#[darling(attributes(scylla))] +struct Field { + // If true, then the field is not parsed at all, but it is initialized + // with Default::default() instead. All other attributes are ignored. + #[darling(default)] + skip: bool, + + // If true, then - if this field is missing from the UDT fields metadata + // - it will be initialized to Default::default(). + #[darling(default)] + #[darling(rename = "allow_missing")] + default_when_missing: bool, + + // If true, then - if this field is present among UDT fields metadata + // but at the same time missing from serialized data or set to null + // - it will be initialized to Default::default(). + #[darling(default)] + default_when_null: bool, + + // If set, then deserializes from the UDT field with this particular name + // instead of the Rust field name. + #[darling(default)] + rename: Option, + + ident: Option, + ty: syn::Type, +} + +impl DeserializeCommonFieldAttrs for Field { + fn needs_default(&self) -> bool { + self.skip || self.default_when_missing + } + + fn deserialize_target(&self) -> &syn::Type { + &self.ty + } +} + +// derive(DeserializeValue) for the DeserializeValue trait +pub(crate) fn deserialize_value_derive( + tokens_input: TokenStream, +) -> Result { + let input = syn::parse(tokens_input)?; + + let implemented_trait: syn::Path = parse_quote!(DeserializeValue); + let implemented_trait_name = implemented_trait + .segments + .last() + .unwrap() + .ident + .unraw() + .to_string(); + let constraining_trait = implemented_trait.clone(); + let s = StructDesc::new(&input, &implemented_trait_name, constraining_trait)?; + + validate_attrs(&s.attrs, s.fields())?; + + let items = [ + s.generate_type_check_method().into(), + s.generate_deserialize_method().into(), + ]; + + Ok(s.generate_impl(implemented_trait, items)) +} + +fn validate_attrs(attrs: &StructAttrs, fields: &[Field]) -> Result<(), darling::Error> { + let mut errors = darling::Error::accumulator(); + + if attrs.skip_name_checks { + // Skipping name checks is only available in enforce_order mode + if !attrs.enforce_order { + let error = + darling::Error::custom("attribute requires ."); + errors.push(error); + } + + // Fields with `allow_missing` are only permitted at the end of the + // struct, i.e. no field without `allow_missing` and `skip` is allowed + // to be after any field with `allow_missing`. + let invalid_default_when_missing_field = fields + .iter() + .rev() + // Skip the whole suffix of and . + .skip_while(|field| !field.is_required()) + // skip_while finished either because the iterator is empty or it found a field without both and . + // In either case, there aren't allowed to be any more fields with `allow_missing`. + .find(|field| field.default_when_missing); + if let Some(invalid) = invalid_default_when_missing_field { + let error = + darling::Error::custom( + "when is on, fields with are only permitted at the end of the struct, \ + i.e. no field without and is allowed to be after any field with ." + ).with_span(&invalid.ident); + errors.push(error); + } + + // annotations don't make sense with skipped name checks + for field in fields { + if field.rename.is_some() { + let err = darling::Error::custom( + " annotations don't make sense with attribute", + ) + .with_span(&field.ident); + errors.push(err); + } + } + } else { + // Detect name collisions caused by . + let mut used_names = HashMap::::new(); + for field in fields { + let udt_field_name = field.udt_field_name(); + if let Some(other_field) = used_names.get(&udt_field_name) { + let other_field_ident = other_field.ident.as_ref().unwrap(); + let msg = format!("the UDT field name `{udt_field_name}` used by this struct field is already used by field `{other_field_ident}`"); + let err = darling::Error::custom(msg).with_span(&field.ident); + errors.push(err); + } else { + used_names.insert(udt_field_name, field); + } + } + } + + errors.finish() +} + +impl Field { + // Returns whether this field is mandatory for deserialization. + fn is_required(&self) -> bool { + !self.skip && !self.default_when_missing + } + + // The name of UDT field corresponding to this Rust struct field + fn udt_field_name(&self) -> String { + match self.rename.as_ref() { + Some(rename) => rename.to_owned(), + None => self.ident.as_ref().unwrap().unraw().to_string(), + } + } + + // A Rust literal representing the name of this field + fn cql_name_literal(&self) -> syn::LitStr { + syn::LitStr::new(&self.udt_field_name(), Span::call_site()) + } +} + +type StructDesc = super::StructDescForDeserialize; + +impl StructDesc { + /// Generates an expression which extracts the UDT fields or returns an error. + fn generate_extract_fields_from_type(&self, typ_expr: syn::Expr) -> syn::Expr { + let macro_internal = &self.struct_attrs().macro_internal_path(); + parse_quote!( + match #typ_expr { + #macro_internal::ColumnType::UserDefinedType { field_types, .. } => field_types, + other => return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + &other, + #macro_internal::DeserUdtTypeCheckErrorKind::NotUdt, + ) + ), + } + ) + } + + fn generate_type_check_method(&self) -> syn::ImplItemFn { + if self.attrs.enforce_order { + TypeCheckAssumeOrderGenerator(self).generate() + } else { + TypeCheckUnorderedGenerator(self).generate() + } + } + + fn generate_deserialize_method(&self) -> syn::ImplItemFn { + if self.attrs.enforce_order { + DeserializeAssumeOrderGenerator(self).generate() + } else { + DeserializeUnorderedGenerator(self).generate() + } + } +} + +struct TypeCheckAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { + // Generates name and type validation for given Rust struct's field. + fn generate_field_validation(&self, rust_field_idx: usize, field: &Field) -> syn::Expr { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let rust_field_name = field.cql_name_literal(); + let rust_field_typ = field.deserialize_target(); + let default_when_missing = field.default_when_missing; + let skip_name_checks = self.0.attrs.skip_name_checks; + + // Action performed in case of field name mismatch. + let name_mismatch: syn::Expr = if default_when_missing { + parse_quote! { + { + // If the Rust struct's field is marked as `default_when_missing`, then let's assume + // optimistically that the remaining UDT fields match required Rust struct fields. + // For that, store the read UDT field to be fit against the next Rust struct field. + saved_cql_field = ::std::option::Option::Some(next_cql_field); + break 'verifications; // Skip type verification, because the UDT field is absent. + } + } + } else { + parse_quote! { + { + // Error - required value for field not present among the CQL fields. + return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::FieldNameMismatch { + position: #rust_field_idx, + rust_field_name: <_ as ::std::borrow::ToOwned>::to_owned(#rust_field_name), + db_field_name: <_ as ::std::borrow::ToOwned>::to_owned(cql_field_name), + } + ) + ); + } + } + }; + + // Optional name verification. + let name_verification: Option = (!skip_name_checks).then(|| { + parse_quote! { + if #rust_field_name != cql_field_name { + // The read UDT field is not the one expected by the Rust struct. + #name_mismatch + } + } + }); + + parse_quote! { + 'field: { + let next_cql_field = match saved_cql_field + // We may have a stored CQL UDT field that did not match the previous Rust struct's field. + .take() + // If not, simply fetch another CQL UDT field from the iterator. + .or_else(|| cql_field_iter.next()) { + ::std::option::Option::Some(cql_field) => cql_field, + // In case the Rust field allows default-initialisation and there are no more CQL fields, + // simply assume it's going to be default-initialised. + ::std::option::Option::None if #default_when_missing => break 'field, + ::std::option::Option::None => return Err(too_few_fields()), + }; + let (cql_field_name, cql_field_typ) = next_cql_field; + + 'verifications: { + // Verify the name (unless `skip_name_checks` is specified) + // In a specific case when this Rust field is going to be default-initialised + // due to no corresponding CQL UDT field, the below type verification will be skipped. + #name_verification + + // Verify the type + <#rust_field_typ as #macro_internal::DeserializeValue<#constraint_lifetime>>::type_check(cql_field_typ) + .map_err(|err| #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::FieldTypeCheckFailed { + field_name: <_ as ::std::borrow::ToOwned>::to_owned(#rust_field_name), + err, + } + ))?; + } + } + } + } + + // Generates the type_check method for when ensure_order == true. + fn generate(&self) -> syn::ImplItemFn { + // The generated method will: + // - Check that every required field appears on the list in the same order as struct fields + // - Every type on the list is correct + + let macro_internal = self.0.struct_attrs().macro_internal_path(); + + let extract_fields_expr = self.0.generate_extract_fields_from_type(parse_quote!(typ)); + + let required_fields_iter = || self.0.fields().iter().filter(|f| f.is_required()); + + let required_field_count = required_fields_iter().count(); + let required_field_count_lit = + syn::LitInt::new(&required_field_count.to_string(), Span::call_site()); + + let required_fields_names = required_fields_iter().map(|field| field.ident.as_ref()); + + let nonskipped_fields_iter = || { + self.0 + .fields() + .iter() + // It is important that we enumerate **before** filtering, because otherwise we would not + // count the skipped fields, which might be confusing. + .enumerate() + .filter(|(_idx, f)| !f.skip) + }; + + let field_validations = + nonskipped_fields_iter().map(|(idx, field)| self.generate_field_validation(idx, field)); + + let check_excess_udt_fields: Option = + self.0.attrs.forbid_excess_udt_fields.then(|| { + parse_quote! { + if let ::std::option::Option::Some((cql_field_name, cql_field_typ)) = saved_cql_field + .take() + .or_else(|| cql_field_iter.next()) { + return ::std::result::Result::Err(#macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::ExcessFieldInUdt { + db_field_name: <_ as ::std::clone::Clone>::clone(cql_field_name), + } + )); + } + } + }); + + parse_quote! { + fn type_check( + typ: &#macro_internal::ColumnType, + ) -> ::std::result::Result<(), #macro_internal::TypeCheckError> { + // Extract information about the field types from the UDT + // type definition. + let fields = #extract_fields_expr; + + let too_few_fields = || #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::TooFewFields { + required_fields: vec![ + #(stringify!(#required_fields_names),)* + ], + present_fields: fields.iter().map(|(name, _typ)| name.clone()).collect(), + } + ); + + // Verify that the field count is correct + if fields.len() < #required_field_count_lit { + return ::std::result::Result::Err(too_few_fields()); + } + + let mut cql_field_iter = fields.iter(); + // A CQL UDT field that has already been fetched from the field iterator, + // but not yet matched to a Rust struct field (because the previous + // Rust struct field didn't match it and had #[allow_missing] specified). + let mut saved_cql_field = ::std::option::Option::None::< + &(::std::string::String, #macro_internal::ColumnType), + >; + #( + #field_validations + )* + + #check_excess_udt_fields + + // All is good! + ::std::result::Result::Ok(()) + } + } + } +} + +struct DeserializeAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeAssumeOrderGenerator<'sd> { + fn generate_finalize_field(&self, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote! { + ::std::default::Default::default() + }; + } + + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let cql_name_literal = field.cql_name_literal(); + let deserializer = field.deserialize_target(); + let constraint_lifetime = self.0.constraint_lifetime(); + let default_when_missing = field.default_when_missing; + let default_when_null = field.default_when_null; + let skip_name_checks = self.0.attrs.skip_name_checks; + + let deserialize: syn::Expr = parse_quote! { + <#deserializer as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(cql_field_typ, value) + .map_err(|err| #macro_internal::mk_value_deser_err::( + typ, + #macro_internal::UdtDeserializationErrorKind::FieldDeserializationFailed { + field_name: #cql_name_literal.to_owned(), + err, + } + ))? + }; + + let maybe_default_deserialize: syn::Expr = if default_when_null { + parse_quote! { + if value.is_none() { + ::std::default::Default::default() + } else { + #deserialize + } + } + } else { + parse_quote! { + #deserialize + } + }; + + // Action performed in case of field name mismatch. + let name_mismatch: syn::Expr = if default_when_missing { + parse_quote! { + { + // If the Rust struct's field is marked as `default_when_missing`, then let's assume + // optimistically that the remaining UDT fields match required Rust struct fields. + // For that, store the read UDT field to be fit against the next Rust struct field. + saved_cql_field = ::std::option::Option::Some(next_cql_field); + + ::std::default::Default::default() + } + } + } else { + parse_quote! { + { + panic!( + "type check should have prevented this scenario - field name mismatch! Rust field name {}, CQL field name {}", + #cql_name_literal, + cql_field_name + ); + } + } + }; + + let maybe_name_check_and_deserialize_or_save: syn::Expr = if skip_name_checks { + parse_quote! { + #maybe_default_deserialize + } + } else { + parse_quote! { + if #cql_name_literal == cql_field_name { + #maybe_default_deserialize + } else { + #name_mismatch + } + } + }; + + let no_more_fields: syn::Expr = if default_when_missing { + parse_quote! { + ::std::default::Default::default() + } + } else { + parse_quote! { + // Type check has ensured that there are enough CQL UDT fields. + panic!("Too few CQL UDT fields - type check should have prevented this scenario!") + } + }; + + parse_quote! { + { + let maybe_next_cql_field = saved_cql_field + .take() + .map(::std::result::Result::Ok) + .or_else(|| { + cql_field_iter.next() + .map(|(specs, value_res)| value_res.map(|value| (specs, value))) + }) + .transpose() + // Propagate deserialization errors. + .map_err(|err| #macro_internal::mk_value_deser_err::( + typ, + #macro_internal::UdtDeserializationErrorKind::FieldDeserializationFailed { + field_name: #cql_name_literal.to_owned(), + err, + } + ))?; + + if let Some(next_cql_field) = maybe_next_cql_field { + let ((cql_field_name, cql_field_typ), value) = next_cql_field; + + // The value can be either + // - None - missing from the serialized representation + // - Some(None) - present in the serialized representation but null + // For now, we treat both cases as "null". + let value = value.flatten(); + + #maybe_name_check_and_deserialize_or_save + } else { + #no_more_fields + } + } + } + } + + fn generate(&self) -> syn::ImplItemFn { + // We can assume that type_check was called. + + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let fields = self.0.fields(); + + let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let field_finalizers = fields.iter().map(|f| self.generate_finalize_field(f)); + + #[allow(unused_mut)] + let mut iterator_type: syn::Type = + parse_quote!(#macro_internal::UdtIterator<#constraint_lifetime>); + + parse_quote! { + fn deserialize( + typ: &#constraint_lifetime #macro_internal::ColumnType, + v: ::std::option::Option<#macro_internal::FrameSlice<#constraint_lifetime>>, + ) -> ::std::result::Result { + // Create an iterator over the fields of the UDT. + let mut cql_field_iter = <#iterator_type as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(typ, v) + .map_err(#macro_internal::value_deser_error_replace_rust_name::)?; + + // This is to hold another field that already popped up from the field iterator but appeared to not match + // the expected nonrequired field. Therefore, that field is stored here, while the expected field + // is default-initialized. + let mut saved_cql_field = ::std::option::Option::None::<( + &(::std::string::String, #macro_internal::ColumnType), + ::std::option::Option<::std::option::Option<#macro_internal::FrameSlice>> + )>; + + ::std::result::Result::Ok(Self { + #(#field_idents: #field_finalizers,)* + }) + } + } + } +} + +struct TypeCheckUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckUnorderedGenerator<'sd> { + // An identifier for a bool variable that represents whether given + // field was already visited during type check + fn visited_flag_variable(field: &Field) -> syn::Ident { + quote::format_ident!("visited_{}", field.ident.as_ref().unwrap().unraw()) + } + + // Generates a declaration of a "visited" flag for the purpose of type check. + // We generate it even if the flag is not required in order to protect + // from fields appearing more than once + fn generate_visited_flag_decl(field: &Field) -> Option { + (!field.skip).then(|| { + let visited_flag = Self::visited_flag_variable(field); + parse_quote! { + let mut #visited_flag = false; + } + }) + } + + // Generates code that, given variable `typ`, type-checks given field + fn generate_type_check(&self, field: &Field) -> Option { + (!field.skip).then(|| { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let visited_flag = Self::visited_flag_variable(field); + let typ = field.deserialize_target(); + let cql_name_literal = field.cql_name_literal(); + let decrement_if_required: Option = field + .is_required() + .then(|| parse_quote! {remaining_required_cql_fields -= 1;}); + + parse_quote! { + { + if !#visited_flag { + <#typ as #macro_internal::DeserializeValue<#constraint_lifetime>>::type_check(cql_field_typ) + .map_err(|err| #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::FieldTypeCheckFailed { + field_name: <_ as ::std::clone::Clone>::clone(cql_field_name), + err, + } + ))?; + #visited_flag = true; + #decrement_if_required + } else { + return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::DuplicatedField { + field_name: <_ as ::std::borrow::ToOwned>::to_owned(#cql_name_literal), + } + ) + ) + } + } + } + }) + } + + // Generates code that appends the field name if it is missing. + // The generated code is used to construct a nice error message. + fn generate_append_name(field: &Field) -> Option { + field.is_required().then(|| { + let visited_flag = Self::visited_flag_variable(field); + let cql_name_literal = field.cql_name_literal(); + parse_quote!( + { + if !#visited_flag { + missing_fields.push(#cql_name_literal); + } + } + ) + }) + } + + // Generates the type_check method for when ensure_order == false. + fn generate(&self) -> syn::ImplItemFn { + // The generated method will: + // - Check that every required field appears on the list exactly once, in any order + // - Every type on the list is correct + + let macro_internal = &self.0.struct_attrs().macro_internal_path(); + let forbid_excess_udt_fields = self.0.attrs.forbid_excess_udt_fields; + let rust_fields = self.0.fields(); + let visited_field_declarations = rust_fields + .iter() + .flat_map(Self::generate_visited_flag_decl); + let type_check_blocks = rust_fields.iter().flat_map(|f| self.generate_type_check(f)); + let append_name_blocks = rust_fields.iter().flat_map(Self::generate_append_name); + let rust_nonskipped_field_names = rust_fields + .iter() + .filter(|f| !f.skip) + .map(|f| f.cql_name_literal()); + let required_cql_field_count = rust_fields.iter().filter(|f| f.is_required()).count(); + let required_cql_field_count_lit = + syn::LitInt::new(&required_cql_field_count.to_string(), Span::call_site()); + let extract_cql_fields_expr = self.0.generate_extract_fields_from_type(parse_quote!(typ)); + + // If UDT contains a field with an unknown name, an error is raised iff + // `forbid_excess_udt_fields` attribute is specified. + let excess_udt_field_action: syn::Expr = if forbid_excess_udt_fields { + parse_quote! { + return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::ExcessFieldInUdt { + db_field_name: unknown.to_owned(), + } + ) + ) + } + } else { + parse_quote! { + // We ignore excess UDT fields, as this facilitates the process of adding new fields + // to a UDT in running production cluster & clients. + () + } + }; + + parse_quote! { + fn type_check( + typ: &#macro_internal::ColumnType, + ) -> ::std::result::Result<(), #macro_internal::TypeCheckError> { + // Extract information about the field types from the UDT + // type definition. + let cql_fields = #extract_cql_fields_expr; + + // Counts down how many required fields are remaining + let mut remaining_required_cql_fields: ::std::primitive::usize = #required_cql_field_count_lit; + + // For each required field, generate a "visited" boolean flag + #(#visited_field_declarations)* + + for (cql_field_name, cql_field_typ) in cql_fields { + // Pattern match on the name and verify that the type is correct. + match cql_field_name.as_str() { + #(#rust_nonskipped_field_names => #type_check_blocks,)* + unknown => #excess_udt_field_action, + } + } + + if remaining_required_cql_fields > 0 { + // If there are some missing required fields, generate an error + // which contains missing field names + let mut missing_fields = ::std::vec::Vec::<&'static str>::with_capacity(remaining_required_cql_fields); + #(#append_name_blocks)* + return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::ValuesMissingForUdtFields { + field_names: missing_fields, + } + ) + ) + } + + ::std::result::Result::Ok(()) + } + } + } +} + +struct DeserializeUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeUnorderedGenerator<'sd> { + /// An identifier for a variable that is meant to store the parsed variable + /// before being ultimately moved to the struct on deserialize. + fn deserialize_field_variable(field: &Field) -> syn::Ident { + quote::format_ident!("f_{}", field.ident.as_ref().unwrap().unraw()) + } + + /// Generates an expression which produces a value ready to be put into a field + /// of the target structure. + fn generate_finalize_field(&self, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote!(::std::default::Default::default()); + } + + let deserialize_field = Self::deserialize_field_variable(field); + if field.default_when_missing { + // Generate Default::default if the field was missing + parse_quote! { + #deserialize_field.unwrap_or_default() + } + } else { + let cql_name_literal = field.cql_name_literal(); + parse_quote! { + #deserialize_field.unwrap_or_else(|| panic!( + "field {} missing in UDT - type check should have prevented this!", + #cql_name_literal + )) + } + } + } + + /// Generates code that performs deserialization when the raw field + /// is being processed. + fn generate_deserialization(&self, field: &Field) -> Option { + (!field.skip).then(|| { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let deserialize_field = Self::deserialize_field_variable(field); + let cql_name_literal = field.cql_name_literal(); + let deserializer = field.deserialize_target(); + + let do_deserialize: syn::Expr = parse_quote! { + <#deserializer as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(cql_field_typ, value) + .map_err(|err| #macro_internal::mk_value_deser_err::( + typ, + #macro_internal::UdtDeserializationErrorKind::FieldDeserializationFailed { + field_name: #cql_name_literal.to_owned(), + err, + } + ))? + }; + + let deserialize_action: syn::Expr = if field.default_when_null { + parse_quote! { + if value.is_some() { + #do_deserialize + } else { + ::std::default::Default::default() + } + } + } else { + do_deserialize + }; + + parse_quote! { + { + assert!( + #deserialize_field.is_none(), + "duplicated field {} - type check should have prevented this!", + stringify!(#deserialize_field) + ); + + // The value can be either + // - None - missing from the serialized representation + // - Some(None) - present in the serialized representation but null + // For now, we treat both cases as "null". + let value = value.flatten(); + + #deserialize_field = ::std::option::Option::Some( + #deserialize_action + ); + } + } + }) + } + + // Generate a declaration of a variable that temporarily keeps + // the deserialized value + fn generate_deserialize_field_decl(field: &Field) -> Option { + (!field.skip).then(|| { + let deserialize_field = Self::deserialize_field_variable(field); + parse_quote! { + let mut #deserialize_field = ::std::option::Option::None; + } + }) + } + + fn generate(&self) -> syn::ImplItemFn { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let fields = self.0.fields(); + + let deserialize_field_decls = fields.iter().map(Self::generate_deserialize_field_decl); + let deserialize_blocks = fields.iter().flat_map(|f| self.generate_deserialization(f)); + let rust_field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let rust_nonskipped_field_names = fields + .iter() + .filter(|f| !f.skip) + .map(|f| f.cql_name_literal()); + + let field_finalizers = fields.iter().map(|f| self.generate_finalize_field(f)); + + let iterator_type: syn::Type = parse_quote! { + #macro_internal::UdtIterator<#constraint_lifetime> + }; + + // TODO: Allow collecting unrecognized fields into some special field + + parse_quote! { + fn deserialize( + typ: &#constraint_lifetime #macro_internal::ColumnType, + v: ::std::option::Option<#macro_internal::FrameSlice<#constraint_lifetime>>, + ) -> ::std::result::Result { + // Create an iterator over the fields of the UDT. + let cql_field_iter = <#iterator_type as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(typ, v) + .map_err(#macro_internal::value_deser_error_replace_rust_name::)?; + + // Generate fields that will serve as temporary storage + // for the fields' values. Those are of type Option. + #(#deserialize_field_decls)* + + for item in cql_field_iter { + let ((cql_field_name, cql_field_typ), value_res) = item; + let value = value_res.map_err(|err| #macro_internal::mk_value_deser_err::( + typ, + #macro_internal::UdtDeserializationErrorKind::FieldDeserializationFailed { + field_name: ::std::clone::Clone::clone(cql_field_name), + err, + } + ))?; + // Pattern match on the field name and deserialize. + match cql_field_name.as_str() { + #(#rust_nonskipped_field_names => #deserialize_blocks,)* + unknown => { + // Assuming we type checked sucessfully, this must be an excess field. + // Let's skip it. + }, + } + } + + // Create the final struct. The finalizer expressions convert + // the temporary storage fields to the final field values. + // For example, if a field is missing but marked as + // `default_when_null` it will create a default value, otherwise + // it will report an error. + ::std::result::Result::Ok(Self { + #(#rust_field_idents: #field_finalizers,)* + }) + } + } + } +} diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index 5022f09f15..05e24362f4 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -1,5 +1,5 @@ +use darling::ToTokens; use proc_macro::TokenStream; -use quote::ToTokens; mod from_row; mod from_user_type; @@ -66,3 +66,21 @@ pub fn value_list_derive(tokens_input: TokenStream) -> TokenStream { let res = value_list::value_list_derive(tokens_input); res.unwrap_or_else(|e| e.into_compile_error().into()) } + +mod deserialize; + +#[proc_macro_derive(DeserializeRow, attributes(scylla))] +pub fn deserialize_row_derive(tokens_input: TokenStream) -> TokenStream { + match deserialize::row::deserialize_row_derive(tokens_input) { + Ok(tokens) => tokens.into_token_stream().into(), + Err(err) => err.into_compile_error().into(), + } +} + +#[proc_macro_derive(DeserializeValue, attributes(scylla))] +pub fn deserialize_value_derive(tokens_input: TokenStream) -> TokenStream { + match deserialize::value::deserialize_value_derive(tokens_input) { + Ok(tokens) => tokens.into_token_stream().into(), + Err(err) => err.into_compile_error().into(), + } +} diff --git a/scylla-macros/src/parser.rs b/scylla-macros/src/parser.rs index ec72a81b1c..7c376c16f8 100644 --- a/scylla-macros/src/parser.rs +++ b/scylla-macros/src/parser.rs @@ -1,5 +1,4 @@ -use syn::{Data, DeriveInput, ExprLit, Fields, FieldsNamed, FieldsUnnamed, Lit}; -use syn::{Expr, Meta}; +use syn::{Data, DeriveInput, Expr, ExprLit, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta}; /// Parses a struct DeriveInput and returns named fields of this struct. pub(crate) fn parse_named_fields<'a>( @@ -51,8 +50,8 @@ pub(crate) fn parse_struct_fields<'a>( } } -pub(crate) fn get_path(input: &DeriveInput) -> Result { - let mut this_path: Option = None; +pub(crate) fn get_path(input: &DeriveInput) -> Result { + let mut this_path: Option = None; for attr in input.attrs.iter() { if !attr.path().is_ident("scylla_crate") { continue; @@ -65,7 +64,7 @@ pub(crate) fn get_path(input: &DeriveInput) -> Result Result { +/// a: i32, +/// b: Option, +/// c: &'a [u8], +/// } +/// ``` +/// +/// # Attributes +/// +/// The macro supports a number of attributes that customize the generated +/// implementation. Many of the attributes were inspired by procedural macros +/// from `serde` and try to follow the same naming conventions. +/// +/// ## Struct attributes +/// +/// `#[scylla(crate = "crate_name")]` +/// +/// By default, the code generated by the derive macro will refer to the items +/// defined by the driver (types, traits, etc.) via the `::scylla` path. +/// For example, it will refer to the [`DeserializeValue`](crate::deserialize::DeserializeValue) +/// trait using the following path: +/// +/// ```rust,ignore +/// use ::scylla::_macro_internal::DeserializeValue; +/// ``` +/// +/// Most users will simply add `scylla` to their dependencies, then use +/// the derive macro and the path above will work. However, there are some +/// niche cases where this path will _not_ work: +/// +/// - The `scylla` crate is imported under a different name, +/// - The `scylla` crate is _not imported at all_ - the macro actually +/// is defined in the `scylla-macros` crate and the generated code depends +/// on items defined in `scylla-cql`. +/// +/// It's not possible to automatically resolve those issues in the procedural +/// macro itself, so in those cases the user must provide an alternative path +/// to either the `scylla` or `scylla-cql` crate. +/// +/// `#[scylla(enforce_order)]` +/// +/// By default, the generated deserialization code will be insensitive +/// to the UDT field order - when processing a field, it will look it up +/// in the Rust struct with the corresponding field and set it. However, +/// if the UDT field order is known to be the same both in the UDT +/// and the Rust struct, then the `enforce_order` annotation can be used +/// so that a more efficient implementation that does not perform lookups +/// is be generated. The UDT field names will still be checked during the +/// type check phase. +/// +/// #[(scylla(skip_name_checks))] +/// +/// This attribute only works when used with `enforce_order`. +/// +/// If set, the generated implementation will not verify the UDT field names at +/// all. Because it only works with `enforce_order`, it will deserialize first +/// UDT field into the first struct field, second UDT field into the second +/// struct field and so on. It will still verify that the UDT field types +/// and struct field types match. +/// +/// #[(scylla(forbid_excess_udt_fields))] +/// +/// By default, the generated deserialization code ignores excess UDT fields. +/// I.e., `enforce_order` flavour ignores excess UDT fields in the suffix +/// of the UDT definition, and the default unordered flavour ignores excess +/// UDT fields anywhere. +/// If more strictness is desired, this flag makes sure that no excess fields +/// are present and forces error in case there are some. +/// +/// ## Field attributes +/// +/// `#[scylla(skip)]` +/// +/// The field will be completely ignored during deserialization and will +/// be initialized with `Default::default()`. +/// +/// `#[scylla(allow_missing)]` +/// +/// If the UDT definition does not contain this field, it will be initialized +/// with `Default::default()`. +/// +/// `#[scylla(default_when_null)]` +/// +/// If the value of the field received from DB is null, the field will be +/// initialized with `Default::default()`. +/// +/// `#[scylla(rename = "field_name")` +/// +/// By default, the generated implementation will try to match the Rust field +/// to a UDT field with the same name. This attribute instead allows to match +/// to a UDT field with provided name. +pub use scylla_macros::DeserializeValue; + +/// Derive macro for the `DeserializeRow` trait that generates an implementation +/// which deserializes a row with a similar layout to the Rust struct. +/// +/// At the moment, only structs with named fields are supported. +/// +/// This macro properly supports structs with lifetimes, meaning that you can +/// deserialize columns that borrow memory from the serialized response. +/// +/// # Example +/// +/// Having a table defined like this: +/// +/// ```text +/// CREATE TABLE ks.my_table (a PRIMARY KEY, b text, c blob); +/// ``` +/// +/// results of a query "SELECT * FROM ks.my_table" +/// or "SELECT a, b, c FROM ks.my_table" +/// can be deserialized using the following struct: +/// +/// ```rust +/// # use scylla_cql::macros::DeserializeRow; +/// #[derive(DeserializeRow)] +/// # #[scylla(crate = "scylla_cql")] +/// struct MyRow<'a> { +/// a: i32, +/// b: Option, +/// c: &'a [u8], +/// } +/// ``` +/// +/// In general, the struct must match the queried names and types, +/// not the table itself. For example, the query +/// "SELECT a AS b FROM ks.my_table" executed against +/// the aforementioned table can be deserialized to the struct: +/// ```rust +/// # use scylla_cql::macros::DeserializeRow; +/// #[derive(DeserializeRow)] +/// # #[scylla(crate = "scylla_cql")] +/// struct MyRow { +/// b: i32, +/// } +/// ``` +/// +/// # Attributes +/// +/// The macro supports a number of attributes that customize the generated +/// implementation. Many of the attributes were inspired by procedural macros +/// from `serde` and try to follow the same naming conventions. +/// +/// ## Struct attributes +/// +/// `#[scylla(crate = "crate_name")]` +/// +/// By default, the code generated by the derive macro will refer to the items +/// defined by the driver (types, traits, etc.) via the `::scylla` path. +/// For example, it will refer to the [`DeserializeValue`](crate::deserialize::DeserializeValue) +/// trait using the following path: +/// +/// ```rust,ignore +/// use ::scylla::_macro_internal::DeserializeValue; +/// ``` +/// +/// Most users will simply add `scylla` to their dependencies, then use +/// the derive macro and the path above will work. However, there are some +/// niche cases where this path will _not_ work: +/// +/// - The `scylla` crate is imported under a different name, +/// - The `scylla` crate is _not imported at all_ - the macro actually +/// is defined in the `scylla-macros` crate and the generated code depends +/// on items defined in `scylla-cql`. +/// +/// It's not possible to automatically resolve those issues in the procedural +/// macro itself, so in those cases the user must provide an alternative path +/// to either the `scylla` or `scylla-cql` crate. +/// +/// `#[scylla(enforce_order)]` +/// +/// By default, the generated deserialization code will be insensitive +/// to the column order - when processing a column, the corresponding Rust field +/// will be looked up and the column will be deserialized based on its type. +/// However, if the column order and the Rust field order is known to be the +/// same, then the `enforce_order` annotation can be used so that a more +/// efficient implementation that does not perform lookups is be generated. +/// The generated code will still check that the column and field names match. +/// +/// #[(scylla(skip_name_checks))] +/// +/// This attribute only works when used with `enforce_order`. +/// +/// If set, the generated implementation will not verify the column names at +/// all. Because it only works with `enforce_order`, it will deserialize first +/// column into the first field, second column into the second field and so on. +/// It will still still verify that the column types and field types match. +/// +/// ## Field attributes +/// +/// `#[scylla(skip)]` +/// +/// The field will be completely ignored during deserialization and will +/// be initialized with `Default::default()`. +/// +/// `#[scylla(rename = "field_name")` +/// +/// By default, the generated implementation will try to match the Rust field +/// to a column with the same name. This attribute allows to match to a column +/// with provided name. +pub use scylla_macros::DeserializeRow; + /// #[derive(ValueList)] allows to pass struct as a list of values for a query /// /// ---