Skip to content

Commit

Permalink
feat: add [dojo::interface] attribute
Browse files Browse the repository at this point in the history
Add a new `[dojo::interface]` attribute which is expanded in `[starknet::interface]`,
allowing the user to ignore `TContractState` type and `self` parameter.
  • Loading branch information
Rémy Baranx committed Mar 2, 2024
1 parent be5469e commit a3dca4a
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 0 deletions.
197 changes: 197 additions & 0 deletions crates/dojo-lang/src/interface.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
use std::collections::HashMap;

use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
use cairo_lang_defs::plugin::{
DynGeneratedFileAuxData, PluginDiagnostic, PluginGeneratedFile, PluginResult,
};
use cairo_lang_diagnostics::Severity;
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_syntax::node::{ast, Terminal, TypedSyntaxNode};
use cairo_lang_syntax::node::ast::{MaybeTraitBody, OptionReturnTypeClause, ParamList};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use dojo_types::system::Dependency;

use crate::plugin::{DojoAuxData, SystemAuxData};

pub struct DojoInterface {
diagnostics: Vec<PluginDiagnostic>,
dependencies: HashMap<smol_str::SmolStr, Dependency>,
}

impl DojoInterface {
pub fn from_trait(db: &dyn SyntaxGroup, trait_ast: ast::ItemTrait) -> PluginResult {
let name = trait_ast.name(db).text(db);
let mut system = DojoInterface { diagnostics: vec![], dependencies: HashMap::new() };
let mut builder = PatchBuilder::new(db);

if let MaybeTraitBody::Some(body) = trait_ast.body(db) {
let body_nodes: Vec<_> = body
.items(db)
.elements(db)
.iter()
.flat_map(|el| {
if let ast::TraitItem::Function(fn_ast) = el {
return system.rewrite_function(db, fn_ast.clone());
}

//TODO: if the trait body contains other things than functions,
// they are not correctly copied.
// for example: `const ONE: u8 = 1;` is copied as `const ONE`
// vec![RewriteNode::Copied(el.as_syntax_node())]
vec![]

Check warning on line 42 in crates/dojo-lang/src/interface.rs

View check run for this annotation

Codecov / codecov/patch

crates/dojo-lang/src/interface.rs#L36-L42

Added lines #L36 - L42 were not covered by tests
})
.collect();

builder.add_modified(RewriteNode::interpolate_patched(
"
#[starknet::interface]
trait $name$<TContractState> {
$body$
}
",
&UnorderedHashMap::from([
("name".to_string(), RewriteNode::Text(name.to_string())),
("body".to_string(), RewriteNode::new_modified(body_nodes)),
]),
));
}
else {
// empty trait
builder.add_modified(RewriteNode::interpolate_patched(
"
#[starknet::interface]
trait $name$<TContractState> {}
",
&UnorderedHashMap::from([
("name".to_string(), RewriteNode::Text(name.to_string())),
]),
));
}

return PluginResult {
code: Some(PluginGeneratedFile {
name: name.clone(),
content: builder.code,
aux_data: Some(DynGeneratedFileAuxData::new(DojoAuxData {
models: vec![],
systems: vec![SystemAuxData {
name,
dependencies: system.dependencies.values().cloned().collect(),
}],
})),
code_mappings: builder.code_mappings,
}),
diagnostics: system.diagnostics,
remove_original_item: true,
};
}

/// Rewrites parameter list by adding `self` parameter if missing.
///
/// Reports an error in case of `ref self` as systems are supposed to be 100% stateless.
pub fn rewrite_parameters(
&mut self,
db: &dyn SyntaxGroup,
param_list: ParamList,
diagnostic_item: SyntaxStablePtrId
) -> String {
let mut params = param_list.elements(db)
.iter()
.map(|e| e.as_syntax_node().get_text(db))
.collect::<Vec<_>>();

let mut need_to_add_self = true;
if !params.is_empty() {
let first_param = param_list.elements(db)[0].clone();
let param_name = first_param.name(db).text(db).to_string();

if param_name.eq(&"self".to_string()) {
let param_modifiers = first_param
.modifiers(db).elements(db)
.iter().map(|e|e.as_syntax_node().get_text(db).trim().to_string())
.collect::<Vec<_>>();

let param_type = first_param
.type_clause(db).ty(db).as_syntax_node()
.get_text(db).trim().to_string();

if param_modifiers.contains(&"ref".to_string()) && param_type.eq(&"TContractState".to_string()) {
self.diagnostics.push(PluginDiagnostic {
stable_ptr: diagnostic_item,
message: "Functions of dojo::interface cannot have `ref self` parameter.".to_string(),
severity: Severity::Error,
});

need_to_add_self = false;
}

if param_type.eq(&"@TContractState".to_string()) {
need_to_add_self = false;
}
}
};

if need_to_add_self {
params.insert(0, "self: @TContractState".to_string());
}

params.join(", ")
}

/// Rewrites function declaration by adding `self` parameter if missing.
///
/// Some notes:
/// * as rewritten functions belong to a starknet::interface:
/// - there is no generic parameter,
/// - there is no nopanic and implicits clauses.
/// * there is no function body in a trait.
pub fn rewrite_function(
&mut self,
db: &dyn SyntaxGroup,
fn_ast: ast::TraitItemFunction,
) -> Vec<RewriteNode> {
let mut rewrite_nodes = vec![];

let declaration = fn_ast.declaration(db);
let fn_name_node = declaration.name(db).clone();
let fn_name = fn_name_node.text(db).to_string();
let signature = declaration.signature(db);

let attributes;
if fn_ast.attributes(db).elements(db).is_empty() {
attributes = "".to_string();
}
else {
attributes = fn_ast.attributes(db).elements(db)
.iter()
.map(|e| e.as_syntax_node().get_text(db))
.collect::<Vec<_>>()
.join("\n");
}

let ret_ty = if let OptionReturnTypeClause::ReturnTypeClause(ty) = signature.ret_ty(db) {
format!(" {}", ty.as_syntax_node().get_text(db).to_string())
} else {
"".to_string()
};

let params = self.rewrite_parameters(
db,
signature.parameters(db),
fn_name_node.stable_ptr().untyped()
);

rewrite_nodes.push(RewriteNode::interpolate_patched(
"$attributes$fn $fn_name$($params$)$ret_ty$;\n",
&UnorderedHashMap::from([
("attributes".to_string(), RewriteNode::Text(attributes)),
("fn_name".to_string(), RewriteNode::Text(fn_name)),
("params".to_string(), RewriteNode::Text(params)),
("ret_ty".to_string(), RewriteNode::Text(ret_ty))
]),
));

rewrite_nodes
}
}
1 change: 1 addition & 0 deletions crates/dojo-lang/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! Learn more at [dojoengine.gg](http://dojoengine.gg).
pub mod compiler;
pub mod contract;
pub mod interface;
pub mod inline_macros;
pub mod introspect;
pub mod model;
Expand Down
11 changes: 11 additions & 0 deletions crates/dojo-lang/src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use smol_str::SmolStr;
use url::Url;

use crate::contract::DojoContract;
use crate::interface::DojoInterface;
use crate::inline_macros::array_cap::ArrayCapMacro;
use crate::inline_macros::delete::DeleteMacro;
use crate::inline_macros::emit::EmitMacro;
Expand All @@ -33,6 +34,7 @@ use crate::model::handle_model_struct;
use crate::print::{handle_print_enum, handle_print_struct};

const DOJO_CONTRACT_ATTR: &str = "dojo::contract";
const DOJO_INTERFACE_ATTR: &str = "dojo::interface";
const DOJO_PLUGIN_EXPAND_VAR_ENV: &str = "DOJO_PLUGIN_EXPAND";

#[derive(Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -101,6 +103,14 @@ impl BuiltinDojoPlugin {
PluginResult::default()
}

fn handle_trait(&self, db: &dyn SyntaxGroup, trait_ast: ast::ItemTrait) -> PluginResult {
if trait_ast.has_attr(db, DOJO_INTERFACE_ATTR) {
return DojoInterface::from_trait(db, trait_ast);
}

PluginResult::default()
}

fn result_with_diagnostic(
&self,
stable_ptr: SyntaxStablePtrId,
Expand Down Expand Up @@ -246,6 +256,7 @@ impl MacroPlugin for BuiltinDojoPlugin {

match item_ast {
ast::ModuleItem::Module(module_ast) => self.handle_mod(db, module_ast),
ast::ModuleItem::Trait(trait_ast) => self.handle_trait(db, trait_ast),
ast::ModuleItem::Enum(enum_ast) => {
let aux_data = DojoAuxData::default();
let mut rewrite_nodes = vec![];
Expand Down
42 changes: 42 additions & 0 deletions crates/dojo-lang/src/plugin_test_data/system
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ mod withcomponent {
testcomponent2_event: testcomponent2::Event,
}
}
#[dojo::interface]
trait IEmptyTrait;

#[dojo::interface]
trait IActions {
fn do_no_param();
fn do_no_param_but_self(self: @TContractState);
fn do_params(p1: dojo_examples::models::Direction, p2: u8);
fn do_params_and_self(self: @TContractState, p2: u8);
fn do_return_value(p1: u8) -> u16;
fn do_ref_self(ref self: TContractState);

#[my_attr]
fn do_with_attrs(p1: u8) -> u16;
}

//! > generated_cairo_code
#[starknet::contract]
Expand Down Expand Up @@ -205,6 +220,11 @@ error: Unsupported attribute.
#[starknet::component]
^********************^

error: Functions of dojo::interface cannot have `ref self` parameter.
--> test_src/lib.cairo:91:8
fn do_ref_self(ref self: TContractState);
^*********^

error: Unsupported attribute.
--> test_src/lib.cairo[spawn]:2:17
#[starknet::contract]
Expand All @@ -230,6 +250,11 @@ error: Unsupported attribute.
#[starknet::contract]
^*******************^

error: Unsupported attribute.
--> test_src/lib.cairo[IActions]:11:5
#[my_attr]
^********^

error: Unsupported attribute.
--> test_src/lib.cairo:49:5
#[storage]
Expand Down Expand Up @@ -701,3 +726,20 @@ impl TestEventDrop of core::traits::Drop::<TestEvent>;
impl EventDrop of core::traits::Drop::<Event>;

}

#[starknet::interface]
trait IEmptyTrait<TContractState> {}

#[starknet::interface]
trait IActions<TContractState> {
fn do_no_param(self: @TContractState);
fn do_no_param_but_self(self: @TContractState);
fn do_params(self: @TContractState, p1: dojo_examples::models::Direction, p2: u8);
fn do_params_and_self(self: @TContractState, p2: u8);
fn do_return_value(self: @TContractState, p1: u8) -> u16;
fn do_ref_self(ref self: TContractState);

#[my_attr]
fn do_with_attrs(self: @TContractState, p1: u8) -> u16;

}

0 comments on commit a3dca4a

Please sign in to comment.