From a064965672873494107297886f3f4b77740bfe4a Mon Sep 17 00:00:00 2001 From: Nasr Date: Mon, 10 Jun 2024 15:00:16 -0400 Subject: [PATCH] refactor: to handle arrays --- crates/dojo-bindgen/src/plugins/unity/mod.rs | 129 +++++++++++++------ 1 file changed, 87 insertions(+), 42 deletions(-) diff --git a/crates/dojo-bindgen/src/plugins/unity/mod.rs b/crates/dojo-bindgen/src/plugins/unity/mod.rs index a2dba05f54..5482767be0 100644 --- a/crates/dojo-bindgen/src/plugins/unity/mod.rs +++ b/crates/dojo-bindgen/src/plugins/unity/mod.rs @@ -231,7 +231,11 @@ public class {} : ModelInstance {{ // Handles a model definition and its referenced tokens // Will map all structs and enums to C# types // Will format the model into a C# class - fn handle_model(&self, model: &DojoModel, handled_tokens: &mut Vec) -> String { + fn handle_model( + &self, + model: &DojoModel, + handled_tokens: &mut HashMap, + ) -> String { let mut out = String::new(); out += UnityPlugin::generated_header().as_str(); out += UnityPlugin::model_imports().as_str(); @@ -239,11 +243,11 @@ public class {} : ModelInstance {{ let mut model_struct: Option<&Composite> = None; let tokens = &model.tokens; for token in &tokens.structs { - if handled_tokens.iter().any(|t| t.type_name() == token.type_name()) { + if handled_tokens.contains_key(&token.type_path()) { continue; } - handled_tokens.push(token.to_composite().unwrap().to_owned()); + handled_tokens.insert(token.type_path(), token.to_composite().unwrap().to_owned()); // first index is our model struct if token.type_name() == model.name { @@ -255,11 +259,11 @@ public class {} : ModelInstance {{ } for token in &tokens.enums { - if handled_tokens.iter().any(|t| t.type_name() == token.type_name()) { + if handled_tokens.contains_key(&token.type_path()) { continue; } - handled_tokens.push(token.to_composite().unwrap().to_owned()); + handled_tokens.insert(token.type_path(), token.to_composite().unwrap().to_owned()); out += UnityPlugin::format_enum(token.to_composite().unwrap()).as_str(); } @@ -273,47 +277,88 @@ public class {} : ModelInstance {{ // Formats a system into a C# method used by the contract class // Handled tokens should be a list of all structs and enums used by the contract // Such as a set of referenced tokens from a model - fn format_system(system: &Function, handled_tokens: &[Composite]) -> String { + fn format_system(system: &Function) -> String { fn handle_arg_recursive( - type_name: &str, + arg_name: &str, token: &Token, - handled_tokens: &[Composite], - ) -> String { + ) -> Vec<( + // formatted arg + String, + // if its an array + bool, + )> { + println!("Handling arg: {:?}", token); let mapped_type = UnityPlugin::map_type(token); - match handled_tokens.iter().find(|t| t.type_name() == token.type_name()) { - Some(t) => { + match token { + Token::Composite(t) => { // Need to flatten the struct members. match t.r#type { - CompositeType::Struct if t.type_name() == "ByteArray" => format!( - "calldata.AddRange(ByteArray.Serialize({}).Select(f => f.Inner));", - type_name - ), - CompositeType::Struct => t - .inners - .iter() - .map(|field| { - handle_arg_recursive( - &format!("{}.{}", type_name, field.name), - &field.token, - handled_tokens, - ) - }) - .collect::>() - .join("\n\t\t"), - CompositeType::Enum => format!( - "calldata.Add(new FieldElement({}.GetIndex({})).Inner);", - t.type_name(), - type_name - ), - _ => { - format!("calldata.Add(new FieldElement({}).Inner);", type_name) + CompositeType::Struct if t.type_name() == "ByteArray" => vec![( + format!("ByteArray.Serialize({}).Select(f => f.Inner)", arg_name), + true, + )], + CompositeType::Struct => { + let tokens = vec![]; + t.inners.iter().for_each(|f| { + tokens.extend(handle_arg_recursive( + &format!("{}.{}", arg_name, f.name), + &f.token, + )); + }); + + tokens + } + CompositeType::Enum => { + let tokens = vec![( + "new FieldElement({}.GetIndex({})).Inner".to_string(), + false, + )]; + + t.inners.iter().for_each(|field| { + if let Token::CoreBasic(basic) = &field.token { + // ignore unit type + if basic.type_path == "()" { + return; + } + } + + tokens.extend(handle_arg_recursive( + &format!("{}.value", arg_name), + &if let Token::GenericArg(generic_arg) = &field.token { + let generic_token = t + .generic_args + .iter() + .find(|(name, _)| name == generic_arg) + .unwrap() + .1 + .clone(); + println!("Generic token: {:?}", generic_token); + generic_token + } else { + field.token.clone() + }, + )) + }); + + tokens } } } - None => match mapped_type.as_str() { - "FieldElement" => format!("calldata.Add({}.Inner);", type_name), - _ => format!("calldata.Add(new FieldElement({}).Inner);", type_name), + Token::Array(array) => { + + } + Token::Tuple(tuple) => { + tuple + .inners + .iter() + .map(|(name, token)| handle_arg_recursive(name, token)) + .flatten() + .collect() + } + _ => match mapped_type.as_str() { + "FieldElement" => format!("calldata.Add({}.Inner);", arg_name), + _ => format!("calldata.Add(new FieldElement({}).Inner);", arg_name), }, } } @@ -328,7 +373,7 @@ public class {} : ModelInstance {{ let calldata = system .inputs .iter() - .map(|(name, token)| handle_arg_recursive(name, token, handled_tokens)) + .map(|(name, token)| handle_arg_recursive(name, token)) .collect::>() .join("\n\t\t"); @@ -374,7 +419,7 @@ public class {} : ModelInstance {{ // Will format the contract into a C# class and // all systems into C# methods // Handled tokens should be a list of all structs and enums used by the contract - fn handle_contract(&self, contract: &DojoContract, handled_tokens: &[Composite]) -> String { + fn handle_contract(&self, contract: &DojoContract) -> String { let mut out = String::new(); out += UnityPlugin::generated_header().as_str(); out += UnityPlugin::contract_imports().as_str(); @@ -384,7 +429,7 @@ public class {} : ModelInstance {{ .iter() // we assume systems dont have outputs .filter(|s| s.to_function().unwrap().get_output_kind() as u8 == FunctionOutputKind::NoOutput as u8) - .map(|system| UnityPlugin::format_system(system.to_function().unwrap(), handled_tokens)) + .map(|system| UnityPlugin::format_system(system.to_function().unwrap())) .collect::>() .join("\n\n "); @@ -412,7 +457,7 @@ public class {} : MonoBehaviour {{ impl BuiltinPlugin for UnityPlugin { async fn generate_code(&self, data: &DojoData) -> BindgenResult>> { let mut out: HashMap> = HashMap::new(); - let mut handled_tokens = Vec::::new(); + let mut handled_tokens = HashMap::::new(); // Handle codegen for models for (name, model) in &data.models { @@ -429,7 +474,7 @@ impl BuiltinPlugin for UnityPlugin { let contracts_path = Path::new(&format!("Contracts/{}.gen.cs", name)).to_owned(); println!("Generating contract: {}", name); - let code = self.handle_contract(contract, &handled_tokens); + let code = self.handle_contract(contract); out.insert(contracts_path, code.as_bytes().to_vec()); }