Skip to content

Commit

Permalink
feat: dispatcher_from_tag! macro
Browse files Browse the repository at this point in the history
  • Loading branch information
remybar committed Sep 12, 2024
1 parent 1ef88f8 commit 75a75dd
Show file tree
Hide file tree
Showing 15 changed files with 337 additions and 9 deletions.
44 changes: 40 additions & 4 deletions crates/dojo-lang/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use starknet::core::types::contract::SierraClass;
use starknet::core::types::Felt;
use tracing::{debug, trace, trace_span};

use crate::plugin::{DojoAuxData, Model};
use crate::plugin::{DojoAuxData, Model, Trait};
use crate::scarb_internal::debug::SierraToCairoDebugInfo;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -350,6 +350,7 @@ fn update_files(

let mut models = BTreeMap::new();
let mut contracts = BTreeMap::new();
let mut dojo_interfaces = vec![];

if let Some(external_contracts) = external_contracts {
let external_crate_ids = collect_external_crate_ids(db, external_contracts);
Expand Down Expand Up @@ -378,6 +379,7 @@ fn update_files(
&naming::get_tag(&contract.namespace, &contract.name),
&compiled_artifacts,
&contract.systems,
&contract.traits,
)?);
}

Expand All @@ -392,6 +394,9 @@ fn update_files(
*module_id,
&compiled_artifacts,
)?);

// update the list of Dojo interface
dojo_interfaces.extend(dojo_aux_data.interfaces.iter().map(|i| i.name.clone()));
}

// StarknetAuxData shouldn't be required. Every dojo contract and model are starknet
Expand Down Expand Up @@ -422,7 +427,37 @@ fn update_files(
std::fs::create_dir_all(&base_contracts_abis_dir)?;
}

for (_, (manifest, module_id, artifact)) in contracts.iter_mut() {
for (qualified_path, (manifest, module_id, artifact, traits)) in contracts.iter_mut() {
// During compilation, a list of traits implemented inside the contract is extracted.
// We need to find which of these traits is a Dojo interface, to be able to store its qualified path,
// and be able to use it in macros such `dispatcher_from_tag!`
let found_interfaces =
traits.iter().filter(|t| dojo_interfaces.contains(&t.name)).collect::<Vec<_>>();

// a contract may or may not implement a Dojo interface, but it cannot implement several Dojo interfaces.
if found_interfaces.len() > 1 {
return Err(anyhow!(
"The contract '{}' cannot implement several Dojo interfaces (found: [{}]).",
manifest.inner.tag,
found_interfaces.iter().map(|t| t.name.clone()).collect::<Vec<_>>().join(", ")
));
}

manifest.inner.interface_path =
found_interfaces.first().map_or(String::new(), |x| x.path.clone());

// the Dojo interface path may start with `super`, referencing the trait which is
// in the same file than the contract.
// In this case, just replace `super` by the contract qualified path.
if manifest.inner.interface_path.starts_with("super") {
let (path, _) = qualified_path
.rsplit_once(CAIRO_PATH_SEPARATOR)
.unwrap_or((qualified_path.as_str(), ""));

manifest.inner.interface_path =
manifest.inner.interface_path.replacen("super", &path, 1);
}

write_manifest_and_abi(
&base_contracts_dir,
&base_contracts_abis_dir,
Expand Down Expand Up @@ -558,7 +593,8 @@ fn get_dojo_contract_artifacts(
tag: &str,
compiled_classes: &CompiledArtifactByPath,
systems: &[String],
) -> Result<HashMap<String, (Manifest<DojoContract>, ModuleId, CompiledArtifact)>> {
traits: &Vec<Trait>,
) -> Result<HashMap<String, (Manifest<DojoContract>, ModuleId, CompiledArtifact, Vec<Trait>)>> {
let mut result = HashMap::new();

if !matches!(naming::get_name_from_tag(tag).as_str(), "world" | "resource_metadata" | "base") {
Expand All @@ -584,7 +620,7 @@ fn get_dojo_contract_artifacts(

result.insert(
contract_qualified_path.to_string(),
(manifest, *module_id, artifact.clone()),
(manifest, *module_id, artifact.clone(), traits.clone()),
);
}
}
Expand Down
83 changes: 82 additions & 1 deletion crates/dojo-lang/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use dojo_types::system::Dependency;
use dojo_world::config::NamespaceConfig;
use dojo_world::contracts::naming;

use crate::plugin::{ContractAuxData, DojoAuxData, DOJO_CONTRACT_ATTR};
use crate::plugin::{ContractAuxData, DojoAuxData, Trait, DOJO_CONTRACT_ATTR};
use crate::syntax::world_param::{self, WorldParamInjectionKind};
use crate::syntax::{self_param, utils as syntax_utils};

Expand Down Expand Up @@ -50,6 +50,7 @@ impl DojoContract {

let mut diagnostics = vec![];
let parameters = get_parameters(db, module_ast, &mut diagnostics);
let traits = get_traits(db, module_ast);

let mut contract =
DojoContract { diagnostics, dependencies: HashMap::new(), systems: vec![] };
Expand Down Expand Up @@ -262,8 +263,10 @@ impl DojoContract {
namespace: contract_namespace.clone(),
dependencies: contract.dependencies.values().cloned().collect(),
systems: contract.systems.clone(),
traits,
}],
events: vec![],
interfaces: vec![],
})),
code_mappings,
}),
Expand Down Expand Up @@ -732,3 +735,81 @@ fn get_parameters(

parameters
}

fn get_traits(db: &dyn SyntaxGroup, module_ast: &ast::ItemModule) -> Vec<Trait> {
let traits = if let ast::MaybeModuleBody::Some(body) = module_ast.body(db) {
body.items(db)
.elements(db)
.iter()
.filter_map(|e| {
if let ast::ModuleItem::Impl(x) = e {
let mut path_segments = x.trait_path(db).elements(db);

// in Cairo, there is always a trait path linked to an impl, so path_segments always contain an item
let last_segment = path_segments.pop().unwrap();

// a dojo interface always has a generic <ContractState> argument, so just keep this kind of traits.
match last_segment {
ast::PathSegment::WithGenericArgs(p) => {
let trait_name = p.ident(db).text(db).to_string();

// Here, we have to rebuild the full trait path. There are several cases:
// 1) there is no path with the trait name (example: IActions)
// => find the trait path in `use` clauses.
// 2) the path is relative (example: `players::IActions`)
// => not possible to be sure that there is only one path in `use` clauses that matches
// with this relative path.
// (example: 2 use clauses: `use path1::players` and `use path2::players`).
// 3) the path is absolute (example: `path::to::players::IActions`)
// => just use this path.
//
// At the moment, only cases 1) and 3) are supported.
let trait_path = if path_segments.is_empty() {
get_trait_path(db, &body, trait_name.clone())
.expect("a path must always be found")
} else {
format!(
"{}::{trait_name}",
path_segments
.iter()
.map(|p| p.as_syntax_node().get_text(db))
.collect::<Vec<_>>()
.join("::")
)
};

Some(Trait { name: trait_name.clone(), path: trait_path })
}
_ => None,
}
} else {
None
}
})
.collect::<Vec<_>>()
} else {
vec![]
};

traits
}

fn get_trait_path(
db: &dyn SyntaxGroup,
module_body: &ast::ModuleBody,
trait_name: String,
) -> Option<String> {
for e in module_body.items(db).elements(db) {
if let ast::ModuleItem::Use(u) = e {
if let ast::UsePath::Single(s) = u.use_path(db) {
let trait_path = s.as_syntax_node().get_text(db);

if trait_path.ends_with(&trait_name) {
return Some(trait_path);
}
}
}
}

None
}
93 changes: 93 additions & 0 deletions crates/dojo-lang/src/inline_macros/dispatcher_from_tag.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use cairo_lang_defs::patcher::PatchBuilder;
use cairo_lang_defs::plugin::{
InlineMacroExprPlugin, InlinePluginResult, MacroPluginMetadata, NamedPlugin, PluginDiagnostic,
PluginGeneratedFile,
};
use cairo_lang_defs::plugin_utils::unsupported_bracket_diagnostic;
use cairo_lang_diagnostics::Severity;
use cairo_lang_syntax::node::{ast, TypedStablePtr, TypedSyntaxNode};
use dojo_world::contracts::naming;

use super::utils::find_interface_path;

#[derive(Debug, Default)]
pub struct DispatcherFromTagMacro;

impl NamedPlugin for DispatcherFromTagMacro {
const NAME: &'static str = "dispatcher_from_tag";
}

impl InlineMacroExprPlugin for DispatcherFromTagMacro {
fn generate_code(
&self,
db: &dyn cairo_lang_syntax::node::db::SyntaxGroup,
syntax: &ast::ExprInlineMacro,
metadata: &MacroPluginMetadata<'_>,
) -> InlinePluginResult {
let ast::WrappedArgList::ParenthesizedArgList(arg_list) = syntax.arguments(db) else {
return unsupported_bracket_diagnostic(db, syntax);
};

let args = arg_list.arguments(db).elements(db);

if args.len() != 2 {
return InlinePluginResult {
code: None,
diagnostics: vec![PluginDiagnostic {
stable_ptr: syntax.stable_ptr().untyped(),
message: "Invalid arguments. Expected dispatcher_from_tag!(\"tag\", contract_address)"
.to_string(),
severity: Severity::Error,
}],
};
}

let tag = &args[0].as_syntax_node().get_text(db).replace('\"', "");
let contract_address = args[1].as_syntax_node().get_text(db);

if !naming::is_valid_tag(tag) {
return InlinePluginResult {
code: None,
diagnostics: vec![PluginDiagnostic {
stable_ptr: syntax.stable_ptr().untyped(),
message: "Invalid tag. Tag must be in the format of `namespace-name`."
.to_string(),
severity: Severity::Error,
}],
};
}

// read the interface path from the manifest and generate a dispatcher:
// <interface_path>Dispatcher { contract_address };
let interface_path = match find_interface_path(metadata.cfg_set, tag) {
Ok(interface_path) => interface_path,
Err(_e) => {
return InlinePluginResult {
code: None,
diagnostics: vec![PluginDiagnostic {
stable_ptr: syntax.stable_ptr().untyped(),
message: format!("Failed to find the interface path of `{tag}`"),
severity: Severity::Error,
}],
};
}
};

let mut builder = PatchBuilder::new(db, syntax);
builder.add_str(&format!(
"{interface_path}Dispatcher {{ contract_address: {contract_address}}}",
));

let (code, code_mappings) = builder.build();

InlinePluginResult {
code: Some(PluginGeneratedFile {
name: "dispatcher_from_tag_macro".into(),
content: code,
code_mappings,
aux_data: None,
}),
diagnostics: vec![],
}
}
}
1 change: 1 addition & 0 deletions crates/dojo-lang/src/inline_macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use cairo_lang_syntax::node::{ast, Terminal, TypedStablePtr, TypedSyntaxNode};
use smol_str::SmolStr;

pub mod delete;
pub mod dispatcher_from_tag;
pub mod emit;
pub mod get;
pub mod get_models_test_class_hashes;
Expand Down
16 changes: 16 additions & 0 deletions crates/dojo-lang/src/inline_macros/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ pub fn parent_of_kind(
None
}

///
pub fn find_interface_path(cfg_set: &CfgSet, contract_tag: &str) -> anyhow::Result<String> {
let dojo_manifests_dir = get_dojo_manifests_dir(cfg_set.clone())?;

let base_dir = dojo_manifests_dir.join("base");
let base_manifest = BaseManifest::load_from_path(&base_dir)?;

for contract in base_manifest.contracts {
if contract.inner.tag == contract_tag {
return Ok(contract.inner.interface_path);
}
}

Err(anyhow::anyhow!("Unable to find the interface path of `{}`", contract_tag))
}

/// Reads all the models and namespaces from base manifests files.
pub fn load_manifest_models_and_namespaces(
cfg_set: &CfgSet,
Expand Down
11 changes: 9 additions & 2 deletions crates/dojo-lang/src/interface.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
use cairo_lang_defs::plugin::{
MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
DynGeneratedFileAuxData, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile,
PluginResult,
};
use cairo_lang_diagnostics::Severity;
use cairo_lang_plugins::plugins::HasItemsInCfgEx;
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{ast, ids, Terminal, TypedStablePtr, TypedSyntaxNode};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;

use crate::plugin::{DojoAuxData, InterfaceAuxData};
use crate::syntax::self_param;
use crate::syntax::world_param::{self, WorldParamInjectionKind};

Expand Down Expand Up @@ -84,7 +86,12 @@ impl DojoInterface {
code: Some(PluginGeneratedFile {
name: name.clone(),
content: code,
aux_data: None,
aux_data: Some(DynGeneratedFileAuxData::new(DojoAuxData {
models: vec![],
contracts: vec![],
events: vec![],
interfaces: vec![InterfaceAuxData { name: name.to_string() }],
})),
code_mappings,
}),
diagnostics: interface.diagnostics,
Expand Down
Loading

0 comments on commit 75a75dd

Please sign in to comment.