Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support deriving Print for enums #1091

Merged
merged 17 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/dojo-lang/src/introspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub fn handle_introspect_struct(db: &dyn SyntaxGroup, struct_ast: ItemStruct) ->
/// A handler for Dojo code derives Introspect for an enum
/// Parameters:
/// * db: The semantic database.
/// * struct_ast: The AST of the struct.
/// * enum_ast: The AST of the enum.
/// Returns:
/// * A RewriteNode containing the generated code.
pub fn handle_introspect_enum(
Expand Down
5 changes: 3 additions & 2 deletions crates/dojo-lang/src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::inline_macros::get::GetMacro;
use crate::inline_macros::set::SetMacro;
use crate::introspect::{handle_introspect_enum, handle_introspect_struct};
use crate::model::handle_model_struct;
use crate::print::derive_print;
use crate::print::{handle_print_enum, handle_print_struct};

const DOJO_CONTRACT_ATTR: &str = "dojo::contract";

Expand Down Expand Up @@ -303,6 +303,7 @@ impl MacroPlugin for BuiltinDojoPlugin {
enum_ast.clone(),
));
}
"Print" => rewrite_nodes.push(handle_print_enum(db, enum_ast.clone())),
_ => continue,
}
}
Expand Down Expand Up @@ -379,7 +380,7 @@ impl MacroPlugin for BuiltinDojoPlugin {
diagnostics.extend(model_diagnostics);
}
"Print" => {
rewrite_nodes.push(derive_print(db, struct_ast.clone()));
rewrite_nodes.push(handle_print_struct(db, struct_ast.clone()));
}
"Introspect" => {
rewrite_nodes
Expand Down
35 changes: 35 additions & 0 deletions crates/dojo-lang/src/plugin_test_data/print
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ struct Player {
name: felt252,
}

#[derive(Print, Copy, Drop, Serde)]
0xicosahedron marked this conversation as resolved.
Show resolved Hide resolved
enum Enemy {
Unknown,
Bot: felt252,
OtherPlayer: ContractAddress,
}

//! > generated_cairo_code
use serde::Serde;

Expand Down Expand Up @@ -84,6 +91,7 @@ struct Player {

name: felt252,
}

#[cfg(test)]
impl PlayerPrintImpl of debug::PrintTrait<Player> {
fn print(self: Player) {
Expand Down Expand Up @@ -126,6 +134,13 @@ struct Player {

name: felt252,
}

#[derive(Print, Copy, Drop, Serde)]
0xicosahedron marked this conversation as resolved.
Show resolved Hide resolved
enum Enemy {
Unknown,
Bot: felt252,
OtherPlayer: ContractAddress,
}
impl PositionCopy of Copy::<Position>;
impl PositionDrop of Drop::<Position>;
impl PositionSerde of Serde::<Position> {
Expand Down Expand Up @@ -168,3 +183,23 @@ impl PlayerSerde of Serde::<Player> {
})
}
}
impl EnemyCopy of Copy::<Enemy>;
impl EnemyDrop of Drop::<Enemy>;
impl EnemySerde of Serde::<Enemy> {
fn serialize(self: @Enemy, ref output: array::Array<felt252>) {
match self {
Enemy::Unknown(x) => { serde::Serde::serialize(@0, ref output); serde::Serde::serialize(x, ref output); },
Enemy::Bot(x) => { serde::Serde::serialize(@1, ref output); serde::Serde::serialize(x, ref output); },
Enemy::OtherPlayer(x) => { serde::Serde::serialize(@2, ref output); serde::Serde::serialize(x, ref output); },
}
}
fn deserialize(ref serialized: array::Span<felt252>) -> Option<Enemy> {
let idx: felt252 = serde::Serde::deserialize(ref serialized)?;
Option::Some(
if idx == 0 { Enemy::Unknown(serde::Serde::deserialize(ref serialized)?) }
else if idx == 1 { Enemy::Bot(serde::Serde::deserialize(ref serialized)?) }
else if idx == 2 { Enemy::OtherPlayer(serde::Serde::deserialize(ref serialized)?) }
else { return Option::None; }
)
}
}
45 changes: 42 additions & 3 deletions crates/dojo-lang/src/print.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use cairo_lang_defs::patcher::RewriteNode;
use cairo_lang_syntax::node::ast::ItemStruct;
use cairo_lang_syntax::node::ast::{ItemEnum, ItemStruct};
0xicosahedron marked this conversation as resolved.
Show resolved Hide resolved
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
Expand All @@ -10,7 +10,7 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
/// * struct_ast: The AST of the model struct.
/// Returns:
/// * A RewriteNode containing the generated code.
pub fn derive_print(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode {
pub fn handle_print_struct(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode {
let prints: Vec<_> = struct_ast
.members(db)
.elements(db)
Expand All @@ -26,7 +26,7 @@ pub fn derive_print(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode

RewriteNode::interpolate_patched(
"#[cfg(test)]
impl $type_name$PrintImpl of debug::PrintTrait<$type_name$> {
impl $type_name$StructPrintImpl of debug::PrintTrait<$type_name$> {
fn print(self: $type_name$) {
$print$
}
Expand All @@ -40,3 +40,42 @@ pub fn derive_print(db: &dyn SyntaxGroup, struct_ast: ItemStruct) -> RewriteNode
]),
)
}

/// Derives PrintTrait for an enum.
/// Parameters:
/// * db: The semantic database.
/// * enum_ast: The AST of the model enum.
/// Returns:
/// * A RewriteNode containing the generated code.
pub fn handle_print_enum(db: &dyn SyntaxGroup, enum_ast: ItemEnum) -> RewriteNode {
let enum_name = enum_ast.name(db).text(db);
let prints: Vec<_> = enum_ast
.variants(db)
.elements(db)
.iter()
.map(|m| {
format!(
"{}::{}(value) => {{ debug::PrintTrait::print('{}'); \
0xicosahedron marked this conversation as resolved.
Show resolved Hide resolved
debug::PrintTrait::print(value); }}",
0xicosahedron marked this conversation as resolved.
Show resolved Hide resolved
enum_name,
m.name(db).text(db).to_string(),
m.name(db).text(db).to_string()
)
})
.collect();

RewriteNode::interpolate_patched(
"#[cfg(test)]
impl $type_name$EnumPrintImpl of debug::PrintTrait<$type_name$> {
0xicosahedron marked this conversation as resolved.
Show resolved Hide resolved
fn print(self: $type_name$) {
match self {
$print$
}
}
}",
&UnorderedHashMap::from([
("type_name".to_string(), RewriteNode::new_trimmed(enum_ast.name(db).as_syntax_node())),
("print".to_string(), RewriteNode::Text(prints.join("\n"))),
0xicosahedron marked this conversation as resolved.
Show resolved Hide resolved
]),
)
}
Loading