Skip to content

Commit

Permalink
get version from dojo::model attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
remybar committed Mar 17, 2024
1 parent 26d44f7 commit c82fe04
Show file tree
Hide file tree
Showing 3 changed files with 1,329 additions and 52 deletions.
130 changes: 128 additions & 2 deletions crates/dojo-lang/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,139 @@
use std::cmp::Ordering;

use cairo_lang_defs::patcher::RewriteNode;
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_diagnostics::Severity;
use cairo_lang_syntax::node::ast::ItemStruct;
use cairo_lang_syntax::node::ast::{ArgClause, Expr, ItemStruct, OptionArgListParenthesized};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::helpers::QueryAttrs;
use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use convert_case::{Case, Casing};
use dojo_world::manifest::Member;

use crate::plugin::{DojoAuxData, Model};
use crate::plugin::{DojoAuxData, Model, DOJO_MODEL_ATTR};

const CURRENT_MODEL_VERSION: u8 = 1;
const MODEL_VERSION_NAME: &str = "version";

/// Get the version associated with the dojo::model attribute.
///
/// Note: dojo::model attribute has already been checked so there is one and only one attribute.
///
/// Parameters:
/// * db: The semantic database.
/// * struct_ast: The AST of the model struct.
/// * diagnostics: vector of compiler diagnostics.
///
/// Returns:
/// * The model version associated with the dojo:model attribute.
pub fn get_model_version(
db: &dyn SyntaxGroup,
struct_ast: ItemStruct,
diagnostics: &mut Vec<PluginDiagnostic>,
) -> u8 {
if let OptionArgListParenthesized::ArgListParenthesized(arguments) =
struct_ast.attributes(db).query_attr(db, DOJO_MODEL_ATTR).first().unwrap().arguments(db)
{
let version_args = arguments
.arguments(db)
.elements(db)
.iter()
.filter_map(|a| match a.arg_clause(db) {
ArgClause::Named(x) => {
let arg_name = x.name(db).text(db).to_string();
if arg_name.eq(MODEL_VERSION_NAME) {
Some(x.value(db))
} else {
diagnostics.push(PluginDiagnostic {
message: format!("Unexpected argument '{}' for dojo::model", arg_name),
stable_ptr: x.stable_ptr().untyped(),
severity: Severity::Warning,
});
None
}
}
ArgClause::Unnamed(x) => {
diagnostics.push(PluginDiagnostic {
message: format!(
"Unexpected argument '{}' for dojo::model",
x.as_syntax_node().get_text(db)
),
stable_ptr: x.stable_ptr().untyped(),
severity: Severity::Warning,
});
None
}
ArgClause::FieldInitShorthand(x) => {
diagnostics.push(PluginDiagnostic {
message: format!(
"Unexpected argument '{}' for dojo::model",
x.name(db).name(db).text(db).to_string()
),
stable_ptr: x.stable_ptr().untyped(),
severity: Severity::Warning,
});
None
}
})
.collect::<Vec<_>>();

let version = match version_args.len().cmp(&1) {
Ordering::Equal => match version_args.first().unwrap() {
Expr::Literal(v) => {
if let Ok(int_value) = v.text(db).parse::<u8>() {
if int_value <= CURRENT_MODEL_VERSION {
Some(int_value)
} else {
diagnostics.push(PluginDiagnostic {
message: format!("dojo::model version {} not supported", int_value),
stable_ptr: v.stable_ptr().untyped(),
severity: Severity::Error,
});
None
}
} else {
diagnostics.push(PluginDiagnostic {
message: format!(
"The argument '{}' of dojo::model must be an integer",
MODEL_VERSION_NAME
),
stable_ptr: struct_ast.stable_ptr().untyped(),
severity: Severity::Error,
});
None
}
}
_ => {
diagnostics.push(PluginDiagnostic {
message: format!(
"The argument '{}' of dojo::model must be an integer",
MODEL_VERSION_NAME
),
stable_ptr: struct_ast.stable_ptr().untyped(),
severity: Severity::Error,
});
None
}
},
Ordering::Greater => {
diagnostics.push(PluginDiagnostic {
message: format!(
"Too many '{}' attributes for dojo::model",
MODEL_VERSION_NAME
),
stable_ptr: struct_ast.stable_ptr().untyped(),
severity: Severity::Error,
});
None
}
Ordering::Less => None,
};

return if let Some(v) = version { v } else { CURRENT_MODEL_VERSION };
}
CURRENT_MODEL_VERSION
}

/// A handler for Dojo code that modifies a model struct.
/// Parameters:
Expand All @@ -24,6 +148,8 @@ pub fn handle_model_struct(
) -> (RewriteNode, Vec<PluginDiagnostic>) {
let mut diagnostics = vec![];

let _version = get_model_version(db, struct_ast.clone(), &mut diagnostics);

let elements = struct_ast.members(db).elements(db);
let members: &Vec<_> = &elements
.iter()
Expand Down
27 changes: 21 additions & 6 deletions crates/dojo-lang/src/plugin.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::cmp::Ordering;

use anyhow::Result;
use cairo_lang_defs::patcher::PatchBuilder;
use cairo_lang_defs::plugin::{
Expand Down Expand Up @@ -35,7 +37,7 @@ use crate::print::{handle_print_enum, handle_print_struct};

const DOJO_CONTRACT_ATTR: &str = "dojo::contract";
const DOJO_INTERFACE_ATTR: &str = "dojo::interface";
const DOJO_MODEL_ATTR: &str = "dojo::model";
pub const DOJO_MODEL_ATTR: &str = "dojo::model";

#[derive(Clone, Debug, PartialEq)]
pub struct Model {
Expand Down Expand Up @@ -393,11 +395,24 @@ impl MacroPlugin for BuiltinDojoPlugin {
}
}

for _ in struct_ast.attributes(db).query_attr(db, "dojo::model") {
let (model_rewrite_nodes, model_diagnostics) =
handle_model_struct(db, &mut aux_data, struct_ast.clone());
rewrite_nodes.push(model_rewrite_nodes);
diagnostics.extend(model_diagnostics);
let attributes = struct_ast.attributes(db).query_attr(db, DOJO_MODEL_ATTR);

match attributes.len().cmp(&1) {
Ordering::Equal => {
let (model_rewrite_nodes, model_diagnostics) =
handle_model_struct(db, &mut aux_data, struct_ast.clone());
rewrite_nodes.push(model_rewrite_nodes);
diagnostics.extend(model_diagnostics);
}
Ordering::Greater => {
diagnostics.push(PluginDiagnostic {
message: "A Dojo model must have zero or one dojo::model attribute."
.into(),
stable_ptr: struct_ast.stable_ptr().untyped(),
severity: Severity::Error,
});
}
_ => {}
}

if rewrite_nodes.is_empty() {
Expand Down
Loading

0 comments on commit c82fe04

Please sign in to comment.