diff --git a/Cargo.toml b/Cargo.toml index 39e5ed93ee..000b3f970f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,7 @@ cairo-lang-formatter = "2.4.0" cairo-lang-language-server = "2.4.0" cairo-lang-lowering = "2.4.0" cairo-lang-parser = "2.4.0" -cairo-lang-plugins = "2.4.0" +cairo-lang-plugins = { version = "2.4.0", features = [ "testing" ] } cairo-lang-project = "2.4.0" cairo-lang-semantic = { version = "2.4.0", features = [ "testing" ] } cairo-lang-sierra = "2.4.0" diff --git a/crates/dojo-lang/src/introspect.rs b/crates/dojo-lang/src/introspect.rs index 67e7a23816..5ebccdb891 100644 --- a/crates/dojo-lang/src/introspect.rs +++ b/crates/dojo-lang/src/introspect.rs @@ -99,7 +99,7 @@ pub fn handle_introspect_struct(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> /// A handler for Dojo code derives Introspect for an enum /// Parameters: /// * db: The semantic database. -/// * struct_ast: The AST of the struct. +/// * enum_ast: The AST of the enum. /// Returns: /// * A RewriteNode containing the generated code. pub fn handle_introspect_enum( diff --git a/crates/dojo-lang/src/manifest_test_data/compiler_cairo_v240/Scarb.lock b/crates/dojo-lang/src/manifest_test_data/compiler_cairo_v240/Scarb.lock index 6f0435b5b6..bb593d0412 100644 --- a/crates/dojo-lang/src/manifest_test_data/compiler_cairo_v240/Scarb.lock +++ b/crates/dojo-lang/src/manifest_test_data/compiler_cairo_v240/Scarb.lock @@ -10,7 +10,7 @@ dependencies = [ [[package]] name = "dojo" -version = "0.4.4" +version = "0.5.0" dependencies = [ "dojo_plugin", ] diff --git a/crates/dojo-lang/src/plugin.rs b/crates/dojo-lang/src/plugin.rs index e97f5763f5..81972ce36f 100644 --- a/crates/dojo-lang/src/plugin.rs +++ b/crates/dojo-lang/src/plugin.rs @@ -27,7 +27,7 @@ use crate::inline_macros::get::GetMacro; use crate::inline_macros::set::SetMacro; use crate::introspect::{handle_introspect_enum, handle_introspect_struct}; use crate::model::handle_model_struct; -use crate::print::derive_print; +use crate::print::{handle_print_enum, handle_print_struct}; const DOJO_CONTRACT_ATTR: &str = "dojo::contract"; @@ -279,6 +279,7 @@ impl MacroPlugin for BuiltinDojoPlugin { enum_ast.clone(), )); } + "Print" => rewrite_nodes.push(handle_print_enum(db, enum_ast.clone())), _ => continue, } } @@ -355,7 +356,7 @@ impl MacroPlugin for BuiltinDojoPlugin { diagnostics.extend(model_diagnostics); } "Print" => { - rewrite_nodes.push(derive_print(db, struct_ast.clone())); + rewrite_nodes.push(handle_print_struct(db, struct_ast.clone())); } "Introspect" => { rewrite_nodes diff --git a/crates/dojo-lang/src/plugin_test.rs b/crates/dojo-lang/src/plugin_test.rs index d789ff5a5b..899cd8557c 100644 --- a/crates/dojo-lang/src/plugin_test.rs +++ b/crates/dojo-lang/src/plugin_test.rs @@ -1,9 +1,8 @@ use std::sync::Arc; use cairo_lang_defs::db::{DefsDatabase, DefsGroup}; -use cairo_lang_defs::ids::{LanguageElementId, ModuleId, ModuleItemId}; +use cairo_lang_defs::ids::ModuleId; use cairo_lang_defs::plugin::MacroPlugin; -use cairo_lang_diagnostics::{format_diagnostics, DiagnosticLocation}; use cairo_lang_filesystem::cfg::CfgSet; use cairo_lang_filesystem::db::{ init_files_group, AsFilesGroupMut, CrateConfiguration, FilesDatabase, FilesGroup, FilesGroupEx, @@ -11,13 +10,11 @@ use cairo_lang_filesystem::db::{ use cairo_lang_filesystem::ids::{CrateLongId, Directory, FileLongId}; use cairo_lang_parser::db::ParserDatabase; use cairo_lang_plugins::get_base_plugins; +use cairo_lang_plugins::test_utils::expand_module_text; use cairo_lang_syntax::node::db::{SyntaxDatabase, SyntaxGroup}; -use cairo_lang_syntax::node::kind::SyntaxKind; -use cairo_lang_syntax::node::{ast, TypedSyntaxNode}; use cairo_lang_test_utils::parse_test_file::TestRunnerResult; use cairo_lang_test_utils::verify_diagnostics_expectation; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; -use cairo_lang_utils::unordered_hash_set::UnorderedHashSet; use cairo_lang_utils::Upcast; use super::BuiltinDojoPlugin; @@ -118,58 +115,3 @@ pub fn test_expand_plugin_inner( error, } } - -pub fn expand_module_text( - db: &dyn DefsGroup, - module_id: ModuleId, - diagnostics: &mut Vec, -) -> String { - let mut output = String::new(); - // A collection of all the use statements in the module. - let mut uses_list = UnorderedHashSet::default(); - let syntax_db = db.upcast(); - // Collect the module diagnostics. - for (file_id, diag) in db.module_plugin_diagnostics(module_id).unwrap().iter() { - let syntax_node = diag.stable_ptr.lookup(syntax_db); - let location = DiagnosticLocation { - file_id: file_id.file_id(db.upcast()).unwrap(), - span: syntax_node.span_without_trivia(syntax_db), - }; - diagnostics.push(format_diagnostics(db.upcast(), &diag.message, location)); - } - for item_id in db.module_items(module_id).unwrap().iter() { - if let ModuleItemId::Submodule(item) = item_id { - let submodule_item = item.stable_ptr(db).lookup(syntax_db); - if let ast::MaybeModuleBody::Some(body) = submodule_item.body(syntax_db) { - // Recursively expand inline submodules. - output.extend([ - submodule_item.attributes(syntax_db).node.get_text(syntax_db), - submodule_item.module_kw(syntax_db).as_syntax_node().get_text(syntax_db), - submodule_item.name(syntax_db).as_syntax_node().get_text(syntax_db), - body.lbrace(syntax_db).as_syntax_node().get_text(syntax_db), - expand_module_text(db, ModuleId::Submodule(*item), diagnostics), - body.rbrace(syntax_db).as_syntax_node().get_text(syntax_db), - ]); - continue; - } - } else if let ModuleItemId::Use(use_id) = item_id { - let mut use_item = use_id.stable_ptr(db).lookup(syntax_db).as_syntax_node(); - // Climb up the AST until the syntax kind is ItemUse. This is needed since the use item - // points to the use leaf as one use statement can represent multiple use items. - while let Some(parent) = use_item.parent() { - use_item = parent; - if use_item.kind(syntax_db) == SyntaxKind::ItemUse { - break; - } - } - if uses_list.insert(use_item.clone()) { - output.push_str(&use_item.get_text(syntax_db)); - } - continue; - } - let syntax_item = item_id.untyped_stable_ptr(db); - // Output other items as is. - output.push_str(&syntax_item.lookup(syntax_db).get_text(syntax_db)); - } - output -} diff --git a/crates/dojo-lang/src/plugin_test_data/print b/crates/dojo-lang/src/plugin_test_data/print index d7ed7801d2..aec292d7d2 100644 --- a/crates/dojo-lang/src/plugin_test_data/print +++ b/crates/dojo-lang/src/plugin_test_data/print @@ -3,10 +3,14 @@ //! > test_runner_name test_expand_plugin +//! > cfg +["test"] + //! > cairo_code use serde::Serde; +use debug::PrintTrait; -#[derive(Print, Copy, Drop, Serde)] +#[derive(Print)] struct Position { #[key] id: felt252, @@ -15,14 +19,14 @@ struct Position { y: felt252 } -#[derive(Print, Serde)] +#[derive(Print)] struct Roles { role_ids: Array } use starknet::ContractAddress; -#[derive(Print, Copy, Drop, Serde)] +#[derive(Print)] struct Player { #[key] game: felt252, @@ -32,11 +36,18 @@ struct Player { name: felt252, } -//! > generated_cairo_code -use serde::Serde; +#[derive(Print)] +enum Enemy { + Unknown, + Bot: felt252, + OtherPlayer: ContractAddress, +} +//! > expanded_cairo_code +use serde::Serde; +use debug::PrintTrait; -#[derive(Print, Copy, Drop, Serde)] +#[derive(Print)] struct Position { #[key] id: felt252, @@ -45,37 +56,14 @@ struct Position { y: felt252 } -#[cfg(test)] -impl PositionPrintImpl of core::debug::PrintTrait { - fn print(self: Position) { - core::debug::PrintTrait::print('id'); - core::debug::PrintTrait::print(self.id); - core::debug::PrintTrait::print('x'); - core::debug::PrintTrait::print(self.x); - core::debug::PrintTrait::print('y'); - core::debug::PrintTrait::print(self.y); - } -} - - -#[derive(Print, Serde)] +#[derive(Print)] struct Roles { role_ids: Array } -#[cfg(test)] -impl RolesPrintImpl of core::debug::PrintTrait { - fn print(self: Roles) { - core::debug::PrintTrait::print('role_ids'); - core::debug::PrintTrait::print(self.role_ids); - } -} - - use starknet::ContractAddress; - -#[derive(Print, Copy, Drop, Serde)] +#[derive(Print)] struct Player { #[key] game: felt252, @@ -84,87 +72,48 @@ struct Player { name: felt252, } -#[cfg(test)] -impl PlayerPrintImpl of core::debug::PrintTrait { - fn print(self: Player) { - core::debug::PrintTrait::print('game'); - core::debug::PrintTrait::print(self.game); - core::debug::PrintTrait::print('player'); - core::debug::PrintTrait::print(self.player); - core::debug::PrintTrait::print('name'); - core::debug::PrintTrait::print(self.name); - } -} - -//! > expected_diagnostics - -//! > expanded_cairo_code -use serde::Serde; - -#[derive(Print, Copy, Drop, Serde)] -struct Position { - #[key] - id: felt252, - x: felt252, - y: felt252 +#[derive(Print)] +enum Enemy { + Unknown, + Bot: felt252, + OtherPlayer: ContractAddress, } -#[derive(Print, Serde)] -struct Roles { - role_ids: Array +#[cfg(test)] +impl PositionStructPrintImpl of core::debug::PrintTrait { + fn print(self: Position) { + core::debug::PrintTrait::print('id'); core::debug::PrintTrait::print(self.id); +core::debug::PrintTrait::print('x'); core::debug::PrintTrait::print(self.x); +core::debug::PrintTrait::print('y'); core::debug::PrintTrait::print(self.y); + } } -use starknet::ContractAddress; - -#[derive(Print, Copy, Drop, Serde)] -struct Player { - #[key] - game: felt252, - #[key] - player: ContractAddress, - - name: felt252, -} -impl PositionCopy of core::traits::Copy::; -impl PositionDrop of core::traits::Drop::; -impl PositionSerde of core::serde::Serde:: { - fn serialize(self: @Position, ref output: core::array::Array) { - core::serde::Serde::serialize(self.id, ref output); - core::serde::Serde::serialize(self.x, ref output); - core::serde::Serde::serialize(self.y, ref output) - } - fn deserialize(ref serialized: core::array::Span) -> core::option::Option { - core::option::Option::Some(Position { - id: core::serde::Serde::deserialize(ref serialized)?, - x: core::serde::Serde::deserialize(ref serialized)?, - y: core::serde::Serde::deserialize(ref serialized)?, - }) +#[cfg(test)] +impl RolesStructPrintImpl of core::debug::PrintTrait { + fn print(self: Roles) { + core::debug::PrintTrait::print('role_ids'); core::debug::PrintTrait::print(self.role_ids); } } -impl RolesSerde of core::serde::Serde:: { - fn serialize(self: @Roles, ref output: core::array::Array) { - core::serde::Serde::serialize(self.role_ids, ref output) - } - fn deserialize(ref serialized: core::array::Span) -> core::option::Option { - core::option::Option::Some(Roles { - role_ids: core::serde::Serde::deserialize(ref serialized)?, - }) + +#[cfg(test)] +impl PlayerStructPrintImpl of core::debug::PrintTrait { + fn print(self: Player) { + core::debug::PrintTrait::print('game'); core::debug::PrintTrait::print(self.game); +core::debug::PrintTrait::print('player'); core::debug::PrintTrait::print(self.player); +core::debug::PrintTrait::print('name'); core::debug::PrintTrait::print(self.name); } } -impl PlayerCopy of core::traits::Copy::; -impl PlayerDrop of core::traits::Drop::; -impl PlayerSerde of core::serde::Serde:: { - fn serialize(self: @Player, ref output: core::array::Array) { - core::serde::Serde::serialize(self.game, ref output); - core::serde::Serde::serialize(self.player, ref output); - core::serde::Serde::serialize(self.name, ref output) - } - fn deserialize(ref serialized: core::array::Span) -> core::option::Option { - core::option::Option::Some(Player { - game: core::serde::Serde::deserialize(ref serialized)?, - player: core::serde::Serde::deserialize(ref serialized)?, - name: core::serde::Serde::deserialize(ref serialized)?, - }) + +#[cfg(test)] +impl EnemyEnumPrintImpl of core::debug::PrintTrait { + fn print(self: Enemy) { + match self { + Enemy::Unknown => { core::debug::PrintTrait::print('Unknown'); }, +Enemy::Bot(v) => { core::debug::PrintTrait::print('Bot'); core::debug::PrintTrait::print(v); }, +Enemy::OtherPlayer(v) => { core::debug::PrintTrait::print('OtherPlayer'); core::debug::PrintTrait::print(v); } + } } } + +//! > expected_diagnostics diff --git a/crates/dojo-lang/src/print.rs b/crates/dojo-lang/src/print.rs index 06136b592e..168adb79d7 100644 --- a/crates/dojo-lang/src/print.rs +++ b/crates/dojo-lang/src/print.rs @@ -1,5 +1,5 @@ use cairo_lang_defs::patcher::RewriteNode; -use cairo_lang_syntax::node::ast::ItemStruct; +use cairo_lang_syntax::node::ast::{ItemEnum, ItemStruct, OptionTypeClause}; use cairo_lang_syntax::node::db::SyntaxGroup; use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode}; use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; @@ -10,7 +10,7 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; /// * struct_ast: The AST of the model struct. /// Returns: /// * A RewriteNode containing the generated code. -pub fn derive_print(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode { +pub fn handle_print_struct(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode { let prints: Vec<_> = struct_ast .members(db) .elements(db) @@ -25,12 +25,14 @@ pub fn derive_print(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode .collect(); RewriteNode::interpolate_patched( - "#[cfg(test)] - impl $type_name$PrintImpl of core::debug::PrintTrait<$type_name$> { - fn print(self: $type_name$) { - $print$ - } - }", + " +#[cfg(test)] +impl $type_name$StructPrintImpl of core::debug::PrintTrait<$type_name$> { + fn print(self: $type_name$) { + $print$ + } +} +", &UnorderedHashMap::from([ ( "type_name".to_string(), @@ -40,3 +42,53 @@ pub fn derive_print(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode ]), ) } + +/// Derives PrintTrait for an enum. +/// Parameters: +/// * db: The semantic database. +/// * enum_ast: The AST of the model enum. +/// Returns: +/// * A RewriteNode containing the generated code. +pub fn handle_print_enum(db: &dyn SyntaxGroup, enum_ast: ItemEnum) -> RewriteNode { + let enum_name = enum_ast.name(db).text(db); + let prints: Vec<_> = enum_ast + .variants(db) + .elements(db) + .iter() + .map(|m| { + let variant_name = m.name(db).text(db).to_string(); + match m.type_clause(db) { + OptionTypeClause::Empty(_) => { + format!( + "{enum_name}::{variant_name} => {{ \ + core::debug::PrintTrait::print('{variant_name}'); }}" + ) + } + OptionTypeClause::TypeClause(_) => { + format!( + "{enum_name}::{variant_name}(v) => {{ \ + core::debug::PrintTrait::print('{variant_name}'); \ + core::debug::PrintTrait::print(v); }}" + ) + } + } + }) + .collect(); + + RewriteNode::interpolate_patched( + " +#[cfg(test)] +impl $type_name$EnumPrintImpl of core::debug::PrintTrait<$type_name$> { + fn print(self: $type_name$) { + match self { + $print$ + } + } +} +", + &UnorderedHashMap::from([ + ("type_name".to_string(), RewriteNode::new_trimmed(enum_ast.name(db).as_syntax_node())), + ("print".to_string(), RewriteNode::Text(prints.join(",\n"))), + ]), + ) +} diff --git a/examples/spawn-and-move/Scarb.lock b/examples/spawn-and-move/Scarb.lock index b61ff06a70..ae8ad6f337 100644 --- a/examples/spawn-and-move/Scarb.lock +++ b/examples/spawn-and-move/Scarb.lock @@ -10,7 +10,7 @@ dependencies = [ [[package]] name = "dojo_examples" -version = "0.4.4" +version = "0.5.0" dependencies = [ "dojo", ]