diff --git a/crates/dojo-bindgen/src/plugins/unity/mod.rs b/crates/dojo-bindgen/src/plugins/unity/mod.rs index 2827519351..45578f2267 100644 --- a/crates/dojo-bindgen/src/plugins/unity/mod.rs +++ b/crates/dojo-bindgen/src/plugins/unity/mod.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::path::{Path, PathBuf}; use async_trait::async_trait; -use cainome::parser::tokens::{Composite, CompositeType, Function, Token}; +use cainome::parser::tokens::{Composite, CompositeType, Function, FunctionOutputKind, Token}; use crate::error::BindgenResult; use crate::plugins::BuiltinPlugin; @@ -87,6 +87,32 @@ impl UnityPlugin { ) } + fn contract_imports() -> String { + "using System; +using System.Threading.Tasks; +using Dojo; +using Dojo.Starknet; +using UnityEngine; +using dojo_bindings; +using System.Collections.Generic; +using System.Linq; +using Enum = Dojo.Starknet.Enum; +" + .to_string() + } + + fn model_imports() -> String { + "using System; +using Dojo; +using Dojo.Starknet; +using System.Reflection; +using System.Linq; +using System.Collections.Generic; +using Enum = Dojo.Starknet.Enum; +" + .to_string() + } + // Token should be a struct // This will be formatted into a C# struct // using C# and unity SDK types @@ -116,7 +142,8 @@ public struct {} {{ // This will be formatted into a C# enum // Enum is mapped using index of cairo enum fn format_enum(token: &Composite) -> String { - let mut name_with_generics = token.type_name(); + let name = token.type_name(); + let mut name_with_generics = name.clone(); if !token.generic_args.is_empty() { name_with_generics += &format!( "<{}>", @@ -127,7 +154,7 @@ public struct {} {{ let mut result = format!( " // Type definition for `{}` enum -public abstract record {}() {{", +public abstract record {}() : Enum {{", token.type_path, name_with_generics ); @@ -189,21 +216,23 @@ public class {} : ModelInstance {{ // Handles a model definition and its referenced tokens // Will map all structs and enums to C# types // Will format the model into a C# class - fn handle_model(&self, model: &DojoModel, handled_tokens: &mut Vec) -> String { + fn handle_model( + &self, + model: &DojoModel, + handled_tokens: &mut HashMap, + ) -> String { let mut out = String::new(); out += UnityPlugin::generated_header().as_str(); - out += "using System;\n"; - out += "using Dojo;\n"; - out += "using Dojo.Starknet;\n"; + out += UnityPlugin::model_imports().as_str(); let mut model_struct: Option<&Composite> = None; let tokens = &model.tokens; for token in &tokens.structs { - if handled_tokens.iter().any(|t| t.type_name() == token.type_name()) { + if handled_tokens.contains_key(&token.type_path()) { continue; } - handled_tokens.push(token.to_composite().unwrap().to_owned()); + handled_tokens.insert(token.type_path(), token.to_composite().unwrap().to_owned()); // first index is our model struct if token.type_name() == model.name { @@ -215,11 +244,11 @@ public class {} : ModelInstance {{ } for token in &tokens.enums { - if handled_tokens.iter().any(|t| t.type_name() == token.type_name()) { + if handled_tokens.contains_key(&token.type_path()) { continue; } - handled_tokens.push(token.to_composite().unwrap().to_owned()); + handled_tokens.insert(token.type_path(), token.to_composite().unwrap().to_owned()); out += UnityPlugin::format_enum(token.to_composite().unwrap()).as_str(); } @@ -233,7 +262,145 @@ public class {} : ModelInstance {{ // Formats a system into a C# method used by the contract class // Handled tokens should be a list of all structs and enums used by the contract // Such as a set of referenced tokens from a model - fn format_system(system: &Function, handled_tokens: &[Composite]) -> String { + fn format_system(system: &Function, handled_tokens: &HashMap) -> String { + fn handle_arg_recursive( + arg_name: &str, + token: &Token, + handled_tokens: &HashMap, + // variant name + // if its an enum variant data + enum_variant: Option, + ) -> Vec<( + // formatted arg + String, + // if its an array + bool, + // enum name and variant name + // if its an enum variant data + Option, + )> { + let mapped_type = UnityPlugin::map_type(token); + + match token { + Token::Composite(t) => { + let t = handled_tokens.get(&t.type_path).unwrap_or(t); + + // Need to flatten the struct members. + match t.r#type { + CompositeType::Struct if t.type_name() == "ByteArray" => vec![( + format!("ByteArray.Serialize({}).Select(f => f.Inner)", arg_name), + true, + enum_variant, + )], + CompositeType::Struct => { + let mut tokens = vec![]; + t.inners.iter().for_each(|f| { + tokens.extend(handle_arg_recursive( + &format!("{}.{}", arg_name, f.name), + &f.token, + handled_tokens, + enum_variant.clone(), + )); + }); + + tokens + } + CompositeType::Enum => { + let mut tokens = vec![( + format!("new FieldElement(Enum.GetIndex({})).Inner", arg_name), + false, + enum_variant, + )]; + + t.inners.iter().for_each(|field| { + if let Token::CoreBasic(basic) = &field.token { + // ignore unit type + if basic.type_path == "()" { + return; + } + } + + tokens.extend(handle_arg_recursive( + &format!( + "(({}.{}){}).value", + mapped_type, + field.name.clone(), + arg_name + ), + &if let Token::GenericArg(generic_arg) = &field.token { + let generic_token = t + .generic_args + .iter() + .find(|(name, _)| name == generic_arg) + .unwrap() + .1 + .clone(); + generic_token + } else { + field.token.clone() + }, + handled_tokens, + Some(field.name.clone()), + )) + }); + + tokens + } + CompositeType::Unknown => panic!("Unknown composite type: {:?}", t), + } + } + Token::Array(array) => { + let is_inner_array = matches!(array.inner.as_ref(), Token::Array(_)); + let inner = handle_arg_recursive( + &format!("{arg_name}Item"), + &array.inner, + handled_tokens, + enum_variant.clone(), + ); + + let inners = inner + .into_iter() + .map(|(arg, _, _)| arg) + .collect::>() + .join(", "); + + vec![( + if is_inner_array { + format!( + "{arg_name}.SelectMany({arg_name}Item => new dojo.FieldElement[] \ + {{ }}.Concat({inners}))" + ) + } else { + format!( + "{arg_name}.SelectMany({arg_name}Item => new [] {{ {inners} }})" + ) + }, + true, + enum_variant.clone(), + )] + } + Token::Tuple(tuple) => tuple + .inners + .iter() + .enumerate() + .flat_map(|(idx, token)| { + handle_arg_recursive( + &format!("{}.Item{}", arg_name, idx + 1), + token, + handled_tokens, + enum_variant.clone(), + ) + }) + .collect(), + _ => match mapped_type.as_str() { + "FieldElement" => vec![(format!("{}.Inner", arg_name), false, enum_variant)], + _ => { + vec![(format!("new FieldElement({}).Inner", arg_name), false, enum_variant)] + } + }, + } + } + let args = system .inputs .iter() @@ -244,35 +411,31 @@ public class {} : ModelInstance {{ let calldata = system .inputs .iter() - .map(|arg| { - let token = &arg.1; - let type_name = &arg.0; - - match handled_tokens.iter().find(|t| t.type_name() == token.type_name()) { - Some(t) => { - // Need to flatten the struct members. - match t.r#type { - CompositeType::Struct => t - .inners - .iter() - .map(|field| { - format!("new FieldElement({}.{}).Inner", type_name, field.name) - }) - .collect::>() - .join(",\n "), - _ => { - format!("new FieldElement({}).Inner", type_name) - } + .flat_map(|(name, token)| { + let tokens = handle_arg_recursive(name, token, handled_tokens, None); + + tokens + .iter() + .map(|(arg, is_array, enum_variant)| { + let calldata_op = if *is_array { + format!("calldata.AddRange({arg});") + } else { + format!("calldata.Add({arg});") + }; + + if let Some(variant) = enum_variant { + let mapped_token = UnityPlugin::map_type(token); + let mapped_variant_type = format!("{}.{}", mapped_token, variant); + + format!("if ({name} is {mapped_variant_type}) {calldata_op}",) + } else { + calldata_op } - } - None => match UnityPlugin::map_type(token).as_str() { - "FieldElement" => format!("{}.Inner", type_name), - _ => format!("new FieldElement({}).Inner", type_name), - }, - } + }) + .collect::>() }) .collect::>() - .join(",\n "); + .join("\n\t\t"); format!( " @@ -280,13 +443,14 @@ public class {} : ModelInstance {{ // Returns the transaction hash. Use `WaitForTransaction` to wait for the transaction to be \ confirmed. public async Task {system_name}(Account account{arg_sep}{args}) {{ + List calldata = new List(); + {calldata} + return await account.ExecuteRaw(new dojo.Call[] {{ new dojo.Call{{ to = contractAddress, selector = \"{system_name}\", - calldata = new dojo.FieldElement[] {{ - {calldata} - }} + calldata = calldata.ToArray() }} }}); }} @@ -315,19 +479,20 @@ public class {} : ModelInstance {{ // Will format the contract into a C# class and // all systems into C# methods // Handled tokens should be a list of all structs and enums used by the contract - fn handle_contract(&self, contract: &DojoContract, handled_tokens: &[Composite]) -> String { + fn handle_contract( + &self, + contract: &DojoContract, + handled_tokens: &HashMap, + ) -> String { let mut out = String::new(); out += UnityPlugin::generated_header().as_str(); - out += "using System;\n"; - out += "using System.Threading.Tasks;\n"; - out += "using Dojo;\n"; - out += "using Dojo.Starknet;\n"; - out += "using UnityEngine;\n"; - out += "using dojo_bindings;\n"; + out += UnityPlugin::contract_imports().as_str(); let systems = contract .systems .iter() + // we assume systems dont have outputs + .filter(|s| s.to_function().unwrap().get_output_kind() as u8 == FunctionOutputKind::NoOutput as u8) .map(|system| UnityPlugin::format_system(system.to_function().unwrap(), handled_tokens)) .collect::>() .join("\n\n "); @@ -356,7 +521,7 @@ public class {} : MonoBehaviour {{ impl BuiltinPlugin for UnityPlugin { async fn generate_code(&self, data: &DojoData) -> BindgenResult>> { let mut out: HashMap> = HashMap::new(); - let mut handled_tokens = Vec::::new(); + let mut handled_tokens = HashMap::::new(); // Handle codegen for models for (name, model) in &data.models {