Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update & fix unity bindgen #2631

Merged
merged 5 commits into from
Nov 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 64 additions & 26 deletions crates/dojo/bindgen/src/plugins/unity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,50 +244,82 @@ namespace {namespace} {{
let mut sorted_enums = tokens.enums.clone();
sorted_enums.sort_by(compare_tokens_by_type_name);

// Process structs first
for token in &sorted_structs {
if handled_tokens.contains_key(&token.type_path()) {
continue;
}

handled_tokens.insert(token.type_path(), token.to_composite().unwrap().to_owned());

// first index is our model struct
if token.type_name() == naming::get_name_from_tag(&model.tag) {
model_struct = Some(token.to_composite().unwrap());
continue;
}
}

let model_struct = model_struct.expect("model struct not found");

// Handle struct dependencies
let struct_keys: Vec<String> = handled_tokens
.iter()
.filter(|(_, s)| {
model_struct.inners.iter().any(|inner| {
s.r#type == CompositeType::Struct
&& check_token_in_recursively(&inner.token, &s.type_name())
&& inner.token.type_name() != "ByteArray"
})
})
.map(|(k, _)| k.clone())
.collect();

out += UnityPlugin::format_struct(token.to_composite().unwrap()).as_str();
for key in struct_keys {
if let Some(s) = handled_tokens.remove(&key) {
out += UnityPlugin::format_struct(&s).as_str();
}
}

// Process enums
for token in &sorted_enums {
if handled_tokens.contains_key(&token.type_path()) {
continue;
}

handled_tokens.insert(token.type_path(), token.to_composite().unwrap().to_owned());
out += UnityPlugin::format_enum(token.to_composite().unwrap()).as_str();
}

out += "\n";
// Handle enum dependencies
let enum_keys: Vec<String> = handled_tokens
.iter()
.filter(|(_, s)| {
model_struct.inners.iter().any(|inner| {
s.r#type == CompositeType::Enum
&& check_token_in_recursively(&inner.token, &s.type_name())
})
})
.map(|(k, _)| k.clone())
.collect();

out += UnityPlugin::format_model(
&get_namespace_from_tag(&model.tag),
model_struct.expect("model struct not found"),
)
.as_str();
for key in enum_keys {
if let Some(s) = handled_tokens.remove(&key) {
out += UnityPlugin::format_enum(&s).as_str();
}
}

out += "\n";
out +=
UnityPlugin::format_model(&get_namespace_from_tag(&model.tag), model_struct).as_str();

out
}

// Formats a system into a C# method used by the contract class
// Handled tokens should be a list of all structs and enums used by the contract
// Such as a set of referenced tokens from a model
fn format_system(system: &Function, handled_tokens: &HashMap<String, Composite>) -> String {
fn format_system(system: &Function) -> String {
fn handle_arg_recursive(
arg_name: &str,
token: &Token,
handled_tokens: &HashMap<String, Composite>,
// variant name
// if its an enum variant data
enum_variant: Option<String>,
Expand All @@ -304,8 +336,6 @@ namespace {namespace} {{

match token {
Token::Composite(t) => {
let t = handled_tokens.get(&t.type_path).unwrap_or(t);

// Need to flatten the struct members.
match t.r#type {
CompositeType::Struct if t.type_name() == "ByteArray" => vec![(
Expand All @@ -319,7 +349,6 @@ namespace {namespace} {{
tokens.extend(handle_arg_recursive(
&format!("{}.{}", arg_name, f.name),
&f.token,
handled_tokens,
enum_variant.clone(),
));
});
Expand Down Expand Up @@ -360,7 +389,6 @@ namespace {namespace} {{
} else {
field.token.clone()
},
handled_tokens,
Some(field.name.clone()),
))
});
Expand All @@ -375,7 +403,6 @@ namespace {namespace} {{
let inner = handle_arg_recursive(
&format!("{arg_name}Item"),
&array.inner,
handled_tokens,
enum_variant.clone(),
);

Expand Down Expand Up @@ -416,7 +443,6 @@ namespace {namespace} {{
handle_arg_recursive(
&format!("{}.Item{}", arg_name, idx + 1),
token,
handled_tokens,
enum_variant.clone(),
)
})
Expand All @@ -441,7 +467,7 @@ namespace {namespace} {{
.inputs
.iter()
.flat_map(|(name, token)| {
let tokens = handle_arg_recursive(name, token, handled_tokens, None);
let tokens = handle_arg_recursive(name, token, None);

tokens
.iter()
Expand Down Expand Up @@ -477,7 +503,7 @@ namespace {namespace} {{

return await account.ExecuteRaw(new dojo.Call[] {{
new dojo.Call{{
to = contractAddress,
to = new FieldElement(contractAddress).Inner,
selector = \"{system_name}\",
calldata = calldata.ToArray()
}}
Larkooo marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -505,11 +531,7 @@ namespace {namespace} {{
// Will format the contract into a C# class and
// all systems into C# methods
// Handled tokens should be a list of all structs and enums used by the contract
fn handle_contract(
&self,
contract: &DojoContract,
handled_tokens: &HashMap<String, Composite>,
) -> String {
fn handle_contract(&self, contract: &DojoContract) -> String {
let mut out = String::new();
out += UnityPlugin::generated_header().as_str();
out += UnityPlugin::contract_imports().as_str();
Expand All @@ -519,7 +541,7 @@ namespace {namespace} {{
.iter()
// we assume systems dont have outputs
.filter(|s| s.to_function().unwrap().get_output_kind() as u8 == FunctionOutputKind::NoOutput as u8)
.map(|system| UnityPlugin::format_system(system.to_function().unwrap(), handled_tokens))
.map(|system| UnityPlugin::format_system(system.to_function().unwrap()))
.collect::<Vec<String>>()
.join("\n\n ");

Expand All @@ -543,6 +565,22 @@ public class {} : MonoBehaviour {{
}
}

fn check_token_in_recursively(token: &Token, type_name: &str) -> bool {
match token {
Token::Composite(composite) => {
if composite.type_name() == type_name {
return true;
}
composite.inners.iter().any(|inner| check_token_in_recursively(&inner.token, type_name))
}
Token::Array(array) => check_token_in_recursively(&array.inner, type_name),
Token::Tuple(tuple) => {
tuple.inners.iter().any(|inner| check_token_in_recursively(inner, type_name))
}
_ => token.type_name() == type_name,
}
}

#[async_trait]
impl BuiltinPlugin for UnityPlugin {
async fn generate_code(&self, data: &DojoData) -> BindgenResult<HashMap<PathBuf, Vec<u8>>> {
Expand Down Expand Up @@ -572,7 +610,7 @@ impl BuiltinPlugin for UnityPlugin {
let contracts_path = Path::new(&format!("Contracts/{}.gen.cs", name)).to_owned();

println!("Generating contract: {}", name);
let code = self.handle_contract(contract, &handled_tokens);
let code = self.handle_contract(contract);

out.insert(contracts_path, code.as_bytes().to_vec());
}
Expand Down
Loading