Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support deriving Print for enums #1091

Merged
merged 17 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/dojo-lang/src/introspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub fn handle_introspect_struct(db: &dyn SyntaxGroup, struct_ast: ItemStruct) ->
/// A handler for Dojo code derives Introspect for an enum
/// Parameters:
/// * db: The semantic database.
/// * struct_ast: The AST of the struct.
/// * enum_ast: The AST of the enum.
/// Returns:
/// * A RewriteNode containing the generated code.
pub fn handle_introspect_enum(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [

[[package]]
name = "dojo"
version = "0.4.2"
version = "0.4.3"
dependencies = [
"dojo_plugin",
]
Expand Down
5 changes: 3 additions & 2 deletions crates/dojo-lang/src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::inline_macros::get::GetMacro;
use crate::inline_macros::set::SetMacro;
use crate::introspect::{handle_introspect_enum, handle_introspect_struct};
use crate::model::handle_model_struct;
use crate::print::derive_print;
use crate::print::{handle_print_enum, handle_print_struct};

const DOJO_CONTRACT_ATTR: &str = "dojo::contract";

Expand Down Expand Up @@ -279,6 +279,7 @@ impl MacroPlugin for BuiltinDojoPlugin {
enum_ast.clone(),
));
}
"Print" => rewrite_nodes.push(handle_print_enum(db, enum_ast.clone())),
_ => continue,
}
}
Expand Down Expand Up @@ -355,7 +356,7 @@ impl MacroPlugin for BuiltinDojoPlugin {
diagnostics.extend(model_diagnostics);
}
"Print" => {
rewrite_nodes.push(derive_print(db, struct_ast.clone()));
rewrite_nodes.push(handle_print_struct(db, struct_ast.clone()));
}
"Introspect" => {
rewrite_nodes
Expand Down
72 changes: 22 additions & 50 deletions crates/dojo-lang/src/plugin_test_data/print
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_expand_plugin
//! > cairo_code
use serde::Serde;

#[derive(Print, Copy, Drop, Serde)]
#[derive(Print)]
struct Position {
#[key]
id: felt252,
Expand All @@ -15,14 +15,14 @@ struct Position {
y: felt252
}

#[derive(Print, Serde)]
#[derive(Print)]
struct Roles {
role_ids: Array<u8>
}

use starknet::ContractAddress;

#[derive(Print, Copy, Drop, Serde)]
#[derive(Print)]
struct Player {
#[key]
game: felt252,
Expand All @@ -32,11 +32,18 @@ struct Player {
name: felt252,
}

#[derive(Print)]
enum Enemy {
Unknown,
Bot: felt252,
OtherPlayer: ContractAddress,
}

//! > generated_cairo_code
use serde::Serde;


#[derive(Print, Copy, Drop, Serde)]
#[derive(Print)]
struct Position {
#[key]
id: felt252,
Expand All @@ -57,8 +64,7 @@ impl PositionPrintImpl of core::debug::PrintTrait<Position> {
}
}


#[derive(Print, Serde)]
#[derive(Print)]
struct Roles {
role_ids: Array<u8>
}
Expand All @@ -84,6 +90,7 @@ struct Player {

name: felt252,
}

#[cfg(test)]
impl PlayerPrintImpl of core::debug::PrintTrait<Player> {
fn print(self: Player) {
Expand All @@ -101,7 +108,7 @@ impl PlayerPrintImpl of core::debug::PrintTrait<Player> {
//! > expanded_cairo_code
use serde::Serde;

#[derive(Print, Copy, Drop, Serde)]
#[derive(Print)]
struct Position {
#[key]
id: felt252,
Expand All @@ -110,14 +117,14 @@ struct Position {
y: felt252
}

#[derive(Print, Serde)]
#[derive(Print)]
struct Roles {
role_ids: Array<u8>
}

use starknet::ContractAddress;

#[derive(Print, Copy, Drop, Serde)]
#[derive(Print)]
struct Player {
#[key]
game: felt252,
Expand All @@ -126,45 +133,10 @@ struct Player {

name: felt252,
}
impl PositionCopy of core::traits::Copy::<Position>;
impl PositionDrop of core::traits::Drop::<Position>;
impl PositionSerde of core::serde::Serde::<Position> {
fn serialize(self: @Position, ref output: core::array::Array<felt252>) {
core::serde::Serde::serialize(self.id, ref output);
core::serde::Serde::serialize(self.x, ref output);
core::serde::Serde::serialize(self.y, ref output)
}
fn deserialize(ref serialized: core::array::Span<felt252>) -> core::option::Option<Position> {
core::option::Option::Some(Position {
id: core::serde::Serde::deserialize(ref serialized)?,
x: core::serde::Serde::deserialize(ref serialized)?,
y: core::serde::Serde::deserialize(ref serialized)?,
})
}
}
impl RolesSerde of core::serde::Serde::<Roles> {
fn serialize(self: @Roles, ref output: core::array::Array<felt252>) {
core::serde::Serde::serialize(self.role_ids, ref output)
}
fn deserialize(ref serialized: core::array::Span<felt252>) -> core::option::Option<Roles> {
core::option::Option::Some(Roles {
role_ids: core::serde::Serde::deserialize(ref serialized)?,
})
}
}
impl PlayerCopy of core::traits::Copy::<Player>;
impl PlayerDrop of core::traits::Drop::<Player>;
impl PlayerSerde of core::serde::Serde::<Player> {
fn serialize(self: @Player, ref output: core::array::Array<felt252>) {
core::serde::Serde::serialize(self.game, ref output);
core::serde::Serde::serialize(self.player, ref output);
core::serde::Serde::serialize(self.name, ref output)
}
fn deserialize(ref serialized: core::array::Span<felt252>) -> core::option::Option<Player> {
core::option::Option::Some(Player {
game: core::serde::Serde::deserialize(ref serialized)?,
player: core::serde::Serde::deserialize(ref serialized)?,
name: core::serde::Serde::deserialize(ref serialized)?,
})
}

#[derive(Print)]
enum Enemy {
Unknown,
Bot: felt252,
OtherPlayer: ContractAddress,
}
45 changes: 42 additions & 3 deletions crates/dojo-lang/src/print.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use cairo_lang_defs::patcher::RewriteNode;
use cairo_lang_syntax::node::ast::ItemStruct;
use cairo_lang_syntax::node::ast::{ItemEnum, ItemStruct};
0xicosahedron marked this conversation as resolved.
Show resolved Hide resolved
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
Expand All @@ -10,7 +10,7 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
/// * struct_ast: The AST of the model struct.
/// Returns:
/// * A RewriteNode containing the generated code.
pub fn derive_print(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode {
pub fn handle_print_struct(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode {
let prints: Vec<_> = struct_ast
.members(db)
.elements(db)
Expand All @@ -26,7 +26,7 @@ pub fn derive_print(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode

RewriteNode::interpolate_patched(
"#[cfg(test)]
impl $type_name$PrintImpl of core::debug::PrintTrait<$type_name$> {
impl $type_name$StructPrintImpl of core::debug::PrintTrait<$type_name$> {
fn print(self: $type_name$) {
$print$
}
Expand All @@ -40,3 +40,42 @@ pub fn derive_print(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode
]),
)
}

/// Derives PrintTrait for an enum.
/// Parameters:
/// * db: The semantic database.
/// * enum_ast: The AST of the model enum.
/// Returns:
/// * A RewriteNode containing the generated code.
pub fn handle_print_enum(db: &dyn SyntaxGroup, enum_ast: ItemEnum) -> RewriteNode {
let enum_name = enum_ast.name(db).text(db);
let prints: Vec<_> = enum_ast
.variants(db)
.elements(db)
.iter()
.map(|m| {
format!(
"{}::{}(value) => {{ core::debug::PrintTrait::print('{}'); \
core::debug::PrintTrait::print(value); }}",
enum_name,
m.name(db).text(db).to_string(),
m.name(db).text(db).to_string()
)
0xicosahedron marked this conversation as resolved.
Show resolved Hide resolved
})
.collect();

RewriteNode::interpolate_patched(
"#[cfg(test)]
impl $type_name$EnumPrintImpl of core::debug::PrintTrait<$type_name$> {
fn print(self: $type_name$) {
match self {
$print$
}
}
}",
&UnorderedHashMap::from([
("type_name".to_string(), RewriteNode::new_trimmed(enum_ast.name(db).as_syntax_node())),
("print".to_string(), RewriteNode::Text(prints.join(",\n"))),
]),
)
}
2 changes: 1 addition & 1 deletion examples/spawn-and-move/Scarb.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [

[[package]]
name = "dojo_examples"
version = "0.4.2"
version = "0.4.3"
dependencies = [
"dojo",
]
Expand Down
Loading