From 4c0c8e63f445a637ab71cb5c5d82ca5726322f15 Mon Sep 17 00:00:00 2001 From: Ruslan Piasetskyi Date: Mon, 26 Sep 2022 11:07:30 +0200 Subject: [PATCH] serializable-state-derive: initial commit --- Cargo.lock | 77 ++++++- crypto-common/Cargo.toml | 1 + .../serializable-state-derive/Cargo.toml | 23 +++ .../serializable-state-derive/src/lib.rs | 195 ++++++++++++++++++ .../tests/serialization-tests.rs | 71 +++++++ crypto-common/src/serializable_state.rs | 2 + 6 files changed, 363 insertions(+), 6 deletions(-) create mode 100644 crypto-common/serializable-state-derive/Cargo.toml create mode 100644 crypto-common/serializable-state-derive/src/lib.rs create mode 100644 crypto-common/serializable-state-derive/tests/serialization-tests.rs diff --git a/Cargo.lock b/Cargo.lock index c61a98b76..a4b8a4c56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -290,6 +290,7 @@ version = "0.2.0-pre" dependencies = [ "hybrid-array", "rand_core 0.6.4", + "serializable-state-derive", ] [[package]] @@ -324,6 +325,41 @@ dependencies = [ "zeroize", ] +[[package]] +name = "darling" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0209d94da627ab5605dcccf08bb18afa5009cfbef48d8a8b7d7bdbc79be25c5e" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "177e3443818124b357d8e76f53be906d60937f0d3a90773a664fa63fa253e621" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" +dependencies = [ + "darling_core", + "quote", + "syn", +] + [[package]] name = "der" version = "0.4.5" @@ -447,6 +483,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "generic-array" version = "0.14.7" @@ -602,6 +644,12 @@ dependencies = [ "typenum", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "inout" version = "0.1.3" @@ -788,18 +836,18 @@ checksum = "97e91cb6af081c6daad5fa705f8adb0634c027662052cb3174bdf2957bf07e25" [[package]] name = "proc-macro2" -version = "1.0.63" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.28" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] @@ -903,6 +951,17 @@ dependencies = [ "syn", ] +[[package]] +name = "serializable-state-derive" +version = "0.1.0" +dependencies = [ + "crypto-common 0.2.0-pre", + "darling", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "sha2" version = "0.9.9" @@ -981,6 +1040,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "subtle" version = "2.5.0" @@ -989,9 +1054,9 @@ checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" -version = "2.0.22" +version = "2.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2efbeae7acf4eabd6bcdcbd11c92f45231ddda7539edc7806bd1a04a03b24616" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" dependencies = [ "proc-macro2", "quote", diff --git a/crypto-common/Cargo.toml b/crypto-common/Cargo.toml index 11ff47d9a..dad9611f5 100644 --- a/crypto-common/Cargo.toml +++ b/crypto-common/Cargo.toml @@ -14,6 +14,7 @@ categories = ["cryptography", "no-std"] [dependencies] hybrid-array = "=0.2.0-pre.4" +serializable-state-derive = { path = "serializable-state-derive" } # optional dependencies rand_core = { version = "0.6.4", optional = true } diff --git a/crypto-common/serializable-state-derive/Cargo.toml b/crypto-common/serializable-state-derive/Cargo.toml new file mode 100644 index 000000000..bedc5048a --- /dev/null +++ b/crypto-common/serializable-state-derive/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "serializable-state-derive" +description = "Macro implementation of #[derive(SerializableState)]" +version = "0.1.0" +authors = ["Ruslan Piasetskyi "] +license = "MIT OR Apache-2.0" +edition = "2018" +documentation = "https://docs.rs/serializable-state-derive" +repository = "https://github.com/RustCrypto/traits" +keywords = ["serialization", "no_std", "derive"] +categories = ["cryptography", "no-std"] + +[lib] +proc-macro = true + +[dependencies] +darling = "0.20.3" +proc-macro2 = "1.0.69" +quote = "1.0.33" +syn = "2.0.38" + +[dev-dependencies] +crypto-common = { version = "0.2.0-pre", path = ".." } diff --git a/crypto-common/serializable-state-derive/src/lib.rs b/crypto-common/serializable-state-derive/src/lib.rs new file mode 100644 index 000000000..b3ab9747d --- /dev/null +++ b/crypto-common/serializable-state-derive/src/lib.rs @@ -0,0 +1,195 @@ +extern crate proc_macro; + +use darling::FromDeriveInput; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote, quote_spanned}; +use syn::{ + parse_macro_input, punctuated::Iter, spanned::Spanned, Data, DeriveInput, Field, Fields, + Generics, Index, +}; + +const CRATE_NAME: &str = "crypto_common"; + +#[derive(FromDeriveInput, Default)] +#[darling(default, attributes(serializable_state))] +struct Opts { + crate_path: Option, +} + +#[proc_macro_derive(SerializableState, attributes(serializable_state))] +pub fn derive_serializable_state(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input); + let crate_path = get_crate_path(&input); + let struct_name = input.ident; + + let serialized_state_size = generate_serializable_state_size(&input.data, &crate_path); + let serialize_logic = generate_serialize_logic(&input.data); + let deserialize_logic = generate_deserialize_logic(&input.data); + + check_generics(&input.generics); + + let expanded = quote! { + impl #crate_path::SerializableState for #struct_name { + type SerializedStateSize = #serialized_state_size; + + fn serialize(&self) -> #crate_path::SerializedState { + use #crate_path::{SerializableState, SerializedState}; + + #serialize_logic + } + + fn deserialize(_serialized_state: &#crate_path::SerializedState) -> ::core::result::Result { + use #crate_path::SerializableState; + + #deserialize_logic + } + } + }; + + proc_macro::TokenStream::from(expanded) +} + +fn check_generics(generics: &Generics) { + if generics.params.iter().next().is_some() { + panic!("Generics are not supported yet. Please implement SerializableState on your own.") + } +} + +fn generate_serializable_state_size(data: &Data, crate_path: &TokenStream) -> TokenStream { + match *data { + Data::Struct(ref data) => match data.fields { + Fields::Named(ref fields) => { + serializable_state_size_from_fields(fields.named.iter(), crate_path) + } + Fields::Unnamed(ref fields) => { + serializable_state_size_from_fields(fields.unnamed.iter(), crate_path) + } + Fields::Unit => quote! { #crate_path::typenum::U0 }, + }, + Data::Enum(_) | Data::Union(_) => unimplemented!(), + } +} + +fn generate_serialize_logic(data: &Data) -> TokenStream { + match *data { + Data::Struct(ref data) => match data.fields { + Fields::Named(ref fields) => serialize_logic_from_fields(fields.named.iter()), + Fields::Unnamed(ref fields) => serialize_logic_from_fields(fields.unnamed.iter()), + Fields::Unit => quote! { SerializedState::::default() }, + }, + Data::Enum(_) | Data::Union(_) => unimplemented!(), + } +} + +fn generate_deserialize_logic(data: &Data) -> TokenStream { + match *data { + Data::Struct(ref data) => match data.fields { + Fields::Named(ref fields) => deserialize_logic_from_fields(fields.named.iter(), true), + Fields::Unnamed(ref fields) => { + deserialize_logic_from_fields(fields.unnamed.iter(), false) + } + Fields::Unit => quote! { Ok(Self {}) }, + }, + Data::Enum(_) | Data::Union(_) => unimplemented!(), + } +} + +fn serializable_state_size_from_fields( + mut fields: Iter, + crate_path: &TokenStream, +) -> TokenStream { + match fields.next() { + None => quote! { #crate_path::typenum::U0 }, + Some(first) => { + let ty = &first.ty; + let mut size = quote_spanned! { first.span() => <#ty as #crate_path::SerializableState>::SerializedStateSize }; + fields.for_each(|field| { + let ty = &field.ty; + size = quote_spanned! { + field.span() => #crate_path::typenum::Sum<<#ty as #crate_path::SerializableState>::SerializedStateSize, #size> + }; + }); + size + } + } +} + +fn serialize_logic_from_fields(mut fields: Iter) -> TokenStream { + match fields.next() { + None => quote! { SerializedState::::default() }, + Some(first) => { + let field_name = get_field_name(0, &first.ident, None); + let mut code = quote! { self.#field_name.serialize() }; + fields.enumerate().for_each(|(i, field)| { + let field_name = get_field_name(i + 1, &field.ident, None); + code = + quote_spanned! { field.span() => #code.concat(self.#field_name.serialize()) }; + }); + code + } + } +} + +fn deserialize_logic_from_fields(fields: Iter, named: bool) -> TokenStream { + let mut skip_first = fields.clone(); + match skip_first.next() { + None => quote! { Ok(Self {}) }, + Some(first) => { + let mut code = quote!(); + fields.enumerate().for_each(|(i, field)| { + let field_name = get_field_name(i, &field.ident, Some("serialized_")); + let ty = &field.ty; + code = quote_spanned! { + field.span() => + #code + let (#field_name, _serialized_state) = _serialized_state.split_ref::<<#ty as SerializableState>::SerializedStateSize>(); + let #field_name = <#ty>::deserialize(#field_name)?; + }; + }); + let first = get_field_name(0, &first.ident, Some("serialized_")); + let recurce = skip_first.enumerate().map(|(i, field)| { + let field_name = get_field_name(i + 1, &field.ident, Some("serialized_")); + quote_spanned! { field.span() => #field_name } + }); + + let construction = if named { + quote! { Self { #first #(, #recurce)* } } + } else { + quote! { Self ( #first #(, #recurce)* ) } + }; + + quote! { + #code + + Ok(#construction) + } + } + } +} + +fn get_field_name(i: usize, ident: &Option, unnamed_prefix: Option<&str>) -> TokenStream { + match ident { + Some(ident) => quote! { #ident }, + None => match unnamed_prefix { + None => { + let index = Index::from(i); + quote! { #index } + } + Some(unnamed_prefix) => { + let ident = format_ident!("{}{}", unnamed_prefix, i); + quote! { #ident } + } + }, + } +} + +fn get_crate_path(input: &DeriveInput) -> TokenStream { + let crate_path = format_ident!( + "{}", + Opts::from_derive_input(input) + .expect("Unknown options") + .crate_path + .unwrap_or(CRATE_NAME.into()) + ); + quote! { #crate_path } +} diff --git a/crypto-common/serializable-state-derive/tests/serialization-tests.rs b/crypto-common/serializable-state-derive/tests/serialization-tests.rs new file mode 100644 index 000000000..f62ea5d02 --- /dev/null +++ b/crypto-common/serializable-state-derive/tests/serialization-tests.rs @@ -0,0 +1,71 @@ +use crypto_common as crypto_common_crate; +use crypto_common::SerializableState; + +macro_rules! serialization_test { + ($name:ident, $type: ty, $obj: expr, $serialized_state: expr) => { + #[test] + fn $name() { + let obj = $obj; + + let serialized_state = obj.serialize(); + assert_eq!(serialized_state.as_slice(), $serialized_state); + + let deserialized_obj = <$type>::deserialize(&serialized_state).unwrap(); + assert_eq!(deserialized_obj, obj); + } + }; +} + +#[derive(SerializableState, PartialEq, Debug)] +#[serializable_state(crate_path = "crypto_common_crate")] +struct StructWithNamedFields { + a: u8, + b: u64, + c: [u16; 3], +} + +serialization_test!( + struct_with_named_fields_serialization_test, + StructWithNamedFields, + StructWithNamedFields { + a: 0x42, + b: 0x1122334455667788, + c: [0xAABB, 0xCCDD, 0xEEFF], + }, + &[0x42, 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11, 0xBB, 0xAA, 0xDD, 0xCC, 0xFF, 0xEE] +); + +#[derive(SerializableState, PartialEq, Debug)] +struct StructWithZeroNamedFields {} + +serialization_test!( + struct_with_zero_named_fields_serialization_test, + StructWithZeroNamedFields, + StructWithZeroNamedFields {}, + &[] +); + +#[derive(SerializableState, PartialEq, Debug)] +struct StructWithUnnamedFields([u8; 5], u32); + +serialization_test!( + struct_with_unnamed_fields_serialization_test, + StructWithUnnamedFields, + StructWithUnnamedFields([0x11, 0x22, 0x33, 0x44, 0x55], 0xAABBCCDD), + &[0x11, 0x22, 0x33, 0x44, 0x55, 0xDD, 0xCC, 0xBB, 0xAA] +); + +#[derive(SerializableState, PartialEq, Debug)] +struct StructWithZeroUnnamedFields(); + +serialization_test!( + struct_with_zero_unnamed_fields_serialization_test, + StructWithZeroUnnamedFields, + StructWithZeroUnnamedFields(), + &[] +); + +#[derive(SerializableState, PartialEq, Debug)] +struct UnitStruct; + +serialization_test!(unit_struct_serialization_test, UnitStruct, UnitStruct, &[]); diff --git a/crypto-common/src/serializable_state.rs b/crypto-common/src/serializable_state.rs index effaea61e..245a3dc10 100644 --- a/crypto-common/src/serializable_state.rs +++ b/crypto-common/src/serializable_state.rs @@ -5,6 +5,8 @@ use crate::array::{ }; use core::{convert::TryInto, default::Default, fmt}; +pub use serializable_state_derive::*; + /// Serialized internal state. pub type SerializedState = ByteArray<::SerializedStateSize>;