Skip to content

Commit

Permalink
@newtype on array/map wrappers (#237)
Browse files Browse the repository at this point in the history
* @newtype on array/map wrappers

Previously this only worked on primitive wrappers.

Now you can do e.g. `foo = [* uint] ; @newtype` which was ignored
before.

* cargo fmt

* tests added
  • Loading branch information
rooooooooob authored Jun 18, 2024
1 parent e75a0ad commit eadc043
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 21 deletions.
27 changes: 22 additions & 5 deletions src/intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ impl AliasInfo {
}
}

#[derive(Debug, Clone)]
pub struct PlainGroupInfo<'a> {
group: Option<cddl::ast::Group<'a>>,
rule_metadata: RuleMetadata,
}

impl<'a> PlainGroupInfo<'a> {
pub fn new(group: Option<cddl::ast::Group<'a>>, rule_metadata: RuleMetadata) -> Self {
Self {
group,
rule_metadata,
}
}
}

#[derive(Debug)]
pub struct IntermediateTypes<'a> {
// Storing the cddl::Group is the easiest way to go here even after the parse/codegen split.
Expand All @@ -99,7 +114,7 @@ pub struct IntermediateTypes<'a> {
// delayed until the point where it is referenced via self.set_rep_if_plain_group(rep)
// Some(group) = directly defined in .cddl (must call set_plain_group_representatio() later)
// None = indirectly generated due to a group choice (no reason to call set_rep_if_plain_group() later but it won't crash)
plain_groups: BTreeMap<RustIdent, Option<cddl::ast::Group<'a>>>,
plain_groups: BTreeMap<RustIdent, PlainGroupInfo<'a>>,
type_aliases: BTreeMap<AliasIdent, AliasInfo>,
rust_structs: BTreeMap<RustIdent, RustStruct>,
prelude_to_emit: BTreeSet<String>,
Expand Down Expand Up @@ -642,8 +657,8 @@ impl<'a> IntermediateTypes<'a> {
}

// see self.plain_groups comments
pub fn mark_plain_group(&mut self, ident: RustIdent, group: Option<cddl::ast::Group<'a>>) {
self.plain_groups.insert(ident, group);
pub fn mark_plain_group(&mut self, ident: RustIdent, group_info: PlainGroupInfo<'a>) {
self.plain_groups.insert(ident, group_info);
}

// see self.plain_groups comments
Expand All @@ -656,7 +671,8 @@ impl<'a> IntermediateTypes<'a> {
) {
if let Some(plain_group) = self.plain_groups.get(ident) {
// the clone is to get around the borrow checker
if let Some(group) = plain_group.as_ref().cloned() {
let plain_group = plain_group.clone();
if let Some(group) = plain_group.group.as_ref() {
// we are defined via .cddl and thus need to register a concrete
// representation of the plain group
if let Some(rust_struct) = self.rust_structs.get(ident) {
Expand All @@ -673,11 +689,12 @@ impl<'a> IntermediateTypes<'a> {
crate::parsing::parse_group(
self,
parent_visitor,
&group,
group,
ident,
rep,
None,
None,
&plain_group.rule_metadata,
cli,
);
}
Expand Down
13 changes: 10 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ pub(crate) mod utils;

use clap::Parser;
use cli::Cli;
use comment_ast::RuleMetadata;
use generation::GenerationScope;
use intermediate::{CDDLIdent, IntermediateTypes, RustIdent};
use intermediate::{CDDLIdent, IntermediateTypes, PlainGroupInfo, RustIdent};
use once_cell::sync::Lazy;
use parsing::{parse_rule, rule_ident, rule_is_scope_marker};

Expand Down Expand Up @@ -128,10 +129,16 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
if let cddl::ast::Rule::Group { rule, .. } = cddl_rule {
// Freely defined group - no need to generate anything outside of group module
match &rule.entry {
cddl::ast::GroupEntry::InlineGroup { group, .. } => {
cddl::ast::GroupEntry::InlineGroup {
group,
comments_after_group,
..
} => {
assert_eq!(group.group_choices.len(), 1);
let rule_metadata = RuleMetadata::from(comments_after_group.as_ref());
types.mark_plain_group(
RustIdent::new(CDDLIdent::new(rule.name.to_string())),
Some(group.clone()),
PlainGroupInfo::new(Some(group.clone()), rule_metadata),
);
}
x => panic!("Group rule with non-inline group? {:?}", x),
Expand Down
69 changes: 56 additions & 13 deletions src/parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use std::collections::BTreeMap;
use crate::comment_ast::{merge_metadata, metadata_from_comments, RuleMetadata};
use crate::intermediate::{
AliasInfo, CBOREncodingOperation, CDDLIdent, ConceptualRustType, EnumVariant, FixedValue,
GenericDef, GenericInstance, IntermediateTypes, ModuleScope, Primitive, Representation,
RustField, RustIdent, RustRecord, RustStruct, RustStructType, RustType, VariantIdent,
GenericDef, GenericInstance, IntermediateTypes, ModuleScope, PlainGroupInfo, Primitive,
Representation, RustField, RustIdent, RustRecord, RustStruct, RustStructType, RustType,
VariantIdent,
};
use crate::utils::{
append_number_if_duplicate, convert_to_camel_case, convert_to_snake_case,
Expand Down Expand Up @@ -640,6 +641,7 @@ fn parse_type(
Representation::Map,
outer_tag,
generic_params,
&rule_metadata,
cli,
);
}
Expand All @@ -654,6 +656,7 @@ fn parse_type(
Representation::Array,
outer_tag,
generic_params,
&rule_metadata,
cli,
);
}
Expand Down Expand Up @@ -1222,6 +1225,7 @@ fn rust_type_from_type2(
Representation::Array,
None,
None,
&rule_metadata,
cli,
);
// we aren't returning an array, but rather a struct where the fields are ordered
Expand Down Expand Up @@ -1446,27 +1450,59 @@ fn parse_group_choice(
rep: Representation,
tag: Option<usize>,
generic_params: Option<Vec<RustIdent>>,
parent_rule_metadata: Option<&RuleMetadata>,
cli: &Cli,
) {
let rule_metadata = RuleMetadata::from(
get_comment_after(parent_visitor, &CDDLType::from(group_choice), None).as_ref(),
);
let rule_metadata = if let Some(parent_rule_metadata) = parent_rule_metadata {
merge_metadata(&rule_metadata, parent_rule_metadata)
} else {
rule_metadata
};
let rust_struct = match parse_group_type(types, parent_visitor, group_choice, rep, cli) {
GroupParsingType::HomogenousArray(element_type) => {
// Array - homogeneous element type with proper occurence operator
RustStruct::new_array(name.clone(), tag, Some(&rule_metadata), element_type)
if rule_metadata.is_newtype {
// generate newtype over array
RustStruct::new_wrapper(
name.clone(),
tag,
Some(&rule_metadata),
ConceptualRustType::Array(Box::new(element_type)).into(),
None,
)
} else {
// Array - homogeneous element type with proper occurence operator
RustStruct::new_array(name.clone(), tag, Some(&rule_metadata), element_type)
}
}
GroupParsingType::HomogenousMap(key_type, value_type) => {
// Table map - homogeneous key/value types
RustStruct::new_table(
name.clone(),
tag,
Some(&rule_metadata),
key_type,
value_type,
)
if rule_metadata.is_newtype {
// generate newtype over map
RustStruct::new_wrapper(
name.clone(),
tag,
Some(&rule_metadata),
ConceptualRustType::Map(Box::new(key_type), Box::new(value_type)).into(),
None,
)
} else {
// Table map - homogeneous key/value types
RustStruct::new_table(
name.clone(),
tag,
Some(&rule_metadata),
key_type,
value_type,
)
}
}
GroupParsingType::Heterogenous | GroupParsingType::WrappedBasicGroup(_) => {
assert!(
!rule_metadata.is_newtype,
"Can only use @newtype on primtives + heterogenious arrays/maps"
);
// Heterogenous map or array with defined key/value pairs in the cddl like a struct
let record =
parse_record_from_group_choice(types, rep, parent_visitor, group_choice, cli);
Expand All @@ -1489,6 +1525,7 @@ pub fn parse_group(
rep: Representation,
tag: Option<usize>,
generic_params: Option<Vec<RustIdent>>,
parent_rule_metadata: &RuleMetadata,
cli: &Cli,
) {
if group.group_choices.len() == 1 {
Expand All @@ -1501,12 +1538,14 @@ pub fn parse_group(
rep,
tag,
generic_params,
Some(parent_rule_metadata),
cli,
);
} else {
if generic_params.is_some() {
todo!("{}: generic group choices not supported", name);
}
assert!(!parent_rule_metadata.is_newtype);
// Generate Enum object that is not exposed to wasm, since wasm can't expose
// fully featured rust enums via wasm_bindgen

Expand Down Expand Up @@ -1570,7 +1609,10 @@ pub fn parse_group(
let ident_name = rule_metadata.name.unwrap_or_else(|| format!("{name}{i}"));
// General case, GroupN type identifiers and generate group choice since it's inlined here
let variant_name = RustIdent::new(CDDLIdent::new(ident_name));
types.mark_plain_group(variant_name.clone(), None);
types.mark_plain_group(
variant_name.clone(),
PlainGroupInfo::new(None, RuleMetadata::default()),
);
parse_group_choice(
types,
parent_visitor,
Expand All @@ -1579,6 +1621,7 @@ pub fn parse_group(
rep,
None,
generic_params.clone(),
None,
cli,
);
let name = VariantIdent::new_rust(variant_name.clone());
Expand Down
3 changes: 3 additions & 0 deletions tests/core/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ inline_wrapper = [{ * text => text }]
top_level_array = [* uint]
top_level_single_elem = [uint]

wrapper_table = { * uint => uint } ; @newtype
wrapper_list = [ * uint ] ; @newtype

overlapping_inlined = [
; @name one
0 //
Expand Down
31 changes: 31 additions & 0 deletions tests/core/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,4 +531,35 @@ mod tests {
].into_iter().flatten().clone().collect::<Vec<u8>>();
assert_eq!(expected_bytes, struct_with_custom_bytes.to_cbor_bytes());
}

#[test]
fn wrapper_table() {
use cbor_event::Sz;
let bytes = vec![
map_sz(3, Sz::Inline),
cbor_int(5, Sz::Inline),
cbor_int(4, Sz::Inline),
cbor_int(3, Sz::Inline),
cbor_int(2, Sz::Inline),
cbor_int(1, Sz::Inline),
cbor_int(0, Sz::Inline),
].into_iter().flatten().clone().collect::<Vec<u8>>();
let from_bytes = WrapperTable::from_cbor_bytes(&bytes).unwrap();
deser_test(&from_bytes);
}

#[test]
fn wrapper_list() {
use cbor_event::Sz;
let bytes = vec![
arr_sz(5, Sz::Inline),
cbor_int(5, Sz::Inline),
cbor_int(4, Sz::Inline),
cbor_int(3, Sz::Inline),
cbor_int(2, Sz::Inline),
cbor_int(1, Sz::Inline),
].into_iter().flatten().clone().collect::<Vec<u8>>();
let from_bytes = WrapperList::from_cbor_bytes(&bytes).unwrap();
deser_test(&from_bytes);
}
}
3 changes: 3 additions & 0 deletions tests/preserve-encodings/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,6 @@ struct_with_custom_serialization = [
tagged1: #6.9(custom_bytes),
tagged2: #6.9(uint), ; @custom_serialize write_tagged_uint_str @custom_deserialize read_tagged_uint_str
]

wrapper_table = { * uint => uint } ; @newtype
wrapper_list = [ * uint ] ; @newtype
35 changes: 35 additions & 0 deletions tests/preserve-encodings/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1264,4 +1264,39 @@ mod tests {
}
}
}

#[test]
fn wrapper_table() {
let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight];
for def_enc in &def_encodings {
let irregular_bytes = vec![
map_sz(3, *def_enc),
cbor_int(5, *def_enc),
cbor_int(4, *def_enc),
cbor_int(3, *def_enc),
cbor_int(2, *def_enc),
cbor_int(1, *def_enc),
cbor_int(0, *def_enc),
].into_iter().flatten().clone().collect::<Vec<u8>>();
let from_bytes = WrapperTable::from_cbor_bytes(&irregular_bytes).unwrap();
assert_eq!(from_bytes.to_cbor_bytes(), irregular_bytes);
}
}

#[test]
fn wrapper_list() {
let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight];
for def_enc in &def_encodings {
let irregular_bytes = vec![
arr_sz(5, *def_enc),
cbor_int(5, *def_enc),
cbor_int(4, *def_enc),
cbor_int(3, *def_enc),
cbor_int(2, *def_enc),
cbor_int(1, *def_enc),
].into_iter().flatten().clone().collect::<Vec<u8>>();
let from_bytes = WrapperList::from_cbor_bytes(&irregular_bytes).unwrap();
assert_eq!(from_bytes.to_cbor_bytes(), irregular_bytes);
}
}
}

0 comments on commit eadc043

Please sign in to comment.