diff --git a/crates/dojo-lang/src/introspect.rs b/crates/dojo-lang/src/introspect.rs index 3cdc602d9e..cef29813fc 100644 --- a/crates/dojo-lang/src/introspect.rs +++ b/crates/dojo-lang/src/introspect.rs @@ -99,7 +99,7 @@ pub fn handle_introspect_struct(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> /// A handler for Dojo code derives Introspect for an enum /// Parameters: /// * db: The semantic database. -/// * struct_ast: The AST of the struct. +/// * enum_ast: The AST of the enum. /// Returns: /// * A RewriteNode containing the generated code. pub fn handle_introspect_enum( diff --git a/crates/dojo-lang/src/plugin.rs b/crates/dojo-lang/src/plugin.rs index 18a8b819a3..1ec7496a0a 100644 --- a/crates/dojo-lang/src/plugin.rs +++ b/crates/dojo-lang/src/plugin.rs @@ -27,7 +27,7 @@ use crate::inline_macros::get::GetMacro; use crate::inline_macros::set::SetMacro; use crate::introspect::{handle_introspect_enum, handle_introspect_struct}; use crate::model::handle_model_struct; -use crate::print::derive_print; +use crate::print::handle_print_struct; const DOJO_CONTRACT_ATTR: &str = "dojo::contract"; @@ -56,7 +56,11 @@ impl GeneratedFileAuxData for DojoAuxData { self } fn eq(&self, other: &dyn GeneratedFileAuxData) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { self == other } else { false } + if let Some(other) = other.as_any().downcast_ref::() { + self == other + } else { + false + } } } @@ -163,6 +167,7 @@ impl MacroPlugin for BuiltinDojoPlugin { enum_ast.clone(), )); } + "Print" => rewrite_nodes.push(handle_print_enum(db, enum_ast.clone())), _ => continue, } } @@ -239,7 +244,7 @@ impl MacroPlugin for BuiltinDojoPlugin { diagnostics.extend(model_diagnostics); } "Print" => { - rewrite_nodes.push(derive_print(db, struct_ast.clone())); + rewrite_nodes.push(handle_print_struct(db, struct_ast.clone())); } "Introspect" => { rewrite_nodes diff --git a/crates/dojo-lang/src/print.rs b/crates/dojo-lang/src/print.rs index 0506fa368c..88eccfa9fa 100644 --- a/crates/dojo-lang/src/print.rs +++ b/crates/dojo-lang/src/print.rs @@ -1,5 +1,5 @@ use cairo_lang_defs::patcher::RewriteNode; -use cairo_lang_syntax::node::ast::ItemStruct; +use cairo_lang_syntax::node::ast::{ItemEnum, ItemStruct}; use cairo_lang_syntax::node::db::SyntaxGroup; use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode}; use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; @@ -10,7 +10,44 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; /// * struct_ast: The AST of the model struct. /// Returns: /// * A RewriteNode containing the generated code. -pub fn derive_print(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode { +pub fn handle_print_struct(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode { + let prints: Vec<_> = struct_ast + .members(db) + .elements(db) + .iter() + .map(|m| { + format!( + "debug::PrintTrait::print('{}'); debug::PrintTrait::print(self.{});", + m.name(db).text(db).to_string(), + m.name(db).text(db).to_string() + ) + }) + .collect(); + + RewriteNode::interpolate_patched( + "#[cfg(test)] + impl $type_name$PrintImpl of debug::PrintTrait<$type_name$> { + fn print(self: $type_name$) { + $print$ + } + }", + UnorderedHashMap::from([ + ( + "type_name".to_string(), + RewriteNode::new_trimmed(struct_ast.name(db).as_syntax_node()), + ), + ("print".to_string(), RewriteNode::Text(prints.join("\n"))), + ]), + ) +} + +/// Derives PrintTrait for an enum. +/// Parameters: +/// * db: The semantic database. +/// * enum_ast: The AST of the model enum. +/// Returns: +/// * A RewriteNode containing the generated code. +pub fn handle_print_enum(db: &dyn SyntaxGroup, struct_ast: ItemEnum) -> RewriteNode { let prints: Vec<_> = struct_ast .members(db) .elements(db)