From eadc0435df629d179b7c79fbe67d65e9f5ab1954 Mon Sep 17 00:00:00 2001 From: rooooooooob Date: Tue, 18 Jun 2024 14:52:44 +0900 Subject: [PATCH] @newtype on array/map wrappers (#237) * @newtype on array/map wrappers Previously this only worked on primitive wrappers. Now you can do e.g. `foo = [* uint] ; @newtype` which was ignored before. * cargo fmt * tests added --- src/intermediate.rs | 27 ++++++++--- src/main.rs | 13 ++++-- src/parsing.rs | 69 +++++++++++++++++++++++------ tests/core/input.cddl | 3 ++ tests/core/tests.rs | 31 +++++++++++++ tests/preserve-encodings/input.cddl | 3 ++ tests/preserve-encodings/tests.rs | 35 +++++++++++++++ 7 files changed, 160 insertions(+), 21 deletions(-) diff --git a/src/intermediate.rs b/src/intermediate.rs index a10fb1e..9654bb4 100644 --- a/src/intermediate.rs +++ b/src/intermediate.rs @@ -90,6 +90,21 @@ impl AliasInfo { } } +#[derive(Debug, Clone)] +pub struct PlainGroupInfo<'a> { + group: Option>, + rule_metadata: RuleMetadata, +} + +impl<'a> PlainGroupInfo<'a> { + pub fn new(group: Option>, rule_metadata: RuleMetadata) -> Self { + Self { + group, + rule_metadata, + } + } +} + #[derive(Debug)] pub struct IntermediateTypes<'a> { // Storing the cddl::Group is the easiest way to go here even after the parse/codegen split. @@ -99,7 +114,7 @@ pub struct IntermediateTypes<'a> { // delayed until the point where it is referenced via self.set_rep_if_plain_group(rep) // Some(group) = directly defined in .cddl (must call set_plain_group_representatio() later) // None = indirectly generated due to a group choice (no reason to call set_rep_if_plain_group() later but it won't crash) - plain_groups: BTreeMap>>, + plain_groups: BTreeMap>, type_aliases: BTreeMap, rust_structs: BTreeMap, prelude_to_emit: BTreeSet, @@ -642,8 +657,8 @@ impl<'a> IntermediateTypes<'a> { } // see self.plain_groups comments - pub fn mark_plain_group(&mut self, ident: RustIdent, group: Option>) { - self.plain_groups.insert(ident, group); + pub fn mark_plain_group(&mut self, ident: RustIdent, group_info: PlainGroupInfo<'a>) { + self.plain_groups.insert(ident, group_info); } // see self.plain_groups comments @@ -656,7 +671,8 @@ impl<'a> IntermediateTypes<'a> { ) { if let Some(plain_group) = self.plain_groups.get(ident) { // the clone is to get around the borrow checker - if let Some(group) = plain_group.as_ref().cloned() { + let plain_group = plain_group.clone(); + if let Some(group) = plain_group.group.as_ref() { // we are defined via .cddl and thus need to register a concrete // representation of the plain group if let Some(rust_struct) = self.rust_structs.get(ident) { @@ -673,11 +689,12 @@ impl<'a> IntermediateTypes<'a> { crate::parsing::parse_group( self, parent_visitor, - &group, + group, ident, rep, None, None, + &plain_group.rule_metadata, cli, ); } diff --git a/src/main.rs b/src/main.rs index 7c36177..27004ae 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,8 +9,9 @@ pub(crate) mod utils; use clap::Parser; use cli::Cli; +use comment_ast::RuleMetadata; use generation::GenerationScope; -use intermediate::{CDDLIdent, IntermediateTypes, RustIdent}; +use intermediate::{CDDLIdent, IntermediateTypes, PlainGroupInfo, RustIdent}; use once_cell::sync::Lazy; use parsing::{parse_rule, rule_ident, rule_is_scope_marker}; @@ -128,10 +129,16 @@ fn main() -> Result<(), Box> { if let cddl::ast::Rule::Group { rule, .. } = cddl_rule { // Freely defined group - no need to generate anything outside of group module match &rule.entry { - cddl::ast::GroupEntry::InlineGroup { group, .. } => { + cddl::ast::GroupEntry::InlineGroup { + group, + comments_after_group, + .. + } => { + assert_eq!(group.group_choices.len(), 1); + let rule_metadata = RuleMetadata::from(comments_after_group.as_ref()); types.mark_plain_group( RustIdent::new(CDDLIdent::new(rule.name.to_string())), - Some(group.clone()), + PlainGroupInfo::new(Some(group.clone()), rule_metadata), ); } x => panic!("Group rule with non-inline group? {:?}", x), diff --git a/src/parsing.rs b/src/parsing.rs index 62c8fbf..948b047 100644 --- a/src/parsing.rs +++ b/src/parsing.rs @@ -6,8 +6,9 @@ use std::collections::BTreeMap; use crate::comment_ast::{merge_metadata, metadata_from_comments, RuleMetadata}; use crate::intermediate::{ AliasInfo, CBOREncodingOperation, CDDLIdent, ConceptualRustType, EnumVariant, FixedValue, - GenericDef, GenericInstance, IntermediateTypes, ModuleScope, Primitive, Representation, - RustField, RustIdent, RustRecord, RustStruct, RustStructType, RustType, VariantIdent, + GenericDef, GenericInstance, IntermediateTypes, ModuleScope, PlainGroupInfo, Primitive, + Representation, RustField, RustIdent, RustRecord, RustStruct, RustStructType, RustType, + VariantIdent, }; use crate::utils::{ append_number_if_duplicate, convert_to_camel_case, convert_to_snake_case, @@ -640,6 +641,7 @@ fn parse_type( Representation::Map, outer_tag, generic_params, + &rule_metadata, cli, ); } @@ -654,6 +656,7 @@ fn parse_type( Representation::Array, outer_tag, generic_params, + &rule_metadata, cli, ); } @@ -1222,6 +1225,7 @@ fn rust_type_from_type2( Representation::Array, None, None, + &rule_metadata, cli, ); // we aren't returning an array, but rather a struct where the fields are ordered @@ -1446,27 +1450,59 @@ fn parse_group_choice( rep: Representation, tag: Option, generic_params: Option>, + parent_rule_metadata: Option<&RuleMetadata>, cli: &Cli, ) { let rule_metadata = RuleMetadata::from( get_comment_after(parent_visitor, &CDDLType::from(group_choice), None).as_ref(), ); + let rule_metadata = if let Some(parent_rule_metadata) = parent_rule_metadata { + merge_metadata(&rule_metadata, parent_rule_metadata) + } else { + rule_metadata + }; let rust_struct = match parse_group_type(types, parent_visitor, group_choice, rep, cli) { GroupParsingType::HomogenousArray(element_type) => { - // Array - homogeneous element type with proper occurence operator - RustStruct::new_array(name.clone(), tag, Some(&rule_metadata), element_type) + if rule_metadata.is_newtype { + // generate newtype over array + RustStruct::new_wrapper( + name.clone(), + tag, + Some(&rule_metadata), + ConceptualRustType::Array(Box::new(element_type)).into(), + None, + ) + } else { + // Array - homogeneous element type with proper occurence operator + RustStruct::new_array(name.clone(), tag, Some(&rule_metadata), element_type) + } } GroupParsingType::HomogenousMap(key_type, value_type) => { - // Table map - homogeneous key/value types - RustStruct::new_table( - name.clone(), - tag, - Some(&rule_metadata), - key_type, - value_type, - ) + if rule_metadata.is_newtype { + // generate newtype over map + RustStruct::new_wrapper( + name.clone(), + tag, + Some(&rule_metadata), + ConceptualRustType::Map(Box::new(key_type), Box::new(value_type)).into(), + None, + ) + } else { + // Table map - homogeneous key/value types + RustStruct::new_table( + name.clone(), + tag, + Some(&rule_metadata), + key_type, + value_type, + ) + } } GroupParsingType::Heterogenous | GroupParsingType::WrappedBasicGroup(_) => { + assert!( + !rule_metadata.is_newtype, + "Can only use @newtype on primtives + heterogenious arrays/maps" + ); // Heterogenous map or array with defined key/value pairs in the cddl like a struct let record = parse_record_from_group_choice(types, rep, parent_visitor, group_choice, cli); @@ -1489,6 +1525,7 @@ pub fn parse_group( rep: Representation, tag: Option, generic_params: Option>, + parent_rule_metadata: &RuleMetadata, cli: &Cli, ) { if group.group_choices.len() == 1 { @@ -1501,12 +1538,14 @@ pub fn parse_group( rep, tag, generic_params, + Some(parent_rule_metadata), cli, ); } else { if generic_params.is_some() { todo!("{}: generic group choices not supported", name); } + assert!(!parent_rule_metadata.is_newtype); // Generate Enum object that is not exposed to wasm, since wasm can't expose // fully featured rust enums via wasm_bindgen @@ -1570,7 +1609,10 @@ pub fn parse_group( let ident_name = rule_metadata.name.unwrap_or_else(|| format!("{name}{i}")); // General case, GroupN type identifiers and generate group choice since it's inlined here let variant_name = RustIdent::new(CDDLIdent::new(ident_name)); - types.mark_plain_group(variant_name.clone(), None); + types.mark_plain_group( + variant_name.clone(), + PlainGroupInfo::new(None, RuleMetadata::default()), + ); parse_group_choice( types, parent_visitor, @@ -1579,6 +1621,7 @@ pub fn parse_group( rep, None, generic_params.clone(), + None, cli, ); let name = VariantIdent::new_rust(variant_name.clone()); diff --git a/tests/core/input.cddl b/tests/core/input.cddl index 0197907..3ffbe2d 100644 --- a/tests/core/input.cddl +++ b/tests/core/input.cddl @@ -185,6 +185,9 @@ inline_wrapper = [{ * text => text }] top_level_array = [* uint] top_level_single_elem = [uint] +wrapper_table = { * uint => uint } ; @newtype +wrapper_list = [ * uint ] ; @newtype + overlapping_inlined = [ ; @name one 0 // diff --git a/tests/core/tests.rs b/tests/core/tests.rs index 7bd41a6..7273fe0 100644 --- a/tests/core/tests.rs +++ b/tests/core/tests.rs @@ -531,4 +531,35 @@ mod tests { ].into_iter().flatten().clone().collect::>(); assert_eq!(expected_bytes, struct_with_custom_bytes.to_cbor_bytes()); } + + #[test] + fn wrapper_table() { + use cbor_event::Sz; + let bytes = vec![ + map_sz(3, Sz::Inline), + cbor_int(5, Sz::Inline), + cbor_int(4, Sz::Inline), + cbor_int(3, Sz::Inline), + cbor_int(2, Sz::Inline), + cbor_int(1, Sz::Inline), + cbor_int(0, Sz::Inline), + ].into_iter().flatten().clone().collect::>(); + let from_bytes = WrapperTable::from_cbor_bytes(&bytes).unwrap(); + deser_test(&from_bytes); + } + + #[test] + fn wrapper_list() { + use cbor_event::Sz; + let bytes = vec![ + arr_sz(5, Sz::Inline), + cbor_int(5, Sz::Inline), + cbor_int(4, Sz::Inline), + cbor_int(3, Sz::Inline), + cbor_int(2, Sz::Inline), + cbor_int(1, Sz::Inline), + ].into_iter().flatten().clone().collect::>(); + let from_bytes = WrapperList::from_cbor_bytes(&bytes).unwrap(); + deser_test(&from_bytes); + } } diff --git a/tests/preserve-encodings/input.cddl b/tests/preserve-encodings/input.cddl index a06b760..208ff18 100644 --- a/tests/preserve-encodings/input.cddl +++ b/tests/preserve-encodings/input.cddl @@ -211,3 +211,6 @@ struct_with_custom_serialization = [ tagged1: #6.9(custom_bytes), tagged2: #6.9(uint), ; @custom_serialize write_tagged_uint_str @custom_deserialize read_tagged_uint_str ] + +wrapper_table = { * uint => uint } ; @newtype +wrapper_list = [ * uint ] ; @newtype diff --git a/tests/preserve-encodings/tests.rs b/tests/preserve-encodings/tests.rs index 6b8d168..33cb936 100644 --- a/tests/preserve-encodings/tests.rs +++ b/tests/preserve-encodings/tests.rs @@ -1264,4 +1264,39 @@ mod tests { } } } + + #[test] + fn wrapper_table() { + let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight]; + for def_enc in &def_encodings { + let irregular_bytes = vec![ + map_sz(3, *def_enc), + cbor_int(5, *def_enc), + cbor_int(4, *def_enc), + cbor_int(3, *def_enc), + cbor_int(2, *def_enc), + cbor_int(1, *def_enc), + cbor_int(0, *def_enc), + ].into_iter().flatten().clone().collect::>(); + let from_bytes = WrapperTable::from_cbor_bytes(&irregular_bytes).unwrap(); + assert_eq!(from_bytes.to_cbor_bytes(), irregular_bytes); + } + } + + #[test] + fn wrapper_list() { + let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight]; + for def_enc in &def_encodings { + let irregular_bytes = vec![ + arr_sz(5, *def_enc), + cbor_int(5, *def_enc), + cbor_int(4, *def_enc), + cbor_int(3, *def_enc), + cbor_int(2, *def_enc), + cbor_int(1, *def_enc), + ].into_iter().flatten().clone().collect::>(); + let from_bytes = WrapperList::from_cbor_bytes(&irregular_bytes).unwrap(); + assert_eq!(from_bytes.to_cbor_bytes(), irregular_bytes); + } + } }