Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dispatcher_from_tag! macro #2416

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions crates/dojo-lang/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use starknet::core::types::contract::SierraClass;
use starknet::core::types::Felt;
use tracing::{debug, trace, trace_span};

use crate::plugin::{DojoAuxData, Model};
use crate::plugin::{DojoAuxData, Model, Trait};
use crate::scarb_internal::debug::SierraToCairoDebugInfo;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -350,6 +350,7 @@ fn update_files(

let mut models = BTreeMap::new();
let mut contracts = BTreeMap::new();
let mut dojo_interfaces = vec![];

if let Some(external_contracts) = external_contracts {
let external_crate_ids = collect_external_crate_ids(db, external_contracts);
Expand Down Expand Up @@ -378,6 +379,7 @@ fn update_files(
&naming::get_tag(&contract.namespace, &contract.name),
&compiled_artifacts,
&contract.systems,
&contract.traits,
)?);
}

Expand All @@ -392,6 +394,9 @@ fn update_files(
*module_id,
&compiled_artifacts,
)?);

// update the list of Dojo interface
dojo_interfaces.extend(dojo_aux_data.interfaces.iter().map(|i| i.name.clone()));
}

// StarknetAuxData shouldn't be required. Every dojo contract and model are starknet
Expand Down Expand Up @@ -422,7 +427,37 @@ fn update_files(
std::fs::create_dir_all(&base_contracts_abis_dir)?;
}

for (_, (manifest, module_id, artifact)) in contracts.iter_mut() {
for (qualified_path, (manifest, module_id, artifact, traits)) in contracts.iter_mut() {
// During compilation, a list of traits implemented inside the contract is extracted.
// We need to find which of these traits is a Dojo interface, to be able to store its qualified path,
// and be able to use it in macros such `dispatcher_from_tag!`
let found_interfaces =
traits.iter().filter(|t| dojo_interfaces.contains(&t.name)).collect::<Vec<_>>();

// a contract may or may not implement a Dojo interface, but it cannot implement several Dojo interfaces.
if found_interfaces.len() > 1 {
return Err(anyhow!(
"The contract '{}' cannot implement several Dojo interfaces (found: [{}]).",
manifest.inner.tag,
found_interfaces.iter().map(|t| t.name.clone()).collect::<Vec<_>>().join(", ")
));
}

manifest.inner.interface_path =
found_interfaces.first().map_or(String::new(), |x| x.path.clone());

// the Dojo interface path may start with `super`, referencing the trait which is
// in the same file than the contract.
// In this case, just replace `super` by the contract qualified path.
if manifest.inner.interface_path.starts_with("super") {
let (path, _) = qualified_path
.rsplit_once(CAIRO_PATH_SEPARATOR)
.unwrap_or((qualified_path.as_str(), ""));

manifest.inner.interface_path =
manifest.inner.interface_path.replacen("super", &path, 1);
}

write_manifest_and_abi(
&base_contracts_dir,
&base_contracts_abis_dir,
Expand Down Expand Up @@ -558,7 +593,8 @@ fn get_dojo_contract_artifacts(
tag: &str,
compiled_classes: &CompiledArtifactByPath,
systems: &[String],
) -> Result<HashMap<String, (Manifest<DojoContract>, ModuleId, CompiledArtifact)>> {
traits: &Vec<Trait>,
) -> Result<HashMap<String, (Manifest<DojoContract>, ModuleId, CompiledArtifact, Vec<Trait>)>> {
let mut result = HashMap::new();

if !matches!(naming::get_name_from_tag(tag).as_str(), "world" | "resource_metadata" | "base") {
Expand All @@ -584,7 +620,7 @@ fn get_dojo_contract_artifacts(

result.insert(
contract_qualified_path.to_string(),
(manifest, *module_id, artifact.clone()),
(manifest, *module_id, artifact.clone(), traits.clone()),
);
}
}
Expand Down
83 changes: 82 additions & 1 deletion crates/dojo-lang/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use dojo_types::system::Dependency;
use dojo_world::config::NamespaceConfig;
use dojo_world::contracts::naming;

use crate::plugin::{ContractAuxData, DojoAuxData, DOJO_CONTRACT_ATTR};
use crate::plugin::{ContractAuxData, DojoAuxData, Trait, DOJO_CONTRACT_ATTR};
use crate::syntax::world_param::{self, WorldParamInjectionKind};
use crate::syntax::{self_param, utils as syntax_utils};

Expand Down Expand Up @@ -50,6 +50,7 @@ impl DojoContract {

let mut diagnostics = vec![];
let parameters = get_parameters(db, module_ast, &mut diagnostics);
let traits = get_traits(db, module_ast);

let mut contract =
DojoContract { diagnostics, dependencies: HashMap::new(), systems: vec![] };
Expand Down Expand Up @@ -262,8 +263,10 @@ impl DojoContract {
namespace: contract_namespace.clone(),
dependencies: contract.dependencies.values().cloned().collect(),
systems: contract.systems.clone(),
traits,
}],
events: vec![],
interfaces: vec![],
})),
code_mappings,
}),
Expand Down Expand Up @@ -732,3 +735,81 @@ fn get_parameters(

parameters
}

fn get_traits(db: &dyn SyntaxGroup, module_ast: &ast::ItemModule) -> Vec<Trait> {
let traits = if let ast::MaybeModuleBody::Some(body) = module_ast.body(db) {
body.items(db)
.elements(db)
.iter()
.filter_map(|e| {
if let ast::ModuleItem::Impl(x) = e {
let mut path_segments = x.trait_path(db).elements(db);

// in Cairo, there is always a trait path linked to an impl, so path_segments always contain an item
let last_segment = path_segments.pop().unwrap();

// a dojo interface always has a generic <ContractState> argument, so just keep this kind of traits.
match last_segment {
ast::PathSegment::WithGenericArgs(p) => {
let trait_name = p.ident(db).text(db).to_string();

// Here, we have to rebuild the full trait path. There are several cases:
// 1) there is no path with the trait name (example: IActions)
// => find the trait path in `use` clauses.
// 2) the path is relative (example: `players::IActions`)
// => not possible to be sure that there is only one path in `use` clauses that matches
// with this relative path.
// (example: 2 use clauses: `use path1::players` and `use path2::players`).
// 3) the path is absolute (example: `path::to::players::IActions`)
// => just use this path.
//
// At the moment, only cases 1) and 3) are supported.
let trait_path = if path_segments.is_empty() {
get_trait_path(db, &body, trait_name.clone())
.expect("a path must always be found")
} else {
format!(
"{}::{trait_name}",
path_segments
.iter()
.map(|p| p.as_syntax_node().get_text(db))
.collect::<Vec<_>>()
.join("::")
)
};

Some(Trait { name: trait_name.clone(), path: trait_path })
}
_ => None,
}
} else {
None
}
})
.collect::<Vec<_>>()
} else {
vec![]
};

traits
}

fn get_trait_path(
db: &dyn SyntaxGroup,
module_body: &ast::ModuleBody,
trait_name: String,
) -> Option<String> {
for e in module_body.items(db).elements(db) {
if let ast::ModuleItem::Use(u) = e {
if let ast::UsePath::Single(s) = u.use_path(db) {
let trait_path = s.as_syntax_node().get_text(db);

if trait_path.ends_with(&trait_name) {
return Some(trait_path);
}
}
}
}

None
}
93 changes: 93 additions & 0 deletions crates/dojo-lang/src/inline_macros/dispatcher_from_tag.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use cairo_lang_defs::patcher::PatchBuilder;
use cairo_lang_defs::plugin::{
InlineMacroExprPlugin, InlinePluginResult, MacroPluginMetadata, NamedPlugin, PluginDiagnostic,
PluginGeneratedFile,
};
use cairo_lang_defs::plugin_utils::unsupported_bracket_diagnostic;
use cairo_lang_diagnostics::Severity;
use cairo_lang_syntax::node::{ast, TypedStablePtr, TypedSyntaxNode};
use dojo_world::contracts::naming;

use super::utils::find_interface_path;

#[derive(Debug, Default)]
pub struct DispatcherFromTagMacro;

impl NamedPlugin for DispatcherFromTagMacro {
const NAME: &'static str = "dispatcher_from_tag";
}

impl InlineMacroExprPlugin for DispatcherFromTagMacro {
fn generate_code(
&self,
db: &dyn cairo_lang_syntax::node::db::SyntaxGroup,
syntax: &ast::ExprInlineMacro,
metadata: &MacroPluginMetadata<'_>,
) -> InlinePluginResult {
let ast::WrappedArgList::ParenthesizedArgList(arg_list) = syntax.arguments(db) else {
return unsupported_bracket_diagnostic(db, syntax);
};

let args = arg_list.arguments(db).elements(db);

if args.len() != 2 {
return InlinePluginResult {
code: None,
diagnostics: vec![PluginDiagnostic {
stable_ptr: syntax.stable_ptr().untyped(),
message: "Invalid arguments. Expected dispatcher_from_tag!(\"tag\", contract_address)"
.to_string(),
severity: Severity::Error,
}],
};
}

let tag = &args[0].as_syntax_node().get_text(db).replace('\"', "");
let contract_address = args[1].as_syntax_node().get_text(db);

if !naming::is_valid_tag(tag) {
return InlinePluginResult {
code: None,
diagnostics: vec![PluginDiagnostic {
stable_ptr: syntax.stable_ptr().untyped(),
message: "Invalid tag. Tag must be in the format of `namespace-name`."
.to_string(),
severity: Severity::Error,
}],
};
}

// read the interface path from the manifest and generate a dispatcher:
// <interface_path>Dispatcher { contract_address };
let interface_path = match find_interface_path(metadata.cfg_set, tag) {
Ok(interface_path) => interface_path,
Err(_e) => {
return InlinePluginResult {
code: None,
diagnostics: vec![PluginDiagnostic {
stable_ptr: syntax.stable_ptr().untyped(),
message: format!("Failed to find the interface path of `{tag}`"),
severity: Severity::Error,
}],
};
}
};

let mut builder = PatchBuilder::new(db, syntax);
builder.add_str(&format!(
"{interface_path}Dispatcher {{ contract_address: {contract_address}}}",
));

let (code, code_mappings) = builder.build();

InlinePluginResult {
code: Some(PluginGeneratedFile {
name: "dispatcher_from_tag_macro".into(),
content: code,
code_mappings,
aux_data: None,
}),
diagnostics: vec![],
}
}
}
1 change: 1 addition & 0 deletions crates/dojo-lang/src/inline_macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use cairo_lang_syntax::node::{ast, Terminal, TypedStablePtr, TypedSyntaxNode};
use smol_str::SmolStr;

pub mod delete;
pub mod dispatcher_from_tag;
pub mod emit;
pub mod get;
pub mod get_models_test_class_hashes;
Expand Down
16 changes: 16 additions & 0 deletions crates/dojo-lang/src/inline_macros/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ pub fn parent_of_kind(
None
}

///
pub fn find_interface_path(cfg_set: &CfgSet, contract_tag: &str) -> anyhow::Result<String> {
let dojo_manifests_dir = get_dojo_manifests_dir(cfg_set.clone())?;

let base_dir = dojo_manifests_dir.join("base");
let base_manifest = BaseManifest::load_from_path(&base_dir)?;

for contract in base_manifest.contracts {
if contract.inner.tag == contract_tag {
return Ok(contract.inner.interface_path);
}
}

Err(anyhow::anyhow!("Unable to find the interface path of `{}`", contract_tag))
}

/// Reads all the models and namespaces from base manifests files.
pub fn load_manifest_models_and_namespaces(
cfg_set: &CfgSet,
Expand Down
11 changes: 9 additions & 2 deletions crates/dojo-lang/src/interface.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
use cairo_lang_defs::plugin::{
MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
DynGeneratedFileAuxData, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile,
PluginResult,
};
use cairo_lang_diagnostics::Severity;
use cairo_lang_plugins::plugins::HasItemsInCfgEx;
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{ast, ids, Terminal, TypedStablePtr, TypedSyntaxNode};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;

use crate::plugin::{DojoAuxData, InterfaceAuxData};
use crate::syntax::self_param;
use crate::syntax::world_param::{self, WorldParamInjectionKind};

Expand Down Expand Up @@ -84,7 +86,12 @@ impl DojoInterface {
code: Some(PluginGeneratedFile {
name: name.clone(),
content: code,
aux_data: None,
aux_data: Some(DynGeneratedFileAuxData::new(DojoAuxData {
models: vec![],
contracts: vec![],
events: vec![],
interfaces: vec![InterfaceAuxData { name: name.to_string() }],
})),
code_mappings,
}),
diagnostics: interface.diagnostics,
Expand Down
Loading
Loading