Skip to content

Commit

Permalink
feat: improve dojo::contract syntax
Browse files Browse the repository at this point in the history
`self` parameter is not mandatory anymore in dojo::contract. If not present,
it will be automatically added during code expansion.

if a `world: IWorldDispatcher` parameter is given, it is removed and replace
by the `let world = self.world_dispatcher.read();` statement at the beginning
of the function body.
  • Loading branch information
Rémy Baranx committed Mar 3, 2024
1 parent 435e3ee commit 6c28119
Show file tree
Hide file tree
Showing 2 changed files with 350 additions and 11 deletions.
184 changes: 177 additions & 7 deletions crates/dojo-lang/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
use cairo_lang_defs::plugin::{
DynGeneratedFileAuxData, PluginDiagnostic, PluginGeneratedFile, PluginResult,
};
// use cairo_lang_syntax::node::ast::{MaybeModuleBody, Param};

use cairo_lang_diagnostics::Severity;
use cairo_lang_syntax::node::ast::MaybeModuleBody;
// use cairo_lang_syntax::node::ast::OptionReturnTypeClause::ReturnTypeClause;
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{ast, Terminal, TypedSyntaxNode};
use cairo_lang_syntax::node::{ast, ids, Terminal, TypedSyntaxNode};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use dojo_types::system::Dependency;

Expand Down Expand Up @@ -42,6 +42,10 @@ impl DojoContract {
has_storage = true;
return system.merge_storage(db, struct_ast.clone());
}
} else if let ast::ModuleItem::FreeFunction(fn_ast) = el {
return system.rewrite_function(db, fn_ast.clone());
} else if let ast::ModuleItem::Impl(impl_ast) = el {
return system.rewrite_impl(db, impl_ast.clone());
}

vec![RewriteNode::Copied(el.as_syntax_node())]
Expand Down Expand Up @@ -197,8 +201,174 @@ impl DojoContract {
.to_string(),
)]
}
}

// fn is_context(db: &dyn SyntaxGroup, param: &Param) -> bool {
// param.type_clause(db).ty(db).as_syntax_node().get_text(db) == "Context"
// }
/// Gets name, modifiers and type from a function parameter.
pub fn get_parameter_info(
&mut self,
db: &dyn SyntaxGroup,
param: ast::Param
) -> (String, String, String) {
let name = param.name(db).text(db).trim().to_string();
let modifiers = param.modifiers(db).as_syntax_node().get_text(db).trim().to_string();
let param_type = param.type_clause(db).ty(db).as_syntax_node().get_text(db).trim().to_string();

(name, modifiers, param_type)
}

/// Rewrites parameter list by:
/// * adding `self` parameter if missing,
/// * removing `world` if present as it will be read from the first function statement.
///
/// Reports an error in case of `ref self` as systems are supposed to be 100% stateless.
///
/// Returns
/// * the list of parameters in a String
/// * a boolean indicating if `self` has been added
// * a boolean indicating if `world` parameter has been removed
pub fn rewrite_parameters(
&mut self,
db: &dyn SyntaxGroup,
param_list: ast::ParamList,
diagnostic_item: ids::SyntaxStablePtrId
) -> (String, bool, bool) {
let mut world_removed = false;

let mut params = param_list.elements(db)
.iter()
.filter_map(|e| {
let (name, modifiers, param_type) = self.get_parameter_info(db, e.clone());

if name.eq(&"world".to_string()) && modifiers.eq(&"".to_string()) && param_type.eq(&"IWorldDispatcher".to_string()) {
world_removed = true;
None
}
else {
Some(e.as_syntax_node().get_text(db))
}
})
.collect::<Vec<_>>();

let mut add_self = true;
if !params.is_empty() {
let (param_name, param_modifiers, param_type) = self.get_parameter_info(
db,
param_list.elements(db)[0].clone()
);

if param_name.eq(&"self".to_string()) {
if param_modifiers.contains(&"ref".to_string()) && param_type.eq(&"ContractState".to_string()) {
self.diagnostics.push(PluginDiagnostic {
stable_ptr: diagnostic_item,
message: "Functions of dojo::contract cannot have `ref self` parameter.".to_string(),
severity: Severity::Error,
});

add_self = false;
}

if param_type.eq(&"@ContractState".to_string()) {
add_self = false;
}
}
};

if add_self {
params.insert(0, "self: @ContractState".to_string());
}

(params.join(", "), add_self, world_removed)
}

/// Rewrites function statements by adding the reading of `world` at first statement.
pub fn rewrite_statements(
&mut self,
db: &dyn SyntaxGroup,
statement_list: ast::StatementList
) -> String {
let mut statements = statement_list.elements(db)
.iter()
.map(|e| e.as_syntax_node().get_text(db))
.collect::<Vec<_>>();

statements.insert(0, "let world = self.world_dispatcher.read();\n".to_string());
statements.join("")
}

/// Rewrites function declaration by:
/// * adding `self` parameter if missing,
/// * removing `world` if present,
/// * adding `let world = self.world_dispatcher.read();` statement
/// at the beginning of the function to restore the removed `world` parameter.
pub fn rewrite_function(
&mut self,
db: &dyn SyntaxGroup,
fn_ast: ast::FunctionWithBody,
) -> Vec<RewriteNode> {
let mut rewritten_fn = RewriteNode::from_ast(&fn_ast);

let (params_str, self_added, world_removed) = self.rewrite_parameters(
db,
fn_ast.declaration(db).signature(db).parameters(db),
fn_ast.stable_ptr().untyped()
);

if self_added || world_removed {
let rewritten_params = rewritten_fn
.modify_child(db, ast::FunctionWithBody::INDEX_DECLARATION)
.modify_child(db, ast::FunctionDeclaration::INDEX_SIGNATURE)
.modify_child(db, ast::FunctionSignature::INDEX_PARAMETERS);
rewritten_params.set_str(params_str);
}

if world_removed {
let rewritten_statements = rewritten_fn
.modify_child(db, ast::FunctionWithBody::INDEX_BODY)
.modify_child(db, ast::ExprBlock::INDEX_STATEMENTS);

rewritten_statements.set_str(
self.rewrite_statements(db, fn_ast.body(db).statements(db))
);
}

vec![rewritten_fn]
}

/// Rewrites all the functions of a Impl block.
fn rewrite_impl(
&mut self,
db: &dyn SyntaxGroup,
impl_ast: ast::ItemImpl,
) -> Vec<RewriteNode> {
if let ast::MaybeImplBody::Some(body) = impl_ast.body(db) {
let body_nodes: Vec<_> = body
.items(db)
.elements(db)
.iter()
.flat_map(|el| {
if let ast::ImplItem::Function(fn_ast) = el {
return self.rewrite_function(db, fn_ast.clone());
}
vec![RewriteNode::Copied(el.as_syntax_node())]
})
.collect();

let mut builder = PatchBuilder::new(db);
builder.add_modified(RewriteNode::interpolate_patched(
"$body$",
&UnorderedHashMap::from([
("body".to_string(), RewriteNode::new_modified(body_nodes)),
])
));

let mut rewritten_impl = RewriteNode::from_ast(&impl_ast);
let rewritten_items = rewritten_impl
.modify_child(db, ast::ItemImpl::INDEX_BODY)
.modify_child(db, ast::ImplBody::INDEX_ITEMS);

rewritten_items.set_str(builder.code);
return vec![rewritten_impl];
}

vec![RewriteNode::Copied(impl_ast.as_syntax_node())]
}
}
Loading

0 comments on commit 6c28119

Please sign in to comment.