From 640bfd35d83d4a927d77a786fcb4f8e1a666d1ca Mon Sep 17 00:00:00 2001 From: Shramee Srivastav Date: Fri, 27 Oct 2023 20:22:21 +0530 Subject: [PATCH] semantic: model r/w access (#1102) * semantic: model r/w access * chore: tests and review :wrench: * dev: handle module name error --- crates/dojo-lang/src/compiler.rs | 45 ++++- crates/dojo-lang/src/inline_macros/get.rs | 23 ++- crates/dojo-lang/src/inline_macros/mod.rs | 2 + crates/dojo-lang/src/inline_macros/set.rs | 92 ++++++++- crates/dojo-lang/src/inline_macros/utils.rs | 33 ++++ .../dojo-lang/src/manifest_test_data/manifest | 16 +- crates/dojo-lang/src/semantics/mod.rs | 2 + crates/dojo-lang/src/semantics/utils.rs | 179 ++++++++++++++++++ crates/dojo-world/src/manifest.rs | 2 + crates/dojo-world/src/migration/world_test.rs | 4 +- 10 files changed, 383 insertions(+), 15 deletions(-) create mode 100644 crates/dojo-lang/src/inline_macros/utils.rs create mode 100644 crates/dojo-lang/src/semantics/utils.rs diff --git a/crates/dojo-lang/src/compiler.rs b/crates/dojo-lang/src/compiler.rs index 25532adfbd..cd15ba30e2 100644 --- a/crates/dojo-lang/src/compiler.rs +++ b/crates/dojo-lang/src/compiler.rs @@ -1,9 +1,10 @@ -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::iter::zip; use std::ops::{Deref, DerefMut}; use anyhow::{anyhow, Context, Result}; use cairo_lang_compiler::db::RootDatabase; +use cairo_lang_defs::db::DefsGroup; use cairo_lang_defs::ids::{ModuleId, ModuleItemId}; use cairo_lang_filesystem::db::FilesGroup; use cairo_lang_filesystem::ids::{CrateId, CrateLongId}; @@ -27,7 +28,9 @@ use starknet::core::types::contract::SierraClass; use starknet::core::types::FieldElement; use tracing::{debug, trace, trace_span}; +use crate::inline_macros::utils::{SYSTEM_READS, SYSTEM_WRITES}; use crate::plugin::DojoAuxData; +use crate::semantics::utils::find_module_rw; const CAIRO_PATH_SEPARATOR: &str = "::"; @@ -202,7 +205,7 @@ pub fn collect_external_crate_ids( fn update_manifest( manifest: &mut dojo_world::manifest::Manifest, - db: &dyn SemanticGroup, + db: &RootDatabase, crate_ids: &[CrateId], compiled_artifacts: HashMap)>, ) -> anyhow::Result<()> { @@ -254,7 +257,12 @@ fn update_manifest( .filter_map(|aux_data| aux_data.as_ref().map(|aux_data| aux_data.0.as_any())) { if let Some(aux_data) = aux_data.downcast_ref::() { - contracts.extend(get_dojo_contract_artifacts(aux_data, &compiled_artifacts)?); + contracts.extend(get_dojo_contract_artifacts( + db, + module_id, + aux_data, + &compiled_artifacts, + )?); } if let Some(dojo_aux_data) = aux_data.downcast_ref::() { @@ -315,6 +323,8 @@ fn get_dojo_model_artifacts( } fn get_dojo_contract_artifacts( + db: &RootDatabase, + module_id: &ModuleId, aux_data: &StarkNetContractAuxData, compiled_classes: &HashMap)>, ) -> anyhow::Result> { @@ -323,11 +333,38 @@ fn get_dojo_contract_artifacts( .iter() .filter(|name| !matches!(name.as_ref(), "world" | "executor" | "base")) .map(|name| { + let module_name = module_id.full_path(db); + let module_last_name = module_name.split("::").last().unwrap(); + + let reads = match SYSTEM_READS.lock().unwrap().get(module_last_name) { + Some(models) => { + models.clone().into_iter().collect::>().into_iter().collect() + } + None => vec![], + }; + + let write_entries = SYSTEM_WRITES.lock().unwrap(); + let writes = match write_entries.get(module_last_name) { + Some(write_ops) => find_module_rw(db, module_id, write_ops), + None => vec![], + }; + let (class_hash, abi) = compiled_classes .get(name) .cloned() .ok_or(anyhow!("Contract {name} not found in target."))?; - Ok((name.clone(), Contract { name: name.clone(), class_hash, abi, address: None })) + + Ok(( + name.clone(), + Contract { + name: name.clone(), + class_hash, + abi, + writes, + reads, + ..Default::default() + }, + )) }) .collect::>() } diff --git a/crates/dojo-lang/src/inline_macros/get.rs b/crates/dojo-lang/src/inline_macros/get.rs index 9382921a0c..907d132545 100644 --- a/crates/dojo-lang/src/inline_macros/get.rs +++ b/crates/dojo-lang/src/inline_macros/get.rs @@ -3,10 +3,12 @@ use cairo_lang_defs::plugin::{ InlineMacroExprPlugin, InlinePluginResult, PluginDiagnostic, PluginGeneratedFile, }; use cairo_lang_semantic::inline_macros::unsupported_bracket_diagnostic; -use cairo_lang_syntax::node::ast::Expr; +use cairo_lang_syntax::node::ast::{Expr, ItemModule}; +use cairo_lang_syntax::node::kind::SyntaxKind; use cairo_lang_syntax::node::{ast, TypedSyntaxNode}; use itertools::Itertools; +use super::utils::{parent_of_kind, SYSTEM_READS}; use super::{extract_models, unsupported_arg_diagnostic, CAIRO_ERR_MSG_LEN}; #[derive(Debug)] @@ -78,7 +80,26 @@ impl InlineMacroExprPlugin for GetMacro { let __get_macro_keys__ = array::ArrayTrait::span(@__get_macro_keys__);\n" )); + let mut system_reads = SYSTEM_READS.lock().unwrap(); + + let module_syntax_node = + parent_of_kind(db, &syntax.as_syntax_node(), SyntaxKind::ItemModule); + let module_name = if let Some(module_syntax_node) = &module_syntax_node { + let mod_ast = ItemModule::from_syntax_node(db, module_syntax_node.clone()); + mod_ast.name(db).as_syntax_node().get_text_without_trivia(db) + } else { + eprintln!("Error: Couldn't get the module name."); + "".into() + }; + for model in &models { + if !module_name.is_empty() { + if system_reads.get(&module_name).is_none() { + system_reads.insert(module_name.clone(), vec![model.to_string()]); + } else { + system_reads.get_mut(&module_name).unwrap().push(model.to_string()); + } + } let mut lookup_err_msg = format!("{} not found", model.to_string()); lookup_err_msg.truncate(CAIRO_ERR_MSG_LEN); let mut deser_err_msg = format!("{} failed to deserialize", model.to_string()); diff --git a/crates/dojo-lang/src/inline_macros/mod.rs b/crates/dojo-lang/src/inline_macros/mod.rs index 74aa1c9fb9..2cc237256d 100644 --- a/crates/dojo-lang/src/inline_macros/mod.rs +++ b/crates/dojo-lang/src/inline_macros/mod.rs @@ -6,6 +6,7 @@ use smol_str::SmolStr; pub mod emit; pub mod get; pub mod set; +pub mod utils; const CAIRO_ERR_MSG_LEN: usize = 31; @@ -71,6 +72,7 @@ pub fn extract_models( Ok(models) } + pub fn unsupported_arg_diagnostic( db: &dyn SyntaxGroup, macro_ast: &ast::ExprInlineMacro, diff --git a/crates/dojo-lang/src/inline_macros/set.rs b/crates/dojo-lang/src/inline_macros/set.rs index 8708666c86..af3777d465 100644 --- a/crates/dojo-lang/src/inline_macros/set.rs +++ b/crates/dojo-lang/src/inline_macros/set.rs @@ -1,16 +1,35 @@ +use std::collections::HashMap; + use cairo_lang_defs::patcher::PatchBuilder; use cairo_lang_defs::plugin::{ InlineMacroExprPlugin, InlinePluginResult, PluginDiagnostic, PluginGeneratedFile, }; use cairo_lang_semantic::inline_macros::unsupported_bracket_diagnostic; +use cairo_lang_syntax::node::ast::{ExprPath, ExprStructCtorCall, FunctionWithBody, ItemModule}; +use cairo_lang_syntax::node::kind::SyntaxKind; use cairo_lang_syntax::node::{ast, TypedSyntaxNode}; use super::unsupported_arg_diagnostic; +use super::utils::{parent_of_kind, SystemRWOpRecord, SYSTEM_WRITES}; #[derive(Debug)] pub struct SetMacro; impl SetMacro { pub const NAME: &'static str = "set"; + // Parents of set!() + // ----------------- + // StatementExpr + // StatementList + // ExprBlock + // FunctionWithBody + // ImplItemList + // ImplBody + // ItemImpl + // ItemList + // ModuleBody + // ItemModule + // ItemList + // SyntaxFile } impl InlineMacroExprPlugin for SetMacro { fn generate_code( @@ -46,12 +65,19 @@ impl InlineMacroExprPlugin for SetMacro { match models.value(db) { ast::Expr::Parenthesized(parens) => { - bundle.push(parens.expr(db).as_syntax_node().get_text(db)) + let syntax_node = parens.expr(db).as_syntax_node(); + bundle.push((syntax_node.get_text(db), syntax_node)); + } + ast::Expr::Tuple(list) => { + list.expressions(db).elements(db).into_iter().for_each(|expr| { + let syntax_node = expr.as_syntax_node(); + bundle.push((syntax_node.get_text(db), syntax_node)); + }) + } + ast::Expr::StructCtorCall(ctor) => { + let syntax_node = ctor.as_syntax_node(); + bundle.push((syntax_node.get_text(db), syntax_node)); } - ast::Expr::Tuple(list) => list.expressions(db).elements(db).iter().for_each(|expr| { - bundle.push(expr.as_syntax_node().get_text(db)); - }), - ast::Expr::StructCtorCall(ctor) => bundle.push(ctor.as_syntax_node().get_text(db)), _ => { return InlinePluginResult { code: None, @@ -73,7 +99,61 @@ impl InlineMacroExprPlugin for SetMacro { }; } - for entity in bundle { + let module_syntax_node = + parent_of_kind(db, &syntax.as_syntax_node(), SyntaxKind::ItemModule); + let module_name = if let Some(module_syntax_node) = &module_syntax_node { + let mod_ast = ItemModule::from_syntax_node(db, module_syntax_node.clone()); + mod_ast.name(db).as_syntax_node().get_text_without_trivia(db) + } else { + eprintln!("Error: Couldn't get the module name."); + "".into() + }; + + let fn_syntax_node = + parent_of_kind(db, &syntax.as_syntax_node(), SyntaxKind::FunctionWithBody); + let fn_name = if let Some(fn_syntax_node) = &fn_syntax_node { + let fn_ast = FunctionWithBody::from_syntax_node(db, fn_syntax_node.clone()); + fn_ast.declaration(db).name(db).as_syntax_node().get_text_without_trivia(db) + } else { + // Unlikely to get here, but if we do. + eprintln!("Error: Couldn't get the function name."); + "".into() + }; + + for (entity, syntax_node) in bundle { + // db.lookup_intern_file(key0); + if !module_name.is_empty() && !fn_name.is_empty() { + let mut system_writes = SYSTEM_WRITES.lock().unwrap(); + // fn_syntax_node + if system_writes.get(&module_name).is_none() { + system_writes.insert(module_name.clone(), HashMap::new()); + } + let fns = system_writes.get_mut(&module_name).unwrap(); + if fns.get(&fn_name).is_none() { + fns.insert(fn_name.clone(), vec![]); + } + + match syntax_node.kind(db) { + SyntaxKind::ExprPath => { + fns.get_mut(&fn_name).unwrap().push(SystemRWOpRecord::Path( + ExprPath::from_syntax_node(db, syntax_node), + )); + } + // SyntaxKind::StatementExpr => { + // todo!() + // } + SyntaxKind::ExprStructCtorCall => { + fns.get_mut(&fn_name).unwrap().push(SystemRWOpRecord::StructCtor( + ExprStructCtorCall::from_syntax_node(db, syntax_node.clone()), + )); + } + _ => eprintln!( + "Unsupport component value type {} for semantic writer analysis", + syntax_node.kind(db) + ), + } + } + builder.add_str(&format!( " let __set_macro_value__ = {}; diff --git a/crates/dojo-lang/src/inline_macros/utils.rs b/crates/dojo-lang/src/inline_macros/utils.rs new file mode 100644 index 0000000000..46d23e7280 --- /dev/null +++ b/crates/dojo-lang/src/inline_macros/utils.rs @@ -0,0 +1,33 @@ +use std::collections::HashMap; +use std::sync::Mutex; + +use cairo_lang_syntax::node::ast::{ExprPath, ExprStructCtorCall}; +use cairo_lang_syntax::node::kind::SyntaxKind; +use cairo_lang_syntax::node::SyntaxNode; + +type ModuleName = String; +type FunctionName = String; +lazy_static::lazy_static! { + pub static ref SYSTEM_WRITES: Mutex>>> = Default::default(); + pub static ref SYSTEM_READS: Mutex>> = Default::default(); +} + +pub enum SystemRWOpRecord { + StructCtor(ExprStructCtorCall), + Path(ExprPath), +} + +pub fn parent_of_kind( + db: &dyn cairo_lang_syntax::node::db::SyntaxGroup, + target: &SyntaxNode, + kind: SyntaxKind, +) -> Option { + let mut new_target = target.clone(); + while let Some(parent) = new_target.parent() { + if kind == parent.kind(db) { + return Some(parent); + } + new_target = parent; + } + None +} diff --git a/crates/dojo-lang/src/manifest_test_data/manifest b/crates/dojo-lang/src/manifest_test_data/manifest index a66ad52ff3..fc73b21559 100644 --- a/crates/dojo-lang/src/manifest_test_data/manifest +++ b/crates/dojo-lang/src/manifest_test_data/manifest @@ -739,7 +739,9 @@ test_manifest_file } ] } - ] + ], + "reads": [], + "writes": [] }, "executor": { "name": "executor", @@ -797,7 +799,9 @@ test_manifest_file "kind": "enum", "variants": [] } - ] + ], + "reads": [], + "writes": [] }, "base": { "name": "base", @@ -985,6 +989,14 @@ test_manifest_file } ] } + ], + "reads": [ + "Moves", + "Position" + ], + "writes": [ + "Moves", + "Position" ] } ], diff --git a/crates/dojo-lang/src/semantics/mod.rs b/crates/dojo-lang/src/semantics/mod.rs index faf49772b9..4570554d21 100644 --- a/crates/dojo-lang/src/semantics/mod.rs +++ b/crates/dojo-lang/src/semantics/mod.rs @@ -1,3 +1,5 @@ +pub mod utils; + #[cfg(test)] pub mod test_utils; diff --git a/crates/dojo-lang/src/semantics/utils.rs b/crates/dojo-lang/src/semantics/utils.rs new file mode 100644 index 0000000000..a1bf09bee6 --- /dev/null +++ b/crates/dojo-lang/src/semantics/utils.rs @@ -0,0 +1,179 @@ +use std::collections::{BTreeSet, HashMap}; + +use cairo_lang_compiler::db::RootDatabase; +use cairo_lang_defs::db::DefsGroup; +use cairo_lang_defs::ids::{FunctionWithBodyId, LookupItemId, ModuleId, ModuleItemId}; +use cairo_lang_lowering::db::LoweringGroup; +use cairo_lang_lowering::ids::{self as low, SemanticFunctionWithBodyIdEx}; +use cairo_lang_lowering::Statement; +use cairo_lang_semantic as semantic; +use cairo_lang_syntax::node::{ast, SyntaxNode, TypedSyntaxNode}; +use semantic::db::SemanticGroup; +use semantic::diagnostic::SemanticDiagnostics; +use semantic::expr::compute::{ComputationContext, Environment}; +use semantic::expr::inference::InferenceId; +use semantic::items::function_with_body::SemanticExprLookup; +use semantic::resolve::Resolver; +use semantic::FunctionId; + +use crate::inline_macros::utils::SystemRWOpRecord; + +pub fn find_module_rw( + db: &RootDatabase, + module_id: &ModuleId, + module_writes: &HashMap>, +) -> Vec { + let mut models: BTreeSet = BTreeSet::new(); + if let Ok(module_fns) = db.module_free_functions_ids(*module_id) { + for fn_id in module_fns.iter() { + find_function_rw( + db, + module_id, + module_writes, + FunctionWithBodyId::Free(*fn_id), + &mut models, + ); + } + } + if let Ok(module_impls) = db.module_impls_ids(*module_id) { + for module_impl_id in module_impls.iter() { + if let Ok(module_fns) = db.impl_functions(*module_impl_id) { + for (_, fn_id) in module_fns.iter() { + find_function_rw( + db, + module_id, + module_writes, + FunctionWithBodyId::Impl(*fn_id), + &mut models, + ); + } + } + } + } + + models.into_iter().collect() +} + +pub fn find_function_rw( + db: &RootDatabase, + _module_id: &ModuleId, + module_writes: &HashMap>, + fn_id: FunctionWithBodyId, + models: &mut BTreeSet, +) { + let fn_name: String = fn_id.name(db).into(); + if let Some(module_fn_writes) = module_writes.get(&fn_name) { + // This functions has write ops, find models access + for writer_lookup in module_fn_writes.iter() { + match writer_lookup { + SystemRWOpRecord::StructCtor(expr) => { + let component = expr.path(db).as_syntax_node().get_text_without_trivia(db); + models.insert(component); + } + SystemRWOpRecord::Path(_expr_path) => { + let fn_id_low = fn_id.lowered(db); + + let flat_lowered = db.function_with_body_lowering(fn_id_low).unwrap(); + for (_, flat_block) in flat_lowered.blocks.iter() { + let mut last_layout_fn_semantic: Option = None; + + for statement in flat_block.statements.iter() { + if let Statement::Call(statement_call) = statement { + if let low::FunctionLongId::Semantic(fn_id) = + db.lookup_intern_lowering_function(statement_call.function) + { + if let Ok(Some(conc_body_fn)) = fn_id.get_concrete(db).body(db) + { + let fn_body_id = conc_body_fn.function_with_body_id(db); + let fn_name = fn_body_id.name(db); + if fn_name == "set_entity" { + if let Some(layout_fn) = last_layout_fn_semantic { + match db.concrete_function_signature(layout_fn) { + Ok(signature) => { + if let Some(params) = + signature.params.get(0) + { + // looks like + // "@dojo_examples::models::Position" + let component = params.ty.format(db); + let component_segments = + component.split("::"); + let component = + component_segments.last().expect( + "layout signature params not \ + found", + ); + models.insert(component.into()); + } + } + Err(_) => { + eprintln!( + "error: could't get entity model(s)" + ) + } + } + } else { + eprintln!( + "type reference not found for set_entity" + ); + } + } else if fn_name == "layout" { + last_layout_fn_semantic = Some(fn_id); + } + } + } + } + } + } + } + } + } + } +} + +pub fn function_resolver(db: &RootDatabase, fn_id: FunctionWithBodyId) -> Resolver<'_> { + let resolver_data = match fn_id { + FunctionWithBodyId::Free(fn_id) => { + let interference = InferenceId::LookupItemDefinition(LookupItemId::ModuleItem( + ModuleItemId::FreeFunction(fn_id), + )); + db.free_function_body_resolver_data(fn_id) + .unwrap() + .clone_with_inference_id(db, interference) + } + FunctionWithBodyId::Impl(fn_id) => { + let interference = InferenceId::LookupItemDefinition(LookupItemId::ImplFunction(fn_id)); + db.impl_function_body_resolver_data(fn_id) + .unwrap() + .clone_with_inference_id(db, interference) + } + }; + Resolver::with_data(db, resolver_data) +} +/// Returns the semantic expression for the current node. +pub fn nearest_semantic_expr( + db: &dyn SemanticGroup, + mut node: SyntaxNode, + function_id: FunctionWithBodyId, +) -> Option { + loop { + let syntax_db = db.upcast(); + if ast::Expr::is_variant(node.kind(syntax_db)) { + let expr_node = ast::Expr::from_syntax_node(syntax_db, node.clone()); + if let Ok(expr_id) = db.lookup_expr_by_ptr(function_id, expr_node.stable_ptr()) { + let semantic_expr = db.expr_semantic(function_id, expr_id); + return Some(semantic_expr); + } + } + node = node.parent()?; + } +} + +pub fn semantic_computation_ctx<'a>( + db: &'a RootDatabase, + fn_id: FunctionWithBodyId, + resolver: Resolver<'a>, + diagnostics: &'a mut SemanticDiagnostics, +) -> ComputationContext<'a> { + ComputationContext::new(db, diagnostics, Some(fn_id), resolver, None, Environment::default()) +} diff --git a/crates/dojo-world/src/manifest.rs b/crates/dojo-world/src/manifest.rs index f6e1ad9028..18dccc3bf2 100644 --- a/crates/dojo-world/src/manifest.rs +++ b/crates/dojo-world/src/manifest.rs @@ -104,6 +104,8 @@ pub struct Contract { #[serde_as(as = "UfeHex")] pub class_hash: FieldElement, pub abi: Option, + pub reads: Vec, + pub writes: Vec, } #[serde_as] diff --git a/crates/dojo-world/src/migration/world_test.rs b/crates/dojo-world/src/migration/world_test.rs index 5fef4da033..daf9be2367 100644 --- a/crates/dojo-world/src/migration/world_test.rs +++ b/crates/dojo-world/src/migration/world_test.rs @@ -73,13 +73,13 @@ fn diff_when_local_and_remote_are_different() { name: "my_contract".into(), class_hash: felt!("0x1111"), address: Some(felt!("0x2222")), - abi: None, + ..Contract::default() }, Contract { name: "my_contract_2".into(), class_hash: felt!("0x3333"), address: Some(felt!("4444")), - abi: None, + ..Contract::default() }, ];