diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index ed7155c5c1..66bba6e674 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -18,6 +18,7 @@ RUN echo "deb http://apt.llvm.org/${VARIANT}/ llvm-toolchain-${VARIANT}-17 main" RUN apt-get -y install -t llvm-toolchain-${VARIANT}-17 llvm-17 llvm-17-dev llvm-17-runtime clang-17 clang-tools-17 lld-17 libpolly-17-dev libmlir-17-dev mlir-17-tools RUN curl -L https://foundry.paradigm.xyz/ | bash && . /root/.bashrc && foundryup +ENV PATH="${PATH}:/root/.foundry/bin" # To build Katana with 'native' feature, we need to set the following environment variables ENV MLIR_SYS_170_PREFIX=/usr/lib/llvm-17 diff --git a/Cargo.lock b/Cargo.lock index a5527378df..877b081274 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14271,6 +14271,7 @@ version = "0.7.0-alpha.5" dependencies = [ "anyhow", "async-trait", + "cainome", "chrono", "crypto-bigint", "dojo-test-utils", diff --git a/bin/sozo/src/commands/migrate.rs b/bin/sozo/src/commands/migrate.rs index 0d36af3fbc..6f6abe7404 100644 --- a/bin/sozo/src/commands/migrate.rs +++ b/bin/sozo/src/commands/migrate.rs @@ -212,8 +212,8 @@ fn is_compatible_version(provided_version: &str, expected_version: &str) -> Resu .map_err(|e| anyhow!("Failed to parse expected version '{}': {}", expected_version, e))?; // Specific backward compatibility rule: 0.6 is compatible with 0.7. - if (provided_ver.major == 0 && provided_ver.minor == 6) - && (expected_ver.major == 0 && expected_ver.minor == 7) + if (provided_ver.major == 0 && provided_ver.minor == 7) + && (expected_ver.major == 0 && expected_ver.minor == 6) { return Ok(true); } @@ -246,7 +246,9 @@ mod tests { #[test] fn test_is_compatible_version_specific_backward_compatibility() { - assert!(is_compatible_version("0.6.0", "0.7.1").unwrap()); + let node_version = "0.7.1"; + let katana_version = "0.6.0"; + assert!(is_compatible_version(node_version, katana_version).unwrap()); } #[test] diff --git a/crates/dojo-bindgen/src/plugins/unity/mod.rs b/crates/dojo-bindgen/src/plugins/unity/mod.rs index 2827519351..45578f2267 100644 --- a/crates/dojo-bindgen/src/plugins/unity/mod.rs +++ b/crates/dojo-bindgen/src/plugins/unity/mod.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::path::{Path, PathBuf}; use async_trait::async_trait; -use cainome::parser::tokens::{Composite, CompositeType, Function, Token}; +use cainome::parser::tokens::{Composite, CompositeType, Function, FunctionOutputKind, Token}; use crate::error::BindgenResult; use crate::plugins::BuiltinPlugin; @@ -87,6 +87,32 @@ impl UnityPlugin { ) } + fn contract_imports() -> String { + "using System; +using System.Threading.Tasks; +using Dojo; +using Dojo.Starknet; +using UnityEngine; +using dojo_bindings; +using System.Collections.Generic; +using System.Linq; +using Enum = Dojo.Starknet.Enum; +" + .to_string() + } + + fn model_imports() -> String { + "using System; +using Dojo; +using Dojo.Starknet; +using System.Reflection; +using System.Linq; +using System.Collections.Generic; +using Enum = Dojo.Starknet.Enum; +" + .to_string() + } + // Token should be a struct // This will be formatted into a C# struct // using C# and unity SDK types @@ -116,7 +142,8 @@ public struct {} {{ // This will be formatted into a C# enum // Enum is mapped using index of cairo enum fn format_enum(token: &Composite) -> String { - let mut name_with_generics = token.type_name(); + let name = token.type_name(); + let mut name_with_generics = name.clone(); if !token.generic_args.is_empty() { name_with_generics += &format!( "<{}>", @@ -127,7 +154,7 @@ public struct {} {{ let mut result = format!( " // Type definition for `{}` enum -public abstract record {}() {{", +public abstract record {}() : Enum {{", token.type_path, name_with_generics ); @@ -189,21 +216,23 @@ public class {} : ModelInstance {{ // Handles a model definition and its referenced tokens // Will map all structs and enums to C# types // Will format the model into a C# class - fn handle_model(&self, model: &DojoModel, handled_tokens: &mut Vec) -> String { + fn handle_model( + &self, + model: &DojoModel, + handled_tokens: &mut HashMap, + ) -> String { let mut out = String::new(); out += UnityPlugin::generated_header().as_str(); - out += "using System;\n"; - out += "using Dojo;\n"; - out += "using Dojo.Starknet;\n"; + out += UnityPlugin::model_imports().as_str(); let mut model_struct: Option<&Composite> = None; let tokens = &model.tokens; for token in &tokens.structs { - if handled_tokens.iter().any(|t| t.type_name() == token.type_name()) { + if handled_tokens.contains_key(&token.type_path()) { continue; } - handled_tokens.push(token.to_composite().unwrap().to_owned()); + handled_tokens.insert(token.type_path(), token.to_composite().unwrap().to_owned()); // first index is our model struct if token.type_name() == model.name { @@ -215,11 +244,11 @@ public class {} : ModelInstance {{ } for token in &tokens.enums { - if handled_tokens.iter().any(|t| t.type_name() == token.type_name()) { + if handled_tokens.contains_key(&token.type_path()) { continue; } - handled_tokens.push(token.to_composite().unwrap().to_owned()); + handled_tokens.insert(token.type_path(), token.to_composite().unwrap().to_owned()); out += UnityPlugin::format_enum(token.to_composite().unwrap()).as_str(); } @@ -233,7 +262,145 @@ public class {} : ModelInstance {{ // Formats a system into a C# method used by the contract class // Handled tokens should be a list of all structs and enums used by the contract // Such as a set of referenced tokens from a model - fn format_system(system: &Function, handled_tokens: &[Composite]) -> String { + fn format_system(system: &Function, handled_tokens: &HashMap) -> String { + fn handle_arg_recursive( + arg_name: &str, + token: &Token, + handled_tokens: &HashMap, + // variant name + // if its an enum variant data + enum_variant: Option, + ) -> Vec<( + // formatted arg + String, + // if its an array + bool, + // enum name and variant name + // if its an enum variant data + Option, + )> { + let mapped_type = UnityPlugin::map_type(token); + + match token { + Token::Composite(t) => { + let t = handled_tokens.get(&t.type_path).unwrap_or(t); + + // Need to flatten the struct members. + match t.r#type { + CompositeType::Struct if t.type_name() == "ByteArray" => vec![( + format!("ByteArray.Serialize({}).Select(f => f.Inner)", arg_name), + true, + enum_variant, + )], + CompositeType::Struct => { + let mut tokens = vec![]; + t.inners.iter().for_each(|f| { + tokens.extend(handle_arg_recursive( + &format!("{}.{}", arg_name, f.name), + &f.token, + handled_tokens, + enum_variant.clone(), + )); + }); + + tokens + } + CompositeType::Enum => { + let mut tokens = vec![( + format!("new FieldElement(Enum.GetIndex({})).Inner", arg_name), + false, + enum_variant, + )]; + + t.inners.iter().for_each(|field| { + if let Token::CoreBasic(basic) = &field.token { + // ignore unit type + if basic.type_path == "()" { + return; + } + } + + tokens.extend(handle_arg_recursive( + &format!( + "(({}.{}){}).value", + mapped_type, + field.name.clone(), + arg_name + ), + &if let Token::GenericArg(generic_arg) = &field.token { + let generic_token = t + .generic_args + .iter() + .find(|(name, _)| name == generic_arg) + .unwrap() + .1 + .clone(); + generic_token + } else { + field.token.clone() + }, + handled_tokens, + Some(field.name.clone()), + )) + }); + + tokens + } + CompositeType::Unknown => panic!("Unknown composite type: {:?}", t), + } + } + Token::Array(array) => { + let is_inner_array = matches!(array.inner.as_ref(), Token::Array(_)); + let inner = handle_arg_recursive( + &format!("{arg_name}Item"), + &array.inner, + handled_tokens, + enum_variant.clone(), + ); + + let inners = inner + .into_iter() + .map(|(arg, _, _)| arg) + .collect::>() + .join(", "); + + vec![( + if is_inner_array { + format!( + "{arg_name}.SelectMany({arg_name}Item => new dojo.FieldElement[] \ + {{ }}.Concat({inners}))" + ) + } else { + format!( + "{arg_name}.SelectMany({arg_name}Item => new [] {{ {inners} }})" + ) + }, + true, + enum_variant.clone(), + )] + } + Token::Tuple(tuple) => tuple + .inners + .iter() + .enumerate() + .flat_map(|(idx, token)| { + handle_arg_recursive( + &format!("{}.Item{}", arg_name, idx + 1), + token, + handled_tokens, + enum_variant.clone(), + ) + }) + .collect(), + _ => match mapped_type.as_str() { + "FieldElement" => vec![(format!("{}.Inner", arg_name), false, enum_variant)], + _ => { + vec![(format!("new FieldElement({}).Inner", arg_name), false, enum_variant)] + } + }, + } + } + let args = system .inputs .iter() @@ -244,35 +411,31 @@ public class {} : ModelInstance {{ let calldata = system .inputs .iter() - .map(|arg| { - let token = &arg.1; - let type_name = &arg.0; - - match handled_tokens.iter().find(|t| t.type_name() == token.type_name()) { - Some(t) => { - // Need to flatten the struct members. - match t.r#type { - CompositeType::Struct => t - .inners - .iter() - .map(|field| { - format!("new FieldElement({}.{}).Inner", type_name, field.name) - }) - .collect::>() - .join(",\n "), - _ => { - format!("new FieldElement({}).Inner", type_name) - } + .flat_map(|(name, token)| { + let tokens = handle_arg_recursive(name, token, handled_tokens, None); + + tokens + .iter() + .map(|(arg, is_array, enum_variant)| { + let calldata_op = if *is_array { + format!("calldata.AddRange({arg});") + } else { + format!("calldata.Add({arg});") + }; + + if let Some(variant) = enum_variant { + let mapped_token = UnityPlugin::map_type(token); + let mapped_variant_type = format!("{}.{}", mapped_token, variant); + + format!("if ({name} is {mapped_variant_type}) {calldata_op}",) + } else { + calldata_op } - } - None => match UnityPlugin::map_type(token).as_str() { - "FieldElement" => format!("{}.Inner", type_name), - _ => format!("new FieldElement({}).Inner", type_name), - }, - } + }) + .collect::>() }) .collect::>() - .join(",\n "); + .join("\n\t\t"); format!( " @@ -280,13 +443,14 @@ public class {} : ModelInstance {{ // Returns the transaction hash. Use `WaitForTransaction` to wait for the transaction to be \ confirmed. public async Task {system_name}(Account account{arg_sep}{args}) {{ + List calldata = new List(); + {calldata} + return await account.ExecuteRaw(new dojo.Call[] {{ new dojo.Call{{ to = contractAddress, selector = \"{system_name}\", - calldata = new dojo.FieldElement[] {{ - {calldata} - }} + calldata = calldata.ToArray() }} }}); }} @@ -315,19 +479,20 @@ public class {} : ModelInstance {{ // Will format the contract into a C# class and // all systems into C# methods // Handled tokens should be a list of all structs and enums used by the contract - fn handle_contract(&self, contract: &DojoContract, handled_tokens: &[Composite]) -> String { + fn handle_contract( + &self, + contract: &DojoContract, + handled_tokens: &HashMap, + ) -> String { let mut out = String::new(); out += UnityPlugin::generated_header().as_str(); - out += "using System;\n"; - out += "using System.Threading.Tasks;\n"; - out += "using Dojo;\n"; - out += "using Dojo.Starknet;\n"; - out += "using UnityEngine;\n"; - out += "using dojo_bindings;\n"; + out += UnityPlugin::contract_imports().as_str(); let systems = contract .systems .iter() + // we assume systems dont have outputs + .filter(|s| s.to_function().unwrap().get_output_kind() as u8 == FunctionOutputKind::NoOutput as u8) .map(|system| UnityPlugin::format_system(system.to_function().unwrap(), handled_tokens)) .collect::>() .join("\n\n "); @@ -356,7 +521,7 @@ public class {} : MonoBehaviour {{ impl BuiltinPlugin for UnityPlugin { async fn generate_code(&self, data: &DojoData) -> BindgenResult>> { let mut out: HashMap> = HashMap::new(); - let mut handled_tokens = Vec::::new(); + let mut handled_tokens = HashMap::::new(); // Handle codegen for models for (name, model) in &data.models { diff --git a/crates/dojo-lang/src/contract.rs b/crates/dojo-lang/src/contract.rs index fac9667ba0..c2a9c292ff 100644 --- a/crates/dojo-lang/src/contract.rs +++ b/crates/dojo-lang/src/contract.rs @@ -4,37 +4,28 @@ use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode}; use cairo_lang_defs::plugin::{ DynGeneratedFileAuxData, PluginDiagnostic, PluginGeneratedFile, PluginResult, }; -use cairo_lang_diagnostics::Severity; -use cairo_lang_syntax::attribute::structured::{ - Attribute, AttributeArg, AttributeArgVariant, AttributeListStructurize, -}; use cairo_lang_syntax::node::ast::MaybeModuleBody; use cairo_lang_syntax::node::db::SyntaxGroup; use cairo_lang_syntax::node::{ast, ids, Terminal, TypedStablePtr, TypedSyntaxNode}; use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; use dojo_types::system::Dependency; -use crate::plugin::{DojoAuxData, SystemAuxData, DOJO_CONTRACT_ATTR}; +use crate::plugin::{DojoAuxData, SystemAuxData}; +use crate::syntax::world_param::{self, WorldParamInjectionKind}; +use crate::syntax::{self_param, utils as syntax_utils}; -const ALLOW_REF_SELF_ARG: &str = "allow_ref_self"; const DOJO_INIT_FN: &str = "dojo_init"; pub struct DojoContract { diagnostics: Vec, dependencies: HashMap, - do_allow_ref_self: bool, } impl DojoContract { pub fn from_module(db: &dyn SyntaxGroup, module_ast: ast::ItemModule) -> PluginResult { let name = module_ast.name(db).text(db); - let attrs = module_ast.attributes(db).structurize(db); - let dojo_contract_attr = attrs.iter().find(|attr| attr.id.as_str() == DOJO_CONTRACT_ATTR); - let do_allow_ref_self = extract_allow_ref_self(dojo_contract_attr, db).unwrap_or_default(); - - let mut system = - DojoContract { diagnostics: vec![], dependencies: HashMap::new(), do_allow_ref_self }; + let mut system = DojoContract { diagnostics: vec![], dependencies: HashMap::new() }; let mut has_event = false; let mut has_storage = false; let mut has_init = false; @@ -182,14 +173,14 @@ impl DojoContract { let fn_decl = fn_ast.declaration(db); let fn_name = fn_decl.name(db).text(db); - let (params_str, _, world_removed) = self.rewrite_parameters( + let (params_str, was_world_injected) = self.rewrite_parameters( db, fn_decl.signature(db).parameters(db), fn_ast.stable_ptr().untyped(), ); let mut world_read = ""; - if world_removed { + if was_world_injected { world_read = "let world = self.world_dispatcher.read();"; } @@ -303,166 +294,61 @@ impl DojoContract { )] } - /// Gets name, modifiers and type from a function parameter. - pub fn get_parameter_info( - &mut self, - db: &dyn SyntaxGroup, - param: ast::Param, - ) -> (String, String, String) { - let name = param.name(db).text(db).trim().to_string(); - let modifiers = param.modifiers(db).as_syntax_node().get_text(db).trim().to_string(); - let param_type = - param.type_clause(db).ty(db).as_syntax_node().get_text(db).trim().to_string(); - - (name, modifiers, param_type) - } - - /// Check if the function has a self parameter. - /// - /// Returns - /// * a boolean indicating if `self` has to be added, - // * a boolean indicating if there is a `ref self` parameter. - pub fn check_self_parameter( - &mut self, - db: &dyn SyntaxGroup, - param_list: ast::ParamList, - ) -> (bool, bool) { - let mut add_self = true; - let mut has_ref_self = false; - if !param_list.elements(db).is_empty() { - let (param_name, param_modifiers, param_type) = - self.get_parameter_info(db, param_list.elements(db)[0].clone()); - - if param_name.eq(&"self".to_string()) { - if param_modifiers.contains(&"ref".to_string()) - && param_type.eq(&"ContractState".to_string()) - { - has_ref_self = true; - add_self = false; - } - - if param_type.eq(&"@ContractState".to_string()) { - add_self = false; - } - } - }; - - (add_self, has_ref_self) - } - - /// Check if the function has multiple IWorldDispatcher parameters. - /// - /// Returns - /// * a boolean indicating if the function has multiple world dispatchers. - pub fn check_world_dispatcher( - &mut self, - db: &dyn SyntaxGroup, - param_list: ast::ParamList, - ) -> bool { - let mut count = 0; - - param_list.elements(db).iter().for_each(|param| { - let (_, _, param_type) = self.get_parameter_info(db, param.clone()); - - if param_type.eq(&"IWorldDispatcher".to_string()) { - count += 1; - } - }); - - count > 1 - } - /// Rewrites parameter list by: - /// * adding `self` parameter if missing, - /// * removing `world` if present as first parameter (self excluded), as it will be read from - /// the first function statement. + /// * adding `self` parameter based on the `world` parameter mutability. If `world` is not + /// provided, a `View` is assumed. + /// * removing `world` if present as first parameter, as it will be read from the first + /// function statement. /// /// Reports an error in case of: - /// * `ref self`, as systems are supposed to be 100% stateless, - /// * multiple IWorldDispatcher parameters. - /// * the `IWorldDispatcher` is not the first parameter (self excluded) and named 'world'. + /// * `self` used explicitly, + /// * multiple world parameters, + /// * the `world` parameter is not the first parameter and named 'world'. /// /// Returns - /// * the list of parameters in a String - /// * a boolean indicating if `self` has been added - // * a boolean indicating if `world` parameter has been removed + /// * the list of parameters in a String. + /// * true if the world has to be injected (found as the first param). pub fn rewrite_parameters( &mut self, db: &dyn SyntaxGroup, param_list: ast::ParamList, - diagnostic_item: ids::SyntaxStablePtrId, - ) -> (String, bool, bool) { - let (add_self, has_ref_self) = self.check_self_parameter(db, param_list.clone()); - let has_multiple_world_dispatchers = self.check_world_dispatcher(db, param_list.clone()); + fn_diagnostic_item: ids::SyntaxStablePtrId, + ) -> (String, bool) { + self_param::check_parameter(db, ¶m_list, fn_diagnostic_item, &mut self.diagnostics); - let mut world_removed = false; + let world_injection = world_param::parse_world_injection( + db, + param_list.clone(), + fn_diagnostic_item, + &mut self.diagnostics, + ); let mut params = param_list .elements(db) .iter() - .enumerate() - .filter_map(|(idx, param)| { - let (name, modifiers, param_type) = self.get_parameter_info(db, param.clone()); - - if param_type.eq(&"IWorldDispatcher".to_string()) - && modifiers.eq(&"".to_string()) - && !has_multiple_world_dispatchers - { - let has_good_pos = (add_self && idx == 0) || (!add_self && idx == 1); - let has_good_name = name.eq(&"world".to_string()); - - if has_good_pos && has_good_name { - world_removed = true; - None - } else { - if !has_good_pos { - self.diagnostics.push(PluginDiagnostic { - stable_ptr: param.stable_ptr().untyped(), - message: "The IWorldDispatcher parameter must be the first \ - parameter of the function (self excluded)." - .to_string(), - severity: Severity::Error, - }); - } + .filter_map(|param| { + let (name, _, param_type) = syntax_utils::get_parameter_info(db, param.clone()); - if !has_good_name { - self.diagnostics.push(PluginDiagnostic { - stable_ptr: param.stable_ptr().untyped(), - message: "The IWorldDispatcher parameter must be named 'world'." - .to_string(), - severity: Severity::Error, - }); - } - Some(param.as_syntax_node().get_text(db)) - } + // If the param is `IWorldDispatcher`, we don't need to keep it in the param list + // as it is flatten in the first statement. + if world_param::is_world_param(&name, ¶m_type) { + None } else { Some(param.as_syntax_node().get_text(db)) } }) .collect::>(); - if has_multiple_world_dispatchers { - self.diagnostics.push(PluginDiagnostic { - stable_ptr: diagnostic_item, - message: "Only one parameter of type IWorldDispatcher is allowed.".to_string(), - severity: Severity::Error, - }); - } - - if has_ref_self && !self.do_allow_ref_self { - self.diagnostics.push(PluginDiagnostic { - stable_ptr: diagnostic_item, - message: "Functions of dojo::contract cannot have 'ref self' parameter." - .to_string(), - severity: Severity::Error, - }); - } - - if add_self { - params.insert(0, "self: @ContractState".to_string()); + match world_injection { + WorldParamInjectionKind::None | WorldParamInjectionKind::View => { + params.insert(0, "self: @ContractState".to_string()); + } + WorldParamInjectionKind::External => { + params.insert(0, "ref self: ContractState".to_string()); + } } - (params.join(", "), add_self, world_removed) + (params.join(", "), world_injection != WorldParamInjectionKind::None) } /// Rewrites function statements by adding the reading of `world` at first statement. @@ -493,21 +379,23 @@ impl DojoContract { ) -> Vec { let mut rewritten_fn = RewriteNode::from_ast(&fn_ast); - let (params_str, self_added, world_removed) = self.rewrite_parameters( + let (params_str, was_world_injected) = self.rewrite_parameters( db, fn_ast.declaration(db).signature(db).parameters(db), fn_ast.stable_ptr().untyped(), ); - if self_added || world_removed { - let rewritten_params = rewritten_fn - .modify_child(db, ast::FunctionWithBody::INDEX_DECLARATION) - .modify_child(db, ast::FunctionDeclaration::INDEX_SIGNATURE) - .modify_child(db, ast::FunctionSignature::INDEX_PARAMETERS); - rewritten_params.set_str(params_str); - } - - if world_removed { + // We always rewrite the params as the self parameter is added based on the + // world mutability. + let rewritten_params = rewritten_fn + .modify_child(db, ast::FunctionWithBody::INDEX_DECLARATION) + .modify_child(db, ast::FunctionDeclaration::INDEX_SIGNATURE) + .modify_child(db, ast::FunctionSignature::INDEX_PARAMETERS); + rewritten_params.set_str(params_str); + + // If the world was injected, we also need to rewrite the statements of the function + // to ensure the `world` injection is effective. + if was_world_injected { let rewritten_statements = rewritten_fn .modify_child(db, ast::FunctionWithBody::INDEX_BODY) .modify_child(db, ast::ExprBlock::INDEX_STATEMENTS); @@ -557,26 +445,3 @@ impl DojoContract { vec![RewriteNode::Copied(impl_ast.as_syntax_node())] } } - -/// Extract the allow_ref_self attribute. -pub(crate) fn extract_allow_ref_self( - allow_ref_self_attr: Option<&Attribute>, - db: &dyn SyntaxGroup, -) -> Option { - let Some(attr) = allow_ref_self_attr else { - return None; - }; - - #[allow(clippy::collapsible_match)] - match &attr.args[..] { - [AttributeArg { variant: AttributeArgVariant::Unnamed(value), .. }] => match value { - ast::Expr::Path(path) - if path.as_syntax_node().get_text_without_trivia(db) == ALLOW_REF_SELF_ARG => - { - Some(true) - } - _ => None, - }, - _ => None, - } -} diff --git a/crates/dojo-lang/src/interface.rs b/crates/dojo-lang/src/interface.rs index 54205335c0..fc9bf2ef1b 100644 --- a/crates/dojo-lang/src/interface.rs +++ b/crates/dojo-lang/src/interface.rs @@ -5,6 +5,9 @@ use cairo_lang_syntax::node::db::SyntaxGroup; use cairo_lang_syntax::node::{ast, ids, Terminal, TypedStablePtr, TypedSyntaxNode}; use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; +use crate::syntax::self_param; +use crate::syntax::world_param::{self, WorldParamInjectionKind}; + pub struct DojoInterface { diagnostics: Vec, } @@ -77,9 +80,7 @@ impl DojoInterface { } } - /// Rewrites parameter list by adding `self` parameter if missing. - /// - /// Reports an error in case of `ref self` as systems are supposed to be 100% stateless. + /// Rewrites parameter list by adding `self` parameter based on the `world` parameter. pub fn rewrite_parameters( &mut self, db: &dyn SyntaxGroup, @@ -92,50 +93,29 @@ impl DojoInterface { .map(|e| e.as_syntax_node().get_text(db)) .collect::>(); - let mut need_to_add_self = true; - if !params.is_empty() { - let first_param = param_list.elements(db)[0].clone(); - let param_name = first_param.name(db).text(db).to_string(); - - if param_name.eq(&"self".to_string()) { - let param_modifiers = first_param - .modifiers(db) - .elements(db) - .iter() - .map(|e| e.as_syntax_node().get_text(db).trim().to_string()) - .collect::>(); - - let param_type = first_param - .type_clause(db) - .ty(db) - .as_syntax_node() - .get_text(db) - .trim() - .to_string(); - - if param_modifiers.contains(&"ref".to_string()) - && param_type.eq(&"TContractState".to_string()) - { - self.diagnostics.push(PluginDiagnostic { - stable_ptr: diagnostic_item, - message: "Functions of dojo::interface cannot have `ref self` parameter." - .to_string(), - severity: Severity::Error, - }); + self_param::check_parameter(db, ¶m_list, diagnostic_item, &mut self.diagnostics); - need_to_add_self = false; - } - - if param_type.eq(&"@TContractState".to_string()) { - need_to_add_self = false; - } + let world_injection = world_param::parse_world_injection( + db, + param_list, + diagnostic_item, + &mut self.diagnostics, + ); + + match world_injection { + WorldParamInjectionKind::None => { + params.insert(0, "self: @TContractState".to_string()); + } + WorldParamInjectionKind::View => { + params.remove(0); + params.insert(0, "self: @TContractState".to_string()); + } + WorldParamInjectionKind::External => { + params.remove(0); + params.insert(0, "ref self: TContractState".to_string()); } }; - if need_to_add_self { - params.insert(0, "self: @TContractState".to_string()); - } - params.join(", ") } @@ -151,11 +131,13 @@ impl DojoInterface { .modify_child(db, ast::FunctionDeclaration::INDEX_SIGNATURE) .modify_child(db, ast::FunctionSignature::INDEX_PARAMETERS); - rewritten_params.set_str(self.rewrite_parameters( + let params_str = self.rewrite_parameters( db, fn_ast.declaration(db).signature(db).parameters(db), fn_ast.stable_ptr().untyped(), - )); + ); + + rewritten_params.set_str(params_str); vec![rewritten_fn] } } diff --git a/crates/dojo-lang/src/lib.rs b/crates/dojo-lang/src/lib.rs index b76ea602c8..4f7e0e4bca 100644 --- a/crates/dojo-lang/src/lib.rs +++ b/crates/dojo-lang/src/lib.rs @@ -13,6 +13,7 @@ pub mod model; pub mod plugin; pub mod print; pub mod semantics; +pub mod syntax; pub(crate) mod version; // Copy of non pub functions from scarb + extension. diff --git a/crates/dojo-lang/src/plugin_test_data/system b/crates/dojo-lang/src/plugin_test_data/system index c23d1a35a0..f7fbef26bb 100644 --- a/crates/dojo-lang/src/plugin_test_data/system +++ b/crates/dojo-lang/src/plugin_test_data/system @@ -85,81 +85,54 @@ trait IEmptyTrait; trait IFaultyTrait { const ONE: u8; - fn do_ref_self(ref self: TContractState); - #[my_attr] fn do_with_attrs(p1: u8) -> u16; } -#[starknet::interface] -trait IAllowedRefSelf { - fn spawn(ref self: T); -} - -#[dojo::contract(allow_ref_self)] -mod ContractAllowedRefSelf { - #[abi(embed_v0)] - impl AllowedImpl of IAllowedRefSelf { - fn spawn(ref self: ContractState) {} - } -} - #[dojo::interface] trait INominalTrait { - fn do_no_param(); - fn do_no_param_but_self(self: @TContractState); - fn do_params(p1: dojo_examples::models::Direction, p2: u8); - fn do_params_and_self(self: @TContractState, p2: u8); - fn do_return_value(p1: u8) -> u16; + fn do_no_param() -> felt252; + fn do_no_param_but_world(world: @IWorldDispatcher) -> felt252; + fn do_no_param_but_world_ref(ref world: IWorldDispatcher) -> felt252; + fn do_params_no_world(p1: felt252, p2: u8) -> felt252; + fn do_params_and_world(world: @IWorldDispatcher, p2: u8) -> felt252; + fn do_params_and_world_ref(ref world: IWorldDispatcher, p2: u8) -> felt252; } #[dojo::interface] -trait IWorldTrait { +trait IFaultyTrait { + fn do_with_self(self: @ContractState) -> felt252; fn do_with_ref_self(ref self: ContractState) -> felt252; fn do_with_several_world_dispatchers( - world: IWorldDispatcher, vec: Vec2, another_world: IWorldDispatcher - ) -> felt252; - fn do_with_world_not_named_world(another_world: IWorldDispatcher) -> felt252; - fn do_with_self_and_world_not_named_world( - self: @ContractState, another_world: IWorldDispatcher - ); - fn do_with_world_not_first(vec: Vec2, world: IWorldDispatcher) -> felt252; - fn do_with_self_and_world_not_first( - self: @ContractState, vec: Vec2, world: IWorldDispatcher + world: @IWorldDispatcher, vec: Vec2, ref another_world: IWorldDispatcher ) -> felt252; + fn do_with_world_not_named_world(another_world: @IWorldDispatcher) -> felt252; + fn do_with_world_not_first(vec: Vec2, ref world: IWorldDispatcher) -> felt252; } #[dojo::contract] mod MyFaultyContract { #[abi(embed_v0)] - impl TestWorldImpl of IWorldTrait { - fn do_with_ref_self(ref self: ContractState) -> felt252 { - 'land' - } - - fn do_with_several_world_dispatchers( - world: IWorldDispatcher, vec: Vec2, another_world: IWorldDispatcher - ) -> felt252 { + impl TestFaultyImpl of IFaultyTrait { + fn do_with_self(ref self: ContractState) -> felt252 { 'land' } - fn do_with_world_not_named_world(another_world: IWorldDispatcher) -> felt252 { + fn do_with_ref_self(ref self: ContractState) -> felt252 { 'land' } - fn do_with_self_and_world_not_named_world( - self: @ContractState, another_world: IWorldDispatcher + fn do_with_several_world_dispatchers( + world: @IWorldDispatcher, vec: Vec2, ref another_world: IWorldDispatcher ) -> felt252 { 'land' } - fn do_with_world_not_first(vec: Vec2, world: IWorldDispatcher) -> felt252 { + fn do_with_world_not_named_world(another_world: @IWorldDispatcher) -> felt252 { 'land' } - fn do_with_self_and_world_not_first( - self: @ContractState, vec: Vec2, world: IWorldDispatcher - ) -> felt252 { + fn do_with_world_not_first(vec: Vec2, ref world: IWorldDispatcher) -> felt252 { 'land' } } @@ -173,22 +146,28 @@ mod MyNominalContract { } #[abi(embed_v0)] - impl TestWorldImpl of IWorldTrait { - fn do(vec: Vec2) -> felt252 { + impl TestNominalImpl of INominalTrait { + fn do_no_param() -> felt252 { 'land' } - fn do_with_self(self: @ContractState, vec: Vec2) -> felt252 { + fn do_no_param_but_world(world: @IWorldDispatcher) -> felt252 { 'land' } - fn do_with_world_first(world: IWorldDispatcher, vec: Vec2) -> felt252 { + fn do_no_param_but_world_ref(ref world: IWorldDispatcher) -> felt252 { 'land' } - fn do_with_self_and_world_first( - self: @ContractState, world: IWorldDispatcher, vec: Vec2 - ) -> felt252 { + fn do_params_no_world(p1: felt252, p2: u8) -> felt252 { + 'land' + } + + fn do_params_and_world(world: @IWorldDispatcher, p2: u8) -> felt252 { + 'land' + } + + fn do_params_and_world_ref(ref world: IWorldDispatcher, p2: u8) -> felt252 { 'land' } } @@ -347,40 +326,40 @@ error: Anything other than functions is not supported in a dojo::interface const ONE: u8; ^************^ -error: Functions of dojo::interface cannot have `ref self` parameter. - --> test_src/lib.cairo:82:5 - fn do_ref_self(ref self: TContractState); - ^***************************************^ +error: In a dojo contract or interface, you should use `world: @IWorldDispatcher` instead of `self: @ContractState`. + --> test_src/lib.cairo:98:5 + fn do_with_self(self: @ContractState) -> felt252; + ^***********************************************^ -error: Functions of dojo::contract cannot have 'ref self' parameter. - --> test_src/lib.cairo:130:9 - fn do_with_ref_self(ref self: ContractState) -> felt252 { - ^*******************************************************^ +error: In a dojo contract or interface, you should use `ref world: IWorldDispatcher` instead of `ref self: ContractState`. + --> test_src/lib.cairo:99:5 + fn do_with_ref_self(ref self: ContractState) -> felt252; + ^******************************************************^ -error: Only one parameter of type IWorldDispatcher is allowed. - --> test_src/lib.cairo:134:9 - fn do_with_several_world_dispatchers( - ^***********************************^ +error: World parameter must be the first parameter. + --> test_src/lib.cairo:104:5 + fn do_with_world_not_first(vec: Vec2, ref world: IWorldDispatcher) -> felt252; + ^****************************************************************************^ -error: The IWorldDispatcher parameter must be named 'world'. - --> test_src/lib.cairo:140:42 - fn do_with_world_not_named_world(another_world: IWorldDispatcher) -> felt252 { - ^*****************************^ +error: In a dojo contract or interface, you should use `ref world: IWorldDispatcher` instead of `ref self: ContractState`. + --> test_src/lib.cairo:111:9 + fn do_with_self(ref self: ContractState) -> felt252 { + ^***************************************************^ -error: The IWorldDispatcher parameter must be named 'world'. - --> test_src/lib.cairo:145:35 - self: @ContractState, another_world: IWorldDispatcher - ^*****************************^ +error: In a dojo contract or interface, you should use `ref world: IWorldDispatcher` instead of `ref self: ContractState`. + --> test_src/lib.cairo:115:9 + fn do_with_ref_self(ref self: ContractState) -> felt252 { + ^*******************************************************^ -error: The IWorldDispatcher parameter must be the first parameter of the function (self excluded). - --> test_src/lib.cairo:150:47 - fn do_with_world_not_first(vec: Vec2, world: IWorldDispatcher) -> felt252 { - ^*********************^ +error: World parameter must be the first parameter. + --> test_src/lib.cairo:129:9 + fn do_with_world_not_first(vec: Vec2, ref world: IWorldDispatcher) -> felt252 { + ^*****************************************************************************^ -error: The IWorldDispatcher parameter must be the first parameter of the function (self excluded). - --> test_src/lib.cairo:155:46 - self: @ContractState, vec: Vec2, world: IWorldDispatcher - ^*********************^ +error: World parameter must be a snapshot if `ref` is not used. + --> test_src/lib.cairo:180:5 + fn dojo_init( + ^***********^ error: Unsupported attribute. --> test_src/lib.cairo:1:1 @@ -408,32 +387,27 @@ error: Unsupported attribute. ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:84:5 + --> test_src/lib.cairo:82:5 #[my_attr] ^********^ error: Unsupported attribute. - --> test_src/lib.cairo:93:1 -#[dojo::contract(allow_ref_self)] -^*******************************^ - -error: Unsupported attribute. - --> test_src/lib.cairo:126:1 + --> test_src/lib.cairo:107:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:162:1 + --> test_src/lib.cairo:135:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:199:1 + --> test_src/lib.cairo:178:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:216:1 + --> test_src/lib.cairo:195:1 #[dojo::contract] ^***************^ @@ -673,217 +647,172 @@ error: Unsupported attribute. ^***************^ error: Unknown inline item macro: 'component'. - --> test_src/lib.cairo:93:1 -#[dojo::contract(allow_ref_self)] -^*******************************^ - -error: Unsupported attribute. - --> test_src/lib.cairo:93:1 -#[dojo::contract(allow_ref_self)] -^*******************************^ - -error: Unsupported attribute. - --> test_src/lib.cairo:93:1 -#[dojo::contract(allow_ref_self)] -^*******************************^ - -error: Unsupported attribute. - --> test_src/lib.cairo:93:1 -#[dojo::contract(allow_ref_self)] -^*******************************^ - -error: Unsupported attribute. - --> test_src/lib.cairo:95:5 - #[abi(embed_v0)] - ^**************^ - -error: Unsupported attribute. - --> test_src/lib.cairo:93:1 -#[dojo::contract(allow_ref_self)] -^*******************************^ - -error: Unsupported attribute. - --> test_src/lib.cairo:93:1 -#[dojo::contract(allow_ref_self)] -^*******************************^ - -error: Unsupported attribute. - --> test_src/lib.cairo:93:1 -#[dojo::contract(allow_ref_self)] -^*******************************^ - -error: Unsupported attribute. - --> test_src/lib.cairo:93:1 -#[dojo::contract(allow_ref_self)] -^*******************************^ - -error: Unknown inline item macro: 'component'. - --> test_src/lib.cairo:126:1 + --> test_src/lib.cairo:107:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:126:1 + --> test_src/lib.cairo:107:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:126:1 + --> test_src/lib.cairo:107:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:126:1 + --> test_src/lib.cairo:107:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:128:5 + --> test_src/lib.cairo:109:5 #[abi(embed_v0)] ^**************^ error: Unsupported attribute. - --> test_src/lib.cairo:126:1 + --> test_src/lib.cairo:107:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:126:1 + --> test_src/lib.cairo:107:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:126:1 + --> test_src/lib.cairo:107:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:126:1 + --> test_src/lib.cairo:107:1 #[dojo::contract] ^***************^ error: Unknown inline item macro: 'component'. - --> test_src/lib.cairo:162:1 + --> test_src/lib.cairo:135:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:162:1 + --> test_src/lib.cairo:135:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:162:1 + --> test_src/lib.cairo:135:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:162:1 + --> test_src/lib.cairo:135:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:169:5 + --> test_src/lib.cairo:142:5 #[abi(embed_v0)] ^**************^ error: Unsupported attribute. - --> test_src/lib.cairo:162:1 + --> test_src/lib.cairo:135:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:162:1 + --> test_src/lib.cairo:135:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:162:1 + --> test_src/lib.cairo:135:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:162:1 + --> test_src/lib.cairo:135:1 #[dojo::contract] ^***************^ error: Unknown inline item macro: 'component'. - --> test_src/lib.cairo:199:1 + --> test_src/lib.cairo:178:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:199:1 + --> test_src/lib.cairo:178:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:199:1 + --> test_src/lib.cairo:178:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:199:1 + --> test_src/lib.cairo:178:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:199:1 + --> test_src/lib.cairo:178:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:199:1 + --> test_src/lib.cairo:178:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:199:1 + --> test_src/lib.cairo:178:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:199:1 + --> test_src/lib.cairo:178:1 #[dojo::contract] ^***************^ error: Unknown inline item macro: 'component'. - --> test_src/lib.cairo:216:1 + --> test_src/lib.cairo:195:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:216:1 + --> test_src/lib.cairo:195:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:216:1 + --> test_src/lib.cairo:195:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:216:1 + --> test_src/lib.cairo:195:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:216:1 + --> test_src/lib.cairo:195:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:216:1 + --> test_src/lib.cairo:195:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:216:1 + --> test_src/lib.cairo:195:1 #[dojo::contract] ^***************^ error: Unsupported attribute. - --> test_src/lib.cairo:216:1 + --> test_src/lib.cairo:195:1 #[dojo::contract] ^***************^ @@ -901,11 +830,6 @@ mod testcomponent2 { struct Storage {} } -#[starknet::interface] -trait IAllowedRefSelf { - fn spawn(ref self: T); -} - #[starknet::contract] mod spawn { use dojo::world; @@ -1210,95 +1134,31 @@ impl EventDrop of core::traits::Drop::; #[starknet::interface] trait IFaultyTrait { - fn do_ref_self(ref self: TContractState); - #[my_attr] fn do_with_attrs(self: @TContractState, p1: u8) -> u16; } - #[starknet::contract] - mod ContractAllowedRefSelf { - use dojo::world; - use dojo::world::IWorldDispatcher; - use dojo::world::IWorldDispatcherTrait; - use dojo::world::IWorldProvider; - use dojo::world::IDojoResourceProvider; - - #[abi(embed_v0)] - impl DojoResourceProviderImpl of IDojoResourceProvider { - fn dojo_resource(self: @ContractState) -> felt252 { - 'ContractAllowedRefSelf' - } - } - - #[abi(embed_v0)] - impl WorldProviderImpl of IWorldProvider { - fn world(self: @ContractState) -> IWorldDispatcher { - self.world_dispatcher.read() - } - } - - #[abi(embed_v0)] - impl UpgradableImpl = dojo::components::upgradeable::upgradeable::UpgradableImpl; - - #[abi(embed_v0)] - impl AllowedImpl of IAllowedRefSelf { - fn spawn(ref self: ContractState) {} - } - - #[starknet::interface] - trait IDojoInit { - fn dojo_init(self: @ContractState); - } - - #[abi(embed_v0)] - impl IDojoInitImpl of IDojoInit { - fn dojo_init(self: @ContractState) { - assert(starknet::get_caller_address() == self.world().contract_address, 'Only world can init'); - } - } - - #[event] - #[derive(Drop, starknet::Event)] - enum Event { - UpgradeableEvent: dojo::components::upgradeable::upgradeable::Event, - } - - #[storage] - struct Storage { - world_dispatcher: IWorldDispatcher, - #[substorage(v0)] - upgradeable: dojo::components::upgradeable::upgradeable::Storage, - } -impl EventDrop of core::traits::Drop::; - - } - #[starknet::interface] trait INominalTrait { - fn do_no_param(self: @TContractState); - fn do_no_param_but_self(self: @TContractState); - fn do_params(self: @TContractState, p1: dojo_examples::models::Direction, p2: u8); - fn do_params_and_self(self: @TContractState, p2: u8); - fn do_return_value(self: @TContractState, p1: u8) -> u16; + fn do_no_param(self: @TContractState) -> felt252; + fn do_no_param_but_world(self: @TContractState) -> felt252; + fn do_no_param_but_world_ref(ref self: TContractState) -> felt252; + fn do_params_no_world(self: @TContractState, p1: felt252, p2: u8) -> felt252; + fn do_params_and_world(self: @TContractState, p2: u8) -> felt252; + fn do_params_and_world_ref(ref self: TContractState, p2: u8) -> felt252; } #[starknet::interface] - trait IWorldTrait { - fn do_with_ref_self(self: @TContractState, ref self: ContractState) -> felt252; + trait IFaultyTrait { + fn do_with_self(self: @TContractState, self: @ContractState) -> felt252; + fn do_with_ref_self(self: @TContractState, ref self: ContractState) -> felt252; fn do_with_several_world_dispatchers( -self: @TContractState, world: IWorldDispatcher, vec: Vec2, another_world: IWorldDispatcher - ) -> felt252; - fn do_with_world_not_named_world(self: @TContractState, another_world: IWorldDispatcher) -> felt252; - fn do_with_self_and_world_not_named_world( -self: @TContractState, self: @ContractState, another_world: IWorldDispatcher - ); - fn do_with_world_not_first(self: @TContractState, vec: Vec2, world: IWorldDispatcher) -> felt252; - fn do_with_self_and_world_not_first( -self: @TContractState, self: @ContractState, vec: Vec2, world: IWorldDispatcher +self: @TContractState, vec: Vec2, ref another_world: IWorldDispatcher ) -> felt252; + fn do_with_world_not_named_world(self: @TContractState, another_world: @IWorldDispatcher) -> felt252; + fn do_with_world_not_first(self: @TContractState, vec: Vec2, ref world: IWorldDispatcher) -> felt252; } @@ -1328,34 +1188,27 @@ self: @TContractState, self: @ContractState, vec: Vec2, world: IWorldDis impl UpgradableImpl = dojo::components::upgradeable::upgradeable::UpgradableImpl; #[abi(embed_v0)] - impl TestWorldImpl of IWorldTrait { - fn do_with_ref_self(ref self: ContractState) -> felt252 { + impl TestFaultyImpl of IFaultyTrait { + fn do_with_self(self: @ContractState, ref self: ContractState) -> felt252 { 'land' } - fn do_with_several_world_dispatchers( -self: @ContractState, world: IWorldDispatcher, vec: Vec2, another_world: IWorldDispatcher - ) -> felt252 { - 'land' - } - - fn do_with_world_not_named_world(self: @ContractState, another_world: IWorldDispatcher) -> felt252 { + fn do_with_ref_self(self: @ContractState, ref self: ContractState) -> felt252 { 'land' } - fn do_with_self_and_world_not_named_world( - self: @ContractState, another_world: IWorldDispatcher + fn do_with_several_world_dispatchers( +self: @ContractState, vec: Vec2, ref another_world: IWorldDispatcher ) -> felt252 { +let world = self.world_dispatcher.read(); 'land' } - fn do_with_world_not_first(self: @ContractState, vec: Vec2, world: IWorldDispatcher) -> felt252 { + fn do_with_world_not_named_world(self: @ContractState, another_world: @IWorldDispatcher) -> felt252 { 'land' } - fn do_with_self_and_world_not_first( - self: @ContractState, vec: Vec2, world: IWorldDispatcher - ) -> felt252 { + fn do_with_world_not_first(self: @ContractState, vec: Vec2) -> felt252 { 'land' } } @@ -1419,23 +1272,31 @@ impl EventDrop of core::traits::Drop::; } #[abi(embed_v0)] - impl TestWorldImpl of IWorldTrait { - fn do(self: @ContractState, vec: Vec2) -> felt252 { + impl TestNominalImpl of INominalTrait { + fn do_no_param(self: @ContractState) -> felt252 { + 'land' + } + + fn do_no_param_but_world(self: @ContractState) -> felt252 { +let world = self.world_dispatcher.read(); + 'land' + } + + fn do_no_param_but_world_ref(ref self: ContractState) -> felt252 { +let world = self.world_dispatcher.read(); 'land' } - fn do_with_self(self: @ContractState, vec: Vec2) -> felt252 { + fn do_params_no_world(self: @ContractState, p1: felt252, p2: u8) -> felt252 { 'land' } - fn do_with_world_first(self: @ContractState, vec: Vec2) -> felt252 { + fn do_params_and_world(self: @ContractState, p2: u8) -> felt252 { let world = self.world_dispatcher.read(); 'land' } - fn do_with_self_and_world_first( - self: @ContractState, vec: Vec2 - ) -> felt252 { + fn do_params_and_world_ref(ref self: ContractState, p2: u8) -> felt252 { let world = self.world_dispatcher.read(); 'land' } diff --git a/crates/dojo-lang/src/syntax/mod.rs b/crates/dojo-lang/src/syntax/mod.rs new file mode 100644 index 0000000000..cfd40715a2 --- /dev/null +++ b/crates/dojo-lang/src/syntax/mod.rs @@ -0,0 +1,3 @@ +pub mod self_param; +pub mod utils; +pub mod world_param; diff --git a/crates/dojo-lang/src/syntax/self_param.rs b/crates/dojo-lang/src/syntax/self_param.rs new file mode 100644 index 0000000000..2f21ae7522 --- /dev/null +++ b/crates/dojo-lang/src/syntax/self_param.rs @@ -0,0 +1,49 @@ +use cairo_lang_defs::plugin::PluginDiagnostic; +use cairo_lang_diagnostics::Severity; +use cairo_lang_syntax::node::db::SyntaxGroup; +use cairo_lang_syntax::node::{ast, ids}; + +use crate::syntax::utils as syntax_utils; + +const SELF_PARAM_NAME: &str = "self"; + +/// Checks if the given function parameter is using `self` instead of `world` param. +/// Adds diagnostic if that case. +/// +/// # Arguments +/// +/// - `db` - The syntax group. +/// - `param_list` - The parameter list of the function. +/// - `fn_diagnostic_item` - The diagnostic item of the function. +/// - `diagnostics` - The diagnostics vector. +pub fn check_parameter( + db: &dyn SyntaxGroup, + param_list: &ast::ParamList, + fn_diagnostic_item: ids::SyntaxStablePtrId, + diagnostics: &mut Vec, +) { + if param_list.elements(db).is_empty() { + return; + } + + let param_0 = param_list.elements(db)[0].clone(); + let (name, modifier, _) = syntax_utils::get_parameter_info(db, param_0.clone()); + + if name.eq(SELF_PARAM_NAME) { + let (expected, actual) = if modifier.eq(&"ref".to_string()) { + ("ref world: IWorldDispatcher", "ref self: ContractState") + } else { + ("world: @IWorldDispatcher", "self: @ContractState") + }; + + diagnostics.push(PluginDiagnostic { + stable_ptr: fn_diagnostic_item, + message: format!( + "In a dojo contract or interface, you should use `{}` instead of `{}`.", + expected, actual + ) + .to_string(), + severity: Severity::Error, + }); + } +} diff --git a/crates/dojo-lang/src/syntax/utils.rs b/crates/dojo-lang/src/syntax/utils.rs new file mode 100644 index 0000000000..b4bf5298a1 --- /dev/null +++ b/crates/dojo-lang/src/syntax/utils.rs @@ -0,0 +1,20 @@ +use cairo_lang_syntax::node::db::SyntaxGroup; +use cairo_lang_syntax::node::{ast, Terminal, TypedSyntaxNode}; + +/// Gets the name, modifiers and type of a function parameter. +/// +/// # Arguments +/// +/// * `db` - The syntax group. +/// * `param` - The parameter. +/// +/// # Returns +/// +/// * A tuple containing the name, modifiers and type of the parameter. +pub fn get_parameter_info(db: &dyn SyntaxGroup, param: ast::Param) -> (String, String, String) { + let name = param.name(db).text(db).trim().to_string(); + let modifiers = param.modifiers(db).as_syntax_node().get_text(db).trim().to_string(); + let param_type = param.type_clause(db).ty(db).as_syntax_node().get_text(db).trim().to_string(); + + (name, modifiers, param_type) +} diff --git a/crates/dojo-lang/src/syntax/world_param.rs b/crates/dojo-lang/src/syntax/world_param.rs new file mode 100644 index 0000000000..aac7939770 --- /dev/null +++ b/crates/dojo-lang/src/syntax/world_param.rs @@ -0,0 +1,92 @@ +use cairo_lang_defs::plugin::PluginDiagnostic; +use cairo_lang_diagnostics::Severity; +use cairo_lang_syntax::node::db::SyntaxGroup; +use cairo_lang_syntax::node::{ast, ids}; + +use super::utils as syntax_utils; + +const WORLD_PARAM_NAME: &str = "world"; +const WORLD_PARAM_TYPE: &str = "IWorldDispatcher"; +const WORLD_PARAM_TYPE_SNAPSHOT: &str = "@IWorldDispatcher"; + +#[derive(Debug, PartialEq, Eq)] +pub enum WorldParamInjectionKind { + None, + View, + External, +} + +/// Checks if the given parameter is the `world` parameter. +/// +/// The `world` must be named `world`, and be placed first in the argument list. +pub fn is_world_param(param_name: &str, param_type: &str) -> bool { + param_name == WORLD_PARAM_NAME + && (param_type == WORLD_PARAM_TYPE || param_type == WORLD_PARAM_TYPE_SNAPSHOT) +} + +/// Extracts the state mutability of a function from the `world` parameter. +/// +/// Checks if the function has only one `world` parameter (or None). +/// The `world` must be named `world`, and be placed first in the argument list. +/// +/// `fn func1(ref world)` // would be external. +/// `fn func2(world)` // would be view. +/// `fn func3()` // would be view. +/// +/// Returns +/// * The [`WorldParamInjectionKind`] determined from the function's params list. +pub fn parse_world_injection( + db: &dyn SyntaxGroup, + param_list: ast::ParamList, + fn_diagnostic_item: ids::SyntaxStablePtrId, + diagnostics: &mut Vec, +) -> WorldParamInjectionKind { + let mut has_world_injected = false; + let mut injection_kind = WorldParamInjectionKind::None; + + param_list.elements(db).iter().enumerate().for_each(|(idx, param)| { + let (name, modifiers, param_type) = syntax_utils::get_parameter_info(db, param.clone()); + + if !is_world_param(&name, ¶m_type) { + return; + } + + if has_world_injected { + diagnostics.push(PluginDiagnostic { + stable_ptr: fn_diagnostic_item, + message: "Only one world parameter is allowed".to_string(), + severity: Severity::Error, + }); + + return; + } else { + has_world_injected = true; + } + + if idx != 0 { + diagnostics.push(PluginDiagnostic { + stable_ptr: fn_diagnostic_item, + message: "World parameter must be the first parameter.".to_string(), + severity: Severity::Error, + }); + + return; + } + + if modifiers.contains(&"ref".to_string()) { + injection_kind = WorldParamInjectionKind::External; + } else { + injection_kind = WorldParamInjectionKind::View; + + if param_type == WORLD_PARAM_TYPE { + diagnostics.push(PluginDiagnostic { + stable_ptr: fn_diagnostic_item, + message: "World parameter must be a snapshot if `ref` is not used.".to_string(), + severity: Severity::Error, + }); + } + } + }); + + injection_kind +} diff --git a/crates/torii/grpc/src/client.rs b/crates/torii/grpc/src/client.rs index 4a11e04777..7f34a60209 100644 --- a/crates/torii/grpc/src/client.rs +++ b/crates/torii/grpc/src/client.rs @@ -3,7 +3,7 @@ use std::num::ParseIntError; use futures_util::stream::MapOk; use futures_util::{Stream, StreamExt, TryStreamExt}; -use starknet::core::types::{FromStrError, StateUpdate}; +use starknet::core::types::{FromStrError, StateDiff, StateUpdate}; use starknet_crypto::FieldElement; use crate::proto::world::{ @@ -105,9 +105,9 @@ impl WorldClient { .map_err(Error::Grpc) .map(|res| res.into_inner())?; - Ok(EntityUpdateStreaming(stream.map_ok(Box::new(|res| { - let entity = res.entity.expect("entity must exist"); - entity.try_into().expect("must able to serialize") + Ok(EntityUpdateStreaming(stream.map_ok(Box::new(|res| match res.entity { + Some(entity) => entity.try_into().expect("must able to serialize"), + None => Entity { hashed_keys: FieldElement::ZERO, models: vec![] }, })))) } @@ -144,9 +144,11 @@ impl WorldClient { .map_err(Error::Grpc) .map(|res| res.into_inner())?; - Ok(ModelDiffsStreaming(stream.map_ok(Box::new(|res| { - let update = res.model_update.expect("qed; state update must exist"); - TryInto::::try_into(update).expect("must able to serialize") + Ok(ModelDiffsStreaming(stream.map_ok(Box::new(|res| match res.model_update { + Some(update) => { + TryInto::::try_into(update).expect("must able to serialize") + } + None => empty_state_update(), })))) } } @@ -184,3 +186,19 @@ impl Stream for EntityUpdateStreaming { self.0.poll_next_unpin(cx) } } + +fn empty_state_update() -> StateUpdate { + StateUpdate { + block_hash: FieldElement::ZERO, + new_root: FieldElement::ZERO, + old_root: FieldElement::ZERO, + state_diff: StateDiff { + declared_classes: vec![], + deployed_contracts: vec![], + deprecated_declared_classes: vec![], + nonces: vec![], + replaced_classes: vec![], + storage_diffs: vec![], + }, + } +} diff --git a/crates/torii/grpc/src/server/subscriptions/entity.rs b/crates/torii/grpc/src/server/subscriptions/entity.rs index 1573b5c61f..f9d4ae0d96 100644 --- a/crates/torii/grpc/src/server/subscriptions/entity.rs +++ b/crates/torii/grpc/src/server/subscriptions/entity.rs @@ -20,6 +20,7 @@ use torii_core::types::Entity; use tracing::{error, trace}; use crate::proto; +use crate::proto::world::SubscribeEntityResponse; pub(crate) const LOG_TARGET: &str = "torii::grpc::server::subscriptions::entity"; @@ -43,6 +44,11 @@ impl EntityManager { let id = rand::thread_rng().gen::(); let (sender, receiver) = channel(1); + // NOTE: unlock issue with firefox/safari + // initially send empty stream message to return from + // initial subscribe call + let _ = sender.send(Ok(SubscribeEntityResponse { entity: None })).await; + self.subscribers.write().await.insert( id, EntitiesSubscriber { hashed_keys: hashed_keys.iter().cloned().collect(), sender }, diff --git a/crates/torii/grpc/src/server/subscriptions/event_message.rs b/crates/torii/grpc/src/server/subscriptions/event_message.rs index 736f88c0f9..67cf1cf172 100644 --- a/crates/torii/grpc/src/server/subscriptions/event_message.rs +++ b/crates/torii/grpc/src/server/subscriptions/event_message.rs @@ -20,6 +20,7 @@ use torii_core::types::EventMessage; use tracing::{error, trace}; use crate::proto; +use crate::proto::world::SubscribeEntityResponse; pub(crate) const LOG_TARGET: &str = "torii::grpc::server::subscriptions::event_message"; pub struct EventMessagesSubscriber { @@ -42,6 +43,11 @@ impl EventMessageManager { let id = rand::thread_rng().gen::(); let (sender, receiver) = channel(1); + // NOTE: unlock issue with firefox/safari + // initially send empty stream message to return from + // initial subscribe call + let _ = sender.send(Ok(SubscribeEntityResponse { entity: None })).await; + self.subscribers.write().await.insert( id, EventMessagesSubscriber { hashed_keys: hashed_keys.iter().cloned().collect(), sender }, diff --git a/crates/torii/grpc/src/server/subscriptions/model_diff.rs b/crates/torii/grpc/src/server/subscriptions/model_diff.rs index ad257c719c..8e1f4e80cf 100644 --- a/crates/torii/grpc/src/server/subscriptions/model_diff.rs +++ b/crates/torii/grpc/src/server/subscriptions/model_diff.rs @@ -20,6 +20,7 @@ use tracing::{debug, error, trace}; use super::error::SubscriptionError; use crate::proto; +use crate::proto::world::SubscribeModelsResponse; use crate::types::KeysClause; pub(crate) const LOG_TARGET: &str = "torii::grpc::server::subscriptions::model_diff"; @@ -82,6 +83,11 @@ impl StateDiffManager { .flatten() .collect::>(); + // NOTE: unlock issue with firefox/safari + // initially send empty stream message to return from + // initial subscribe call + let _ = sender.send(Ok(SubscribeModelsResponse { model_update: None })).await; + self.subscribers .write() .await diff --git a/crates/torii/libp2p/Cargo.toml b/crates/torii/libp2p/Cargo.toml index d95c0ff66a..08c0472988 100644 --- a/crates/torii/libp2p/Cargo.toml +++ b/crates/torii/libp2p/Cargo.toml @@ -26,6 +26,7 @@ starknet.workspace = true thiserror.workspace = true tracing-subscriber = { version = "0.3", features = [ "env-filter" ] } tracing.workspace = true +cainome.workspace = true [dev-dependencies] dojo-test-utils.workspace = true diff --git a/crates/torii/libp2p/src/server/mod.rs b/crates/torii/libp2p/src/server/mod.rs index 3562b5a46f..4a78db033b 100644 --- a/crates/torii/libp2p/src/server/mod.rs +++ b/crates/torii/libp2p/src/server/mod.rs @@ -20,7 +20,6 @@ use libp2p::swarm::{NetworkBehaviour, SwarmEvent}; use libp2p::{identify, identity, noise, ping, relay, tcp, yamux, PeerId, Swarm, Transport}; use libp2p_webrtc as webrtc; use rand::thread_rng; -use serde_json::Number; use starknet::core::types::{BlockId, BlockTag, FunctionCall}; use starknet::core::utils::get_selector_from_name; use starknet::providers::Provider; @@ -692,90 +691,6 @@ fn read_or_create_certificate(path: &Path) -> anyhow::Result { Ok(cert) } -// Deprecated. These should be potentially removed. As Ty -> TypedData is now done -// on the SDKs side -pub fn parse_ty_to_object(ty: &Ty) -> Result, Error> { - match ty { - Ty::Struct(struct_ty) => { - let mut object = IndexMap::new(); - for member in &struct_ty.children { - let mut member_object = IndexMap::new(); - member_object.insert("key".to_string(), PrimitiveType::Bool(member.key)); - member_object.insert( - "type".to_string(), - PrimitiveType::String(ty_to_string_type(&member.ty)), - ); - member_object.insert("value".to_string(), parse_ty_to_primitive(&member.ty)?); - object.insert(member.name.clone(), PrimitiveType::Object(member_object)); - } - Ok(object) - } - _ => Err(Error::InvalidMessageError("Expected Struct type".to_string())), - } -} - -pub fn ty_to_string_type(ty: &Ty) -> String { - match ty { - Ty::Primitive(primitive) => match primitive { - Primitive::U8(_) => "u8".to_string(), - Primitive::U16(_) => "u16".to_string(), - Primitive::U32(_) => "u32".to_string(), - Primitive::USize(_) => "usize".to_string(), - Primitive::U64(_) => "u64".to_string(), - Primitive::U128(_) => "u128".to_string(), - Primitive::U256(_) => "u256".to_string(), - Primitive::Felt252(_) => "felt252".to_string(), - Primitive::ClassHash(_) => "class_hash".to_string(), - Primitive::ContractAddress(_) => "contract_address".to_string(), - Primitive::Bool(_) => "bool".to_string(), - }, - Ty::Struct(_) => "struct".to_string(), - Ty::Tuple(_) => "tuple".to_string(), - Ty::Array(_) => "array".to_string(), - Ty::ByteArray(_) => "bytearray".to_string(), - Ty::Enum(_) => "enum".to_string(), - } -} - -pub fn parse_ty_to_primitive(ty: &Ty) -> Result { - match ty { - Ty::Primitive(primitive) => match primitive { - Primitive::U8(value) => { - Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) - } - Primitive::U16(value) => { - Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) - } - Primitive::U32(value) => { - Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) - } - Primitive::USize(value) => { - Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) - } - Primitive::U64(value) => { - Ok(PrimitiveType::Number(Number::from(value.map(|v| v).unwrap_or(0u64)))) - } - Primitive::U128(value) => Ok(PrimitiveType::String( - value.map(|v| v.to_string()).unwrap_or_else(|| "0".to_string()), - )), - Primitive::U256(value) => Ok(PrimitiveType::String( - value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), - )), - Primitive::Felt252(value) => Ok(PrimitiveType::String( - value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), - )), - Primitive::ClassHash(value) => Ok(PrimitiveType::String( - value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), - )), - Primitive::ContractAddress(value) => Ok(PrimitiveType::String( - value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), - )), - Primitive::Bool(value) => Ok(PrimitiveType::Bool(value.unwrap_or(false))), - }, - _ => Err(Error::InvalidMessageError("Expected Primitive type".to_string())), - } -} - #[cfg(test)] mod tests { use tempfile::tempdir; diff --git a/crates/torii/libp2p/src/tests.rs b/crates/torii/libp2p/src/tests.rs index 3d0e5581be..7db5ab677b 100644 --- a/crates/torii/libp2p/src/tests.rs +++ b/crates/torii/libp2p/src/tests.rs @@ -268,20 +268,25 @@ mod test { #[cfg(not(target_arch = "wasm32"))] #[tokio::test] async fn test_client_messaging() -> Result<(), Box> { + use std::time::Duration; + use dojo_test_utils::sequencer::{ get_default_test_starknet_config, SequencerConfig, TestSequencer, }; use dojo_types::schema::{Member, Struct, Ty}; + use dojo_world::contracts::abi::model::Layout; use indexmap::IndexMap; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; + use starknet::signers::SigningKey; use starknet_crypto::FieldElement; + use tokio::select; use tokio::time::sleep; use torii_core::sql::Sql; - use crate::server::{parse_ty_to_object, Relay}; - use crate::typed_data::{Domain, TypedData}; + use crate::server::Relay; + use crate::typed_data::{Domain, Field, SimpleField, TypedData}; use crate::types::Message; let _ = tracing_subscriber::fmt() @@ -300,7 +305,36 @@ mod test { .await; let provider = JsonRpcClient::new(HttpTransport::new(sequencer.url())); - let db = Sql::new(pool.clone(), FieldElement::from_bytes_be(&[0; 32]).unwrap()).await?; + let account = sequencer.raw_account(); + + let mut db = Sql::new(pool.clone(), FieldElement::from_bytes_be(&[0; 32]).unwrap()).await?; + + // Register the model of our Message + db.register_model( + Ty::Struct(Struct { + name: "Message".to_string(), + children: vec![ + Member { + name: "identity".to_string(), + ty: Ty::Primitive(Primitive::ContractAddress(None)), + key: true, + }, + Member { + name: "message".to_string(), + ty: Ty::ByteArray("".to_string()), + key: false, + }, + ], + }), + Layout::Fixed(vec![]), + FieldElement::ZERO, + FieldElement::ZERO, + 0, + 0, + 0, + ) + .await + .unwrap(); // Initialize the relay server let mut relay_server = Relay::new(db, provider, 9900, 9901, None, None)?; @@ -314,27 +348,57 @@ mod test { client.event_loop.lock().await.run().await; }); - let mut data = Struct { name: "Message".to_string(), children: vec![] }; - - data.children.push(Member { - name: "player".to_string(), - ty: dojo_types::schema::Ty::Primitive( - dojo_types::primitive::Primitive::ContractAddress(Some( - FieldElement::from_bytes_be(&[0; 32]).unwrap(), - )), - ), - key: true, - }); - - data.children.push(Member { - name: "message".to_string(), - ty: dojo_types::schema::Ty::Primitive(dojo_types::primitive::Primitive::U8(Some(0))), - key: false, - }); - let mut typed_data = TypedData::new( - IndexMap::new(), - "Message", + IndexMap::from_iter(vec![ + ( + "OffchainMessage".to_string(), + vec![ + Field::SimpleType(SimpleField { + name: "model".to_string(), + r#type: "shortstring".to_string(), + }), + Field::SimpleType(SimpleField { + name: "Message".to_string(), + r#type: "Model".to_string(), + }), + ], + ), + ( + "Model".to_string(), + vec![ + Field::SimpleType(SimpleField { + name: "identity".to_string(), + r#type: "ContractAddress".to_string(), + }), + Field::SimpleType(SimpleField { + name: "message".to_string(), + r#type: "string".to_string(), + }), + ], + ), + ( + "StarknetDomain".to_string(), + vec![ + Field::SimpleType(SimpleField { + name: "name".to_string(), + r#type: "shortstring".to_string(), + }), + Field::SimpleType(SimpleField { + name: "version".to_string(), + r#type: "shortstring".to_string(), + }), + Field::SimpleType(SimpleField { + name: "chainId".to_string(), + r#type: "shortstring".to_string(), + }), + Field::SimpleType(SimpleField { + name: "revision".to_string(), + r#type: "shortstring".to_string(), + }), + ], + ), + ]), + "OffchainMessage", Domain::new("Message", "1", "0x0", Some("1")), IndexMap::new(), ); @@ -346,37 +410,50 @@ mod test { typed_data.message.insert( "Message".to_string(), crate::typed_data::PrimitiveType::Object( - parse_ty_to_object(&Ty::Struct(data.clone())).unwrap(), + vec![ + ( + "identity".to_string(), + crate::typed_data::PrimitiveType::String( + account.account_address.to_string(), + ), + ), + ( + "message".to_string(), + crate::typed_data::PrimitiveType::String("mimi".to_string()), + ), + ] + .into_iter() + .collect(), ), ); + let message_hash = typed_data.encode(account.account_address).unwrap(); + let signature = + SigningKey::from_secret_scalar(account.private_key).sign(&message_hash).unwrap(); + client .command_sender .publish(Message { message: typed_data, - signature_r: FieldElement::from_bytes_be(&[0; 32]).unwrap(), - signature_s: FieldElement::from_bytes_be(&[0; 32]).unwrap(), + signature_r: signature.r, + signature_s: signature.s, }) .await?; sleep(std::time::Duration::from_secs(2)).await; - Ok(()) - // loop { - // select! { - // entity = sqlx::query("SELECT * FROM entities WHERE id = ?") - // .bind(format!("{:#x}", FieldElement::from_bytes_be(&[0; - // 32]).unwrap())).fetch_one(&pool) => { if let Ok(_) = entity { - // println!("Test OK: Received message within 5 seconds."); - // return Ok(()); - // } - // } - // _ = sleep(Duration::from_secs(5)) => { - // println!("Test Failed: Did not receive message within 5 seconds."); - // return Err("Timeout reached without receiving a message".into()); - // } - // } - // } + loop { + select! { + entity = sqlx::query("SELECT * FROM entities").fetch_one(&pool) => if entity.is_ok() { + println!("Test OK: Received message within 5 seconds."); + return Ok(()); + }, + _ = sleep(Duration::from_secs(5)) => { + println!("Test Failed: Did not receive message within 5 seconds."); + return Err("Timeout reached without receiving a message".into()); + } + } + } } #[cfg(target_arch = "wasm32")] diff --git a/crates/torii/libp2p/src/typed_data.rs b/crates/torii/libp2p/src/typed_data.rs index dc752f751b..733c7ca29d 100644 --- a/crates/torii/libp2p/src/typed_data.rs +++ b/crates/torii/libp2p/src/typed_data.rs @@ -1,11 +1,10 @@ use std::str::FromStr; +use cainome::cairo_serde::ByteArray; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; use serde_json::Number; -use starknet::core::utils::{ - cairo_short_string_to_felt, get_selector_from_name, CairoShortStringToFeltError, -}; +use starknet::core::utils::{cairo_short_string_to_felt, get_selector_from_name}; use starknet_crypto::{poseidon_hash_many, FieldElement}; use crate::errors::Error; @@ -176,39 +175,6 @@ pub fn encode_type(name: &str, types: &IndexMap>) -> Result Result<(Vec, FieldElement, usize), CairoShortStringToFeltError> { - let short_strings: Vec<&str> = split_long_string(target_string); - let remainder = short_strings.last().unwrap_or(&""); - - let mut short_strings_encoded = short_strings - .iter() - .map(|&s| cairo_short_string_to_felt(s)) - .collect::, _>>()?; - - let (pending_word, pending_word_length) = if remainder.is_empty() || remainder.len() == 31 { - (FieldElement::ZERO, 0) - } else { - (short_strings_encoded.pop().unwrap(), remainder.len()) - }; - - Ok((short_strings_encoded, pending_word, pending_word_length)) -} - -fn split_long_string(long_str: &str) -> Vec<&str> { - let mut result = Vec::new(); - - let mut start = 0; - while start < long_str.len() { - let end = (start + 31).min(long_str.len()); - result.push(&long_str[start..end]); - start = end; - } - - result -} - #[derive(Debug, Default)] pub struct Ctx { pub base_type: String, @@ -273,7 +239,7 @@ fn get_hex(value: &str) -> Result { } else { // assume its a short string and encode cairo_short_string_to_felt(value) - .map_err(|_| Error::InvalidMessageError("Invalid short string".to_string())) + .map_err(|e| Error::InvalidMessageError(format!("Invalid shortstring for felt: {}", e))) } } @@ -330,8 +296,11 @@ impl PrimitiveType { let type_hash = encode_type(r#type, if ctx.is_preset { preset_types } else { types })?; - hashes.push(get_selector_from_name(&type_hash).map_err(|_| { - Error::InvalidMessageError(format!("Invalid type {} for selector", r#type)) + hashes.push(get_selector_from_name(&type_hash).map_err(|e| { + Error::InvalidMessageError(format!( + "Invalid type {} for selector: {}", + r#type, e + )) })?); for (field_name, value) in obj { @@ -368,24 +337,23 @@ impl PrimitiveType { "shortstring" => get_hex(string), "string" => { // split the string into short strings and encode - let byte_array = byte_array_from_string(string).map_err(|_| { - Error::InvalidMessageError("Invalid short string".to_string()) + let byte_array = ByteArray::from_string(string).map_err(|e| { + Error::InvalidMessageError(format!("Invalid string for bytearray: {}", e)) })?; - let mut hashes = vec![FieldElement::from(byte_array.0.len())]; + let mut hashes = vec![FieldElement::from(byte_array.data.len())]; - for hash in byte_array.0 { - hashes.push(hash); + for hash in byte_array.data { + hashes.push(hash.felt()); } - hashes.push(byte_array.1); - hashes.push(FieldElement::from(byte_array.2)); + hashes.push(byte_array.pending_word); + hashes.push(FieldElement::from(byte_array.pending_word_len)); Ok(poseidon_hash_many(hashes.as_slice())) } - "selector" => get_selector_from_name(string).map_err(|_| { - Error::InvalidMessageError(format!("Invalid type {} for selector", r#type)) - }), + "selector" => get_selector_from_name(string) + .map_err(|e| Error::InvalidMessageError(format!("Invalid selector: {}", e))), "felt" => get_hex(string), "ContractAddress" => get_hex(string), "ClassHash" => get_hex(string), diff --git a/crates/torii/server/src/proxy.rs b/crates/torii/server/src/proxy.rs index df9f4e26f5..23539b5d49 100644 --- a/crates/torii/server/src/proxy.rs +++ b/crates/torii/server/src/proxy.rs @@ -17,7 +17,7 @@ use tower::ServiceBuilder; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::error; -const DEFAULT_ALLOW_HEADERS: [&str; 12] = [ +const DEFAULT_ALLOW_HEADERS: [&str; 11] = [ "accept", "origin", "content-type", @@ -27,7 +27,6 @@ const DEFAULT_ALLOW_HEADERS: [&str; 12] = [ "x-grpc-timeout", "x-user-agent", "connection", - "upgrade", "sec-websocket-key", "sec-websocket-version", ]; diff --git a/crates/torii/types-test/src/contracts.cairo b/crates/torii/types-test/src/contracts.cairo index 821e8f957a..df0ef8be39 100644 --- a/crates/torii/types-test/src/contracts.cairo +++ b/crates/torii/types-test/src/contracts.cairo @@ -1,9 +1,9 @@ use starknet::{ContractAddress, ClassHash}; -#[starknet::interface] -trait IRecords { - fn create(self: @TContractState, num_records: u8); - fn delete(self: @TContractState, record_id: u32); +#[dojo::interface] +trait IRecords { + fn create(ref world: IWorldDispatcher, num_records: u8); + fn delete(ref world: IWorldDispatcher, record_id: u32); } #[dojo::contract] @@ -33,8 +33,7 @@ mod records { #[abi(embed_v0)] impl RecordsImpl of IRecords { - fn create(self: @ContractState, num_records: u8) { - let world = self.world_dispatcher.read(); + fn create(ref world: IWorldDispatcher, num_records: u8) { let mut record_idx = 0; loop { @@ -118,7 +117,7 @@ mod records { return (); } // Implemment fn delete, input param: record_id - fn delete(self: @ContractState, record_id: u32) { + fn delete(ref world: IWorldDispatcher, record_id: u32) { let world = self.world_dispatcher.read(); let (record, record_sibling) = get!(world, record_id, (Record, RecordSibling)); let subrecord_id = record_id + 1; diff --git a/examples/spawn-and-move/manifests/dev/abis/base/contracts/dojo_examples_actions_actions.json b/examples/spawn-and-move/manifests/dev/abis/base/contracts/dojo_examples_actions_actions.json index 21aed968a7..0882f7c56d 100644 --- a/examples/spawn-and-move/manifests/dev/abis/base/contracts/dojo_examples_actions_actions.json +++ b/examples/spawn-and-move/manifests/dev/abis/base/contracts/dojo_examples_actions_actions.json @@ -182,7 +182,7 @@ "name": "spawn", "inputs": [], "outputs": [], - "state_mutability": "view" + "state_mutability": "external" }, { "type": "function", @@ -194,7 +194,7 @@ } ], "outputs": [], - "state_mutability": "view" + "state_mutability": "external" }, { "type": "function", @@ -206,6 +206,17 @@ } ], "outputs": [], + "state_mutability": "external" + }, + { + "type": "function", + "name": "get_player_position", + "inputs": [], + "outputs": [ + { + "type": "dojo_examples::models::Position" + } + ], "state_mutability": "view" } ] diff --git a/examples/spawn-and-move/manifests/dev/abis/deployments/contracts/dojo_examples_actions_actions.json b/examples/spawn-and-move/manifests/dev/abis/deployments/contracts/dojo_examples_actions_actions.json index 21aed968a7..0882f7c56d 100644 --- a/examples/spawn-and-move/manifests/dev/abis/deployments/contracts/dojo_examples_actions_actions.json +++ b/examples/spawn-and-move/manifests/dev/abis/deployments/contracts/dojo_examples_actions_actions.json @@ -182,7 +182,7 @@ "name": "spawn", "inputs": [], "outputs": [], - "state_mutability": "view" + "state_mutability": "external" }, { "type": "function", @@ -194,7 +194,7 @@ } ], "outputs": [], - "state_mutability": "view" + "state_mutability": "external" }, { "type": "function", @@ -206,6 +206,17 @@ } ], "outputs": [], + "state_mutability": "external" + }, + { + "type": "function", + "name": "get_player_position", + "inputs": [], + "outputs": [ + { + "type": "dojo_examples::models::Position" + } + ], "state_mutability": "view" } ] diff --git a/examples/spawn-and-move/manifests/dev/base/contracts/dojo_examples_actions_actions.toml b/examples/spawn-and-move/manifests/dev/base/contracts/dojo_examples_actions_actions.toml index 09f30e5dfa..405be86d28 100644 --- a/examples/spawn-and-move/manifests/dev/base/contracts/dojo_examples_actions_actions.toml +++ b/examples/spawn-and-move/manifests/dev/base/contracts/dojo_examples_actions_actions.toml @@ -1,6 +1,6 @@ kind = "DojoContract" -class_hash = "0x6d905953360cf18e3393d128c6ced40b38fc83b033412c8541fd4aba59d2767" -original_class_hash = "0x6d905953360cf18e3393d128c6ced40b38fc83b033412c8541fd4aba59d2767" +class_hash = "0x69ba4e0f7a03ae24f85aad88bd1a6b4eab5395474bbb6717803ffeb5aa13b8d" +original_class_hash = "0x69ba4e0f7a03ae24f85aad88bd1a6b4eab5395474bbb6717803ffeb5aa13b8d" base_class_hash = "0x0" abi = "manifests/dev/abis/base/contracts/dojo_examples_actions_actions.json" reads = [] diff --git a/examples/spawn-and-move/manifests/dev/manifest.json b/examples/spawn-and-move/manifests/dev/manifest.json index eb4b1821cf..9f2e854a95 100644 --- a/examples/spawn-and-move/manifests/dev/manifest.json +++ b/examples/spawn-and-move/manifests/dev/manifest.json @@ -1020,8 +1020,8 @@ { "kind": "DojoContract", "address": "0x5c70a663d6b48d8e4c6aaa9572e3735a732ac3765700d470463e670587852af", - "class_hash": "0x6d905953360cf18e3393d128c6ced40b38fc83b033412c8541fd4aba59d2767", - "original_class_hash": "0x6d905953360cf18e3393d128c6ced40b38fc83b033412c8541fd4aba59d2767", + "class_hash": "0x69ba4e0f7a03ae24f85aad88bd1a6b4eab5395474bbb6717803ffeb5aa13b8d", + "original_class_hash": "0x69ba4e0f7a03ae24f85aad88bd1a6b4eab5395474bbb6717803ffeb5aa13b8d", "base_class_hash": "0x22f3e55b61d86c2ac5239fa3b3b8761f26b9a5c0b5f61ddbd5d756ced498b46", "abi": [ { @@ -1207,7 +1207,7 @@ "name": "spawn", "inputs": [], "outputs": [], - "state_mutability": "view" + "state_mutability": "external" }, { "type": "function", @@ -1219,7 +1219,7 @@ } ], "outputs": [], - "state_mutability": "view" + "state_mutability": "external" }, { "type": "function", @@ -1231,6 +1231,17 @@ } ], "outputs": [], + "state_mutability": "external" + }, + { + "type": "function", + "name": "get_player_position", + "inputs": [], + "outputs": [ + { + "type": "dojo_examples::models::Position" + } + ], "state_mutability": "view" } ] diff --git a/examples/spawn-and-move/manifests/dev/manifest.toml b/examples/spawn-and-move/manifests/dev/manifest.toml index 6bbf881ab9..ec501b9733 100644 --- a/examples/spawn-and-move/manifests/dev/manifest.toml +++ b/examples/spawn-and-move/manifests/dev/manifest.toml @@ -22,8 +22,8 @@ name = "dojo::base::base" [[contracts]] kind = "DojoContract" address = "0x5c70a663d6b48d8e4c6aaa9572e3735a732ac3765700d470463e670587852af" -class_hash = "0x6d905953360cf18e3393d128c6ced40b38fc83b033412c8541fd4aba59d2767" -original_class_hash = "0x6d905953360cf18e3393d128c6ced40b38fc83b033412c8541fd4aba59d2767" +class_hash = "0x69ba4e0f7a03ae24f85aad88bd1a6b4eab5395474bbb6717803ffeb5aa13b8d" +original_class_hash = "0x69ba4e0f7a03ae24f85aad88bd1a6b4eab5395474bbb6717803ffeb5aa13b8d" base_class_hash = "0x22f3e55b61d86c2ac5239fa3b3b8761f26b9a5c0b5f61ddbd5d756ced498b46" abi = "manifests/dev/abis/deployments/contracts/dojo_examples_actions_actions.json" reads = [] diff --git a/examples/spawn-and-move/src/actions.cairo b/examples/spawn-and-move/src/actions.cairo index 65dc16c55c..143696ed80 100644 --- a/examples/spawn-and-move/src/actions.cairo +++ b/examples/spawn-and-move/src/actions.cairo @@ -2,9 +2,10 @@ use dojo_examples::models::{Direction, Position, Vec2}; #[dojo::interface] trait IActions { - fn spawn(); - fn move(direction: Direction); - fn set_player_config(name: ByteArray); + fn spawn(ref world: IWorldDispatcher); + fn move(ref world: IWorldDispatcher, direction: Direction); + fn set_player_config(ref world: IWorldDispatcher, name: ByteArray); + fn get_player_position(world: @IWorldDispatcher) -> Position; } #[dojo::interface] @@ -61,7 +62,7 @@ mod actions { #[abi(embed_v0)] impl ActionsImpl of IActions { // ContractState is defined by system decorator expansion - fn spawn(world: IWorldDispatcher) { + fn spawn(ref world: IWorldDispatcher) { let player = get_caller_address(); let position = get!(world, player, (Position)); @@ -76,7 +77,7 @@ mod actions { ); } - fn move(world: IWorldDispatcher, direction: Direction) { + fn move(ref world: IWorldDispatcher, direction: Direction) { let player = get_caller_address(); let (mut position, mut moves) = get!(world, player, (Position, Moves)); moves.remaining -= 1; @@ -86,7 +87,7 @@ mod actions { emit!(world, (Moved { player, direction })); } - fn set_player_config(world: IWorldDispatcher, name: ByteArray) { + fn set_player_config(ref world: IWorldDispatcher, name: ByteArray) { let player = get_caller_address(); let items = array![ @@ -97,6 +98,11 @@ mod actions { set!(world, (config)); } + + fn get_player_position(world: @IWorldDispatcher) -> Position { + let player = get_caller_address(); + get!(world, player, (Position)) + } } } diff --git a/examples/spawn-and-move/src/others.cairo b/examples/spawn-and-move/src/others.cairo index 0f27d036a8..8b64df6a23 100644 --- a/examples/spawn-and-move/src/others.cairo +++ b/examples/spawn-and-move/src/others.cairo @@ -16,7 +16,7 @@ mod others { fn dojo_init( - world: IWorldDispatcher, + world: @IWorldDispatcher, actions_address: ContractAddress, actions_class: ClassHash, value: u8