diff --git a/crates/dojo-lang/src/contract.rs b/crates/dojo-lang/src/contract.rs index 7a9dd4c122..bf07bbd2d3 100644 --- a/crates/dojo-lang/src/contract.rs +++ b/crates/dojo-lang/src/contract.rs @@ -2,9 +2,11 @@ use std::collections::HashMap; use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode}; use cairo_lang_defs::plugin::{ - DynGeneratedFileAuxData, PluginDiagnostic, PluginGeneratedFile, PluginResult, + DynGeneratedFileAuxData, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, + PluginResult, }; use cairo_lang_diagnostics::Severity; +use cairo_lang_plugins::plugins::HasItemsInCfgEx; use cairo_lang_syntax::node::ast::{ ArgClause, Expr, MaybeModuleBody, OptionArgListParenthesized, OptionReturnTypeClause, }; @@ -39,6 +41,7 @@ impl DojoContract { db: &dyn SyntaxGroup, module_ast: &ast::ItemModule, package_id: String, + metadata: &MacroPluginMetadata<'_>, ) -> PluginResult { let name = module_ast.name(db).text(db); @@ -81,28 +84,26 @@ impl DojoContract { if let MaybeModuleBody::Some(body) = module_ast.body(db) { let mut body_nodes: Vec<_> = body - .items(db) - .elements(db) - .iter() + .iter_items_in_cfg(db, metadata.cfg_set) .flat_map(|el| { - if let ast::ModuleItem::Enum(enum_ast) = el { + if let ast::ModuleItem::Enum(ref enum_ast) = el { if enum_ast.name(db).text(db).to_string() == "Event" { has_event = true; return system.merge_event(db, enum_ast.clone()); } - } else if let ast::ModuleItem::Struct(struct_ast) = el { + } else if let ast::ModuleItem::Struct(ref struct_ast) = el { if struct_ast.name(db).text(db).to_string() == "Storage" { has_storage = true; return system.merge_storage(db, struct_ast.clone()); } - } else if let ast::ModuleItem::Impl(impl_ast) = el { + } else if let ast::ModuleItem::Impl(ref impl_ast) = el { // If an implementation is not targetting the ContractState, // the auto injection of self and world is not applied. let trait_path = impl_ast.trait_path(db).node.get_text(db); if trait_path.contains("") { - return system.rewrite_impl(db, impl_ast.clone()); + return system.rewrite_impl(db, impl_ast.clone(), metadata); } - } else if let ast::ModuleItem::FreeFunction(fn_ast) = el { + } else if let ast::ModuleItem::FreeFunction(ref fn_ast) = el { let fn_decl = fn_ast.declaration(db); let fn_name = fn_decl.name(db).text(db); @@ -553,7 +554,12 @@ impl DojoContract { } /// Rewrites all the functions of a Impl block. - fn rewrite_impl(&mut self, db: &dyn SyntaxGroup, impl_ast: ast::ItemImpl) -> Vec { + fn rewrite_impl( + &mut self, + db: &dyn SyntaxGroup, + impl_ast: ast::ItemImpl, + metadata: &MacroPluginMetadata<'_>, + ) -> Vec { let generate_attrs = impl_ast.attributes(db).query_attr(db, "generate_trait"); let has_generate_trait = !generate_attrs.is_empty(); @@ -570,11 +576,9 @@ impl DojoContract { }; let body_nodes: Vec<_> = body - .items(db) - .elements(db) - .iter() + .iter_items_in_cfg(db, metadata.cfg_set) .flat_map(|el| { - if let ast::ImplItem::Function(fn_ast) = el { + if let ast::ImplItem::Function(ref fn_ast) = el { return self.rewrite_function(db, fn_ast.clone(), has_generate_trait); } vec![RewriteNode::Copied(el.as_syntax_node())] diff --git a/crates/dojo-lang/src/interface.rs b/crates/dojo-lang/src/interface.rs index bc30581ddf..c6a2294043 100644 --- a/crates/dojo-lang/src/interface.rs +++ b/crates/dojo-lang/src/interface.rs @@ -1,6 +1,9 @@ use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode}; -use cairo_lang_defs::plugin::{PluginDiagnostic, PluginGeneratedFile, PluginResult}; +use cairo_lang_defs::plugin::{ + MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult, +}; use cairo_lang_diagnostics::Severity; +use cairo_lang_plugins::plugins::HasItemsInCfgEx; use cairo_lang_syntax::node::db::SyntaxGroup; use cairo_lang_syntax::node::{ast, ids, Terminal, TypedStablePtr, TypedSyntaxNode}; use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; @@ -14,18 +17,20 @@ pub struct DojoInterface { } impl DojoInterface { - pub fn from_trait(db: &dyn SyntaxGroup, trait_ast: ast::ItemTrait) -> PluginResult { + pub fn from_trait( + db: &dyn SyntaxGroup, + trait_ast: ast::ItemTrait, + metadata: &MacroPluginMetadata<'_>, + ) -> PluginResult { let name = trait_ast.name(db).text(db); let mut system = DojoInterface { diagnostics: vec![] }; let mut builder = PatchBuilder::new(db, &trait_ast); if let ast::MaybeTraitBody::Some(body) = trait_ast.body(db) { let body_nodes: Vec<_> = body - .items(db) - .elements(db) - .iter() + .iter_items_in_cfg(db, metadata.cfg_set) .flat_map(|el| { - if let ast::TraitItem::Function(fn_ast) = el { + if let ast::TraitItem::Function(ref fn_ast) = el { return system.rewrite_function(db, fn_ast.clone()); } diff --git a/crates/dojo-lang/src/plugin.rs b/crates/dojo-lang/src/plugin.rs index 2e619464e4..a0f43e8c4a 100644 --- a/crates/dojo-lang/src/plugin.rs +++ b/crates/dojo-lang/src/plugin.rs @@ -112,17 +112,23 @@ impl BuiltinDojoPlugin { db: &dyn SyntaxGroup, module_ast: ast::ItemModule, package_id: String, + metadata: &MacroPluginMetadata<'_>, ) -> PluginResult { if module_ast.has_attr(db, DOJO_CONTRACT_ATTR) { - return DojoContract::from_module(db, &module_ast, package_id); + return DojoContract::from_module(db, &module_ast, package_id, metadata); } PluginResult::default() } - fn handle_trait(&self, db: &dyn SyntaxGroup, trait_ast: ast::ItemTrait) -> PluginResult { + fn handle_trait( + &self, + db: &dyn SyntaxGroup, + trait_ast: ast::ItemTrait, + metadata: &MacroPluginMetadata<'_>, + ) -> PluginResult { if trait_ast.has_attr(db, DOJO_INTERFACE_ATTR) { - return DojoInterface::from_trait(db, trait_ast); + return DojoInterface::from_trait(db, trait_ast, metadata); } PluginResult::default() @@ -330,13 +336,11 @@ fn get_additional_derive_attrs_for_model(derive_attr_names: &[String]) -> Vec - // Not used for now, but it contains a key-value BTreeSet. TBD what we can do with this. fn generate_code( &self, db: &dyn SyntaxGroup, item_ast: ast::ModuleItem, - _metadata: &MacroPluginMetadata<'_>, + metadata: &MacroPluginMetadata<'_>, ) -> PluginResult { let package_id = match get_package_id(db) { Option::Some(x) => x, @@ -356,8 +360,10 @@ impl MacroPlugin for BuiltinDojoPlugin { }; match item_ast { - ast::ModuleItem::Module(module_ast) => self.handle_mod(db, module_ast, package_id), - ast::ModuleItem::Trait(trait_ast) => self.handle_trait(db, trait_ast), + ast::ModuleItem::Module(module_ast) => { + self.handle_mod(db, module_ast, package_id, metadata) + } + ast::ModuleItem::Trait(trait_ast) => self.handle_trait(db, trait_ast, metadata), ast::ModuleItem::Enum(enum_ast) => { let aux_data = DojoAuxData::default(); let mut rewrite_nodes = vec![];