diff --git a/crates/dojo-lang/src/interface.rs b/crates/dojo-lang/src/interface.rs new file mode 100644 index 0000000000..ecb52f5afc --- /dev/null +++ b/crates/dojo-lang/src/interface.rs @@ -0,0 +1,197 @@ +use std::collections::HashMap; + +use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode}; +use cairo_lang_defs::plugin::{ + DynGeneratedFileAuxData, PluginDiagnostic, PluginGeneratedFile, PluginResult, +}; +use cairo_lang_diagnostics::Severity; +use cairo_lang_syntax::node::ids::SyntaxStablePtrId; +use cairo_lang_syntax::node::{ast, Terminal, TypedSyntaxNode}; +use cairo_lang_syntax::node::ast::{MaybeTraitBody, OptionReturnTypeClause, ParamList}; +use cairo_lang_syntax::node::db::SyntaxGroup; +use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; +use dojo_types::system::Dependency; + +use crate::plugin::{DojoAuxData, SystemAuxData}; + +pub struct DojoInterface { + diagnostics: Vec, + dependencies: HashMap, +} + +impl DojoInterface { + pub fn from_trait(db: &dyn SyntaxGroup, trait_ast: ast::ItemTrait) -> PluginResult { + let name = trait_ast.name(db).text(db); + let mut system = DojoInterface { diagnostics: vec![], dependencies: HashMap::new() }; + let mut builder = PatchBuilder::new(db); + + if let MaybeTraitBody::Some(body) = trait_ast.body(db) { + let body_nodes: Vec<_> = body + .items(db) + .elements(db) + .iter() + .flat_map(|el| { + if let ast::TraitItem::Function(fn_ast) = el { + return system.rewrite_function(db, fn_ast.clone()); + } + + //TODO: if the trait body contains other things than functions, + // they are not correctly copied. + // for example: `const ONE: u8 = 1;` is copied as `const ONE` +// vec![RewriteNode::Copied(el.as_syntax_node())] + vec![] + }) + .collect(); + + builder.add_modified(RewriteNode::interpolate_patched( + " + #[starknet::interface] + trait $name$ { + $body$ + } + ", + &UnorderedHashMap::from([ + ("name".to_string(), RewriteNode::Text(name.to_string())), + ("body".to_string(), RewriteNode::new_modified(body_nodes)), + ]), + )); + } + else { + // empty trait + builder.add_modified(RewriteNode::interpolate_patched( + " + #[starknet::interface] + trait $name$ {} + ", + &UnorderedHashMap::from([ + ("name".to_string(), RewriteNode::Text(name.to_string())), + ]), + )); + } + + return PluginResult { + code: Some(PluginGeneratedFile { + name: name.clone(), + content: builder.code, + aux_data: Some(DynGeneratedFileAuxData::new(DojoAuxData { + models: vec![], + systems: vec![SystemAuxData { + name, + dependencies: system.dependencies.values().cloned().collect(), + }], + })), + code_mappings: builder.code_mappings, + }), + diagnostics: system.diagnostics, + remove_original_item: true, + }; + } + + /// Rewrites parameter list by adding `self` parameter if missing. + /// + /// Reports an error in case of `ref self` as systems are supposed to be 100% stateless. + pub fn rewrite_parameters( + &mut self, + db: &dyn SyntaxGroup, + param_list: ParamList, + diagnostic_item: SyntaxStablePtrId + ) -> String { + let mut params = param_list.elements(db) + .iter() + .map(|e| e.as_syntax_node().get_text(db)) + .collect::>(); + + let mut need_to_add_self = true; + if !params.is_empty() { + let first_param = param_list.elements(db)[0].clone(); + let param_name = first_param.name(db).text(db).to_string(); + + if param_name.eq(&"self".to_string()) { + let param_modifiers = first_param + .modifiers(db).elements(db) + .iter().map(|e|e.as_syntax_node().get_text(db).trim().to_string()) + .collect::>(); + + let param_type = first_param + .type_clause(db).ty(db).as_syntax_node() + .get_text(db).trim().to_string(); + + if param_modifiers.contains(&"ref".to_string()) && param_type.eq(&"TContractState".to_string()) { + self.diagnostics.push(PluginDiagnostic { + stable_ptr: diagnostic_item, + message: "Functions of dojo::interface cannot have `ref self` parameter.".to_string(), + severity: Severity::Error, + }); + + need_to_add_self = false; + } + + if param_type.eq(&"@TContractState".to_string()) { + need_to_add_self = false; + } + } + }; + + if need_to_add_self { + params.insert(0, "self: @TContractState".to_string()); + } + + params.join(", ") + } + + /// Rewrites function declaration by adding `self` parameter if missing. + /// + /// Some notes: + /// * as rewritten functions belong to a starknet::interface: + /// - there is no generic parameter, + /// - there is no nopanic and implicits clauses. + /// * there is no function body in a trait. + pub fn rewrite_function( + &mut self, + db: &dyn SyntaxGroup, + fn_ast: ast::TraitItemFunction, + ) -> Vec { + let mut rewrite_nodes = vec![]; + + let declaration = fn_ast.declaration(db); + let fn_name_node = declaration.name(db).clone(); + let fn_name = fn_name_node.text(db).to_string(); + let signature = declaration.signature(db); + + let attributes; + if fn_ast.attributes(db).elements(db).is_empty() { + attributes = "".to_string(); + } + else { + attributes = fn_ast.attributes(db).elements(db) + .iter() + .map(|e| e.as_syntax_node().get_text(db)) + .collect::>() + .join("\n"); + } + + let ret_ty = if let OptionReturnTypeClause::ReturnTypeClause(ty) = signature.ret_ty(db) { + format!(" {}", ty.as_syntax_node().get_text(db).to_string()) + } else { + "".to_string() + }; + + let params = self.rewrite_parameters( + db, + signature.parameters(db), + fn_name_node.stable_ptr().untyped() + ); + + rewrite_nodes.push(RewriteNode::interpolate_patched( + "$attributes$fn $fn_name$($params$)$ret_ty$;\n", + &UnorderedHashMap::from([ + ("attributes".to_string(), RewriteNode::Text(attributes)), + ("fn_name".to_string(), RewriteNode::Text(fn_name)), + ("params".to_string(), RewriteNode::Text(params)), + ("ret_ty".to_string(), RewriteNode::Text(ret_ty)) + ]), + )); + + rewrite_nodes + } +} diff --git a/crates/dojo-lang/src/lib.rs b/crates/dojo-lang/src/lib.rs index ced65c7cde..15c1933311 100644 --- a/crates/dojo-lang/src/lib.rs +++ b/crates/dojo-lang/src/lib.rs @@ -5,6 +5,7 @@ //! Learn more at [dojoengine.gg](http://dojoengine.gg). pub mod compiler; pub mod contract; +pub mod interface; pub mod inline_macros; pub mod introspect; pub mod model; diff --git a/crates/dojo-lang/src/plugin.rs b/crates/dojo-lang/src/plugin.rs index 9900115cb2..17319fca07 100644 --- a/crates/dojo-lang/src/plugin.rs +++ b/crates/dojo-lang/src/plugin.rs @@ -23,6 +23,7 @@ use smol_str::SmolStr; use url::Url; use crate::contract::DojoContract; +use crate::interface::DojoInterface; use crate::inline_macros::array_cap::ArrayCapMacro; use crate::inline_macros::delete::DeleteMacro; use crate::inline_macros::emit::EmitMacro; @@ -33,6 +34,7 @@ use crate::model::handle_model_struct; use crate::print::{handle_print_enum, handle_print_struct}; const DOJO_CONTRACT_ATTR: &str = "dojo::contract"; +const DOJO_INTERFACE_ATTR: &str = "dojo::interface"; const DOJO_PLUGIN_EXPAND_VAR_ENV: &str = "DOJO_PLUGIN_EXPAND"; #[derive(Clone, Debug, PartialEq)] @@ -101,6 +103,14 @@ impl BuiltinDojoPlugin { PluginResult::default() } + fn handle_trait(&self, db: &dyn SyntaxGroup, trait_ast: ast::ItemTrait) -> PluginResult { + if trait_ast.has_attr(db, DOJO_INTERFACE_ATTR) { + return DojoInterface::from_trait(db, trait_ast); + } + + PluginResult::default() + } + fn result_with_diagnostic( &self, stable_ptr: SyntaxStablePtrId, @@ -246,6 +256,7 @@ impl MacroPlugin for BuiltinDojoPlugin { match item_ast { ast::ModuleItem::Module(module_ast) => self.handle_mod(db, module_ast), + ast::ModuleItem::Trait(trait_ast) => self.handle_trait(db, trait_ast), ast::ModuleItem::Enum(enum_ast) => { let aux_data = DojoAuxData::default(); let mut rewrite_nodes = vec![]; diff --git a/crates/dojo-lang/src/plugin_test_data/system b/crates/dojo-lang/src/plugin_test_data/system index 199755e340..108757921e 100644 --- a/crates/dojo-lang/src/plugin_test_data/system +++ b/crates/dojo-lang/src/plugin_test_data/system @@ -84,6 +84,21 @@ mod withcomponent { testcomponent2_event: testcomponent2::Event, } } +#[dojo::interface] +trait IEmptyTrait; + +#[dojo::interface] +trait IActions { + fn do_no_param(); + fn do_no_param_but_self(self: @TContractState); + fn do_params(p1: dojo_examples::models::Direction, p2: u8); + fn do_params_and_self(self: @TContractState, p2: u8); + fn do_return_value(p1: u8) -> u16; + fn do_ref_self(ref self: TContractState); + + #[my_attr] + fn do_with_attrs(p1: u8) -> u16; +} //! > generated_cairo_code #[starknet::contract] @@ -205,6 +220,11 @@ error: Unsupported attribute. #[starknet::component] ^********************^ +error: Functions of dojo::interface cannot have `ref self` parameter. + --> test_src/lib.cairo:91:8 + fn do_ref_self(ref self: TContractState); + ^*********^ + error: Unsupported attribute. --> test_src/lib.cairo[spawn]:2:17 #[starknet::contract] @@ -230,6 +250,11 @@ error: Unsupported attribute. #[starknet::contract] ^*******************^ +error: Unsupported attribute. + --> test_src/lib.cairo[IActions]:11:5 + #[my_attr] + ^********^ + error: Unsupported attribute. --> test_src/lib.cairo:49:5 #[storage] @@ -701,3 +726,20 @@ impl TestEventDrop of core::traits::Drop::; impl EventDrop of core::traits::Drop::; } + + #[starknet::interface] + trait IEmptyTrait {} + + #[starknet::interface] + trait IActions { + fn do_no_param(self: @TContractState); +fn do_no_param_but_self(self: @TContractState); +fn do_params(self: @TContractState, p1: dojo_examples::models::Direction, p2: u8); +fn do_params_and_self(self: @TContractState, p2: u8); +fn do_return_value(self: @TContractState, p1: u8) -> u16; +fn do_ref_self(ref self: TContractState); + + #[my_attr] +fn do_with_attrs(self: @TContractState, p1: u8) -> u16; + + }