From 5d85be7ab6a2d03c26260877d2b01045ff54cf5b Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Thu, 26 Dec 2024 10:22:44 +0200 Subject: [PATCH] Using full `Arenas` in constants. commit-id:7edc92a0 --- .../cairo-lang-semantic/src/items/constant.rs | 49 +++++++++---------- crates/cairo-lang-starknet/src/contract.rs | 5 +- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/crates/cairo-lang-semantic/src/items/constant.rs b/crates/cairo-lang-semantic/src/items/constant.rs index c5bb6dbfc67..8824cff4ff1 100644 --- a/crates/cairo-lang-semantic/src/items/constant.rs +++ b/crates/cairo-lang-semantic/src/items/constant.rs @@ -13,7 +13,6 @@ use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode}; use cairo_lang_utils::{ Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches, }; -use id_arena::Arena; use itertools::Itertools; use num_bigint::BigInt; use num_traits::{Num, ToPrimitive, Zero}; @@ -37,7 +36,7 @@ use crate::resolve::{Resolver, ResolverData}; use crate::substitution::SemanticRewriter; use crate::types::resolve_type; use crate::{ - ConcreteTypeId, ConcreteVariant, Expr, ExprBlock, ExprConstant, ExprFunctionCall, + Arenas, ConcreteTypeId, ConcreteVariant, Expr, ExprBlock, ExprConstant, ExprFunctionCall, ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, FunctionId, GenericParam, SemanticDiagnostic, TypeId, TypeLongId, semantic_object_for_id, }; @@ -48,12 +47,12 @@ pub struct Constant { /// The actual id of the const expression value. pub value: ExprId, /// The arena of all the expressions for the const calculation. - pub exprs: Arc>, + pub arenas: Arc, } impl Constant { pub fn ty(&self) -> TypeId { - self.exprs[self.value].ty() + self.arenas.exprs[self.value].ty() } } @@ -326,7 +325,7 @@ pub fn constant_semantic_data_helper( .rewrite(const_value) .unwrap_or_else(|_| ConstValue::Missing(skip_diagnostic()).intern(db)); let resolver_data = Arc::new(ctx.resolver.data); - let constant = Constant { value: value.id, exprs: Arc::new(ctx.arenas.exprs) }; + let constant = Constant { value: value.id, arenas: Arc::new(ctx.arenas) }; Ok(ConstantData { diagnostics: diagnostics.build(), const_value, @@ -387,7 +386,7 @@ pub fn resolve_const_expr_and_evaluate( match &value.expr { Expr::Constant(ExprConstant { const_value_id, .. }) => const_value_id.lookup_intern(db), // Check that the expression is a valid constant. - _ => evaluate_constant_expr(db, &ctx.arenas.exprs, value.id, ctx.diagnostics), + _ => evaluate_constant_expr(db, &ctx.arenas, value.id, ctx.diagnostics), } } @@ -425,18 +424,18 @@ pub fn value_as_const_value( /// evaluate the given const expression value. pub fn evaluate_constant_expr( db: &dyn SemanticGroup, - exprs: &Arena, + arenas: &Arenas, expr_id: ExprId, diagnostics: &mut SemanticDiagnostics, ) -> ConstValue { - let expr = &exprs[expr_id]; + let expr = &arenas.exprs[expr_id]; match expr { Expr::Constant(expr) => expr.const_value_id.lookup_intern(db), Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) if statements.is_empty() => { - evaluate_constant_expr(db, exprs, *inner, diagnostics) + evaluate_constant_expr(db, arenas, *inner, diagnostics) } - Expr::FunctionCall(expr) => evaluate_const_function_call(db, exprs, expr, diagnostics) + Expr::FunctionCall(expr) => evaluate_const_function_call(db, arenas, expr, diagnostics) .map(|value| { value_as_const_value(db, expr.ty, &value) .map_err(|err| { @@ -457,7 +456,7 @@ pub fn evaluate_constant_expr( Expr::Tuple(expr) => ConstValue::Struct( expr.items .iter() - .map(|expr_id| evaluate_constant_expr(db, exprs, *expr_id, diagnostics)) + .map(|expr_id| evaluate_constant_expr(db, arenas, *expr_id, diagnostics)) .collect(), expr.ty, ), @@ -480,7 +479,7 @@ pub fn evaluate_constant_expr( .iter() .find(|(member_id, _)| m.id == *member_id) .map(|(_, expr_id)| { - evaluate_constant_expr(db, exprs, *expr_id, diagnostics) + evaluate_constant_expr(db, arenas, *expr_id, diagnostics) }) .unwrap_or_else(|| ConstValue::Missing(skip_diagnostic())) }) @@ -490,18 +489,18 @@ pub fn evaluate_constant_expr( } Expr::EnumVariantCtor(expr) => ConstValue::Enum( expr.variant.clone(), - Box::new(evaluate_constant_expr(db, exprs, expr.value_expr, diagnostics)), + Box::new(evaluate_constant_expr(db, arenas, expr.value_expr, diagnostics)), ), - Expr::MemberAccess(expr) => extract_const_member_access(db, exprs, expr, diagnostics) + Expr::MemberAccess(expr) => extract_const_member_access(db, arenas, expr, diagnostics) .unwrap_or_else(ConstValue::Missing), Expr::FixedSizeArray(expr) => ConstValue::Struct( match &expr.items { crate::FixedSizeArrayItems::Items(items) => items .iter() - .map(|expr_id| evaluate_constant_expr(db, exprs, *expr_id, diagnostics)) + .map(|expr_id| evaluate_constant_expr(db, arenas, *expr_id, diagnostics)) .collect(), crate::FixedSizeArrayItems::ValueAndSize(value, count) => { - let value = evaluate_constant_expr(db, exprs, *value, diagnostics); + let value = evaluate_constant_expr(db, arenas, *value, diagnostics); let count = count.lookup_intern(db); if let Some(count) = count.into_int() { (0..count.to_usize().unwrap()).map(|_| value.clone()).collect() @@ -554,11 +553,11 @@ fn is_function_const(db: &dyn SemanticGroup, function_id: FunctionId) -> bool { /// Attempts to evaluate constants from a function call. fn evaluate_const_function_call( db: &dyn SemanticGroup, - exprs: &Arena, + arenas: &Arenas, expr: &ExprFunctionCall, diagnostics: &mut SemanticDiagnostics, ) -> Maybe { - if let Some(value) = try_extract_minus_literal(db.upcast(), exprs, expr) { + if let Some(value) = try_extract_minus_literal(db.upcast(), &arenas.exprs, expr) { return Ok(value); } let args = expr @@ -566,7 +565,7 @@ fn evaluate_const_function_call( .iter() .filter_map(|arg| try_extract_matches!(arg, ExprFunctionCallArg::Value)) .map(|arg| { - match evaluate_constant_expr(db, exprs, *arg, diagnostics) { + match evaluate_constant_expr(db, arenas, *arg, diagnostics) { ConstValue::Int(v, _ty) => Ok(v), // Handling u256 constants to enable const evaluation of them. ConstValue::Struct(v, _) => { @@ -574,14 +573,14 @@ fn evaluate_const_function_call( Ok(low + (high << 128)) } else { Err(diagnostics.report( - exprs[*arg].stable_ptr().untyped(), + arenas.exprs[*arg].stable_ptr().untyped(), SemanticDiagnosticKind::UnsupportedConstant, )) } } ConstValue::Missing(err) => Err(err), _ => Err(diagnostics.report( - exprs[*arg].stable_ptr().untyped(), + arenas.exprs[*arg].stable_ptr().untyped(), SemanticDiagnosticKind::UnsupportedConstant, )), } @@ -630,21 +629,21 @@ fn evaluate_const_function_call( /// Extract const member access from a const value. fn extract_const_member_access( db: &dyn SemanticGroup, - exprs: &Arena, + arenas: &Arenas, expr: &ExprMemberAccess, diagnostics: &mut SemanticDiagnostics, ) -> Maybe { - let full_struct = evaluate_constant_expr(db, exprs, expr.expr, diagnostics); + let full_struct = evaluate_constant_expr(db, arenas, expr.expr, diagnostics); let ConstValue::Struct(mut values, _) = full_struct else { return Err(diagnostics.report( - exprs[expr.expr].stable_ptr().untyped(), + arenas.exprs[expr.expr].stable_ptr().untyped(), SemanticDiagnosticKind::UnsupportedConstant, )); }; let members = db.concrete_struct_members(expr.concrete_struct_id)?; let Some(member_idx) = members.iter().position(|(_, member)| member.id == expr.member) else { return Err(diagnostics.report( - exprs[expr.expr].stable_ptr().untyped(), + arenas.exprs[expr.expr].stable_ptr().untyped(), SemanticDiagnosticKind::UnsupportedConstant, )); }; diff --git a/crates/cairo-lang-starknet/src/contract.rs b/crates/cairo-lang-starknet/src/contract.rs index bc0bc288b76..f6278d87997 100644 --- a/crates/cairo-lang-starknet/src/contract.rs +++ b/crates/cairo-lang-starknet/src/contract.rs @@ -317,7 +317,10 @@ fn analyze_contract( let constant_id = extract_matches!(item, ModuleItemId::Constant); let constant = db.constant_semantic_data(constant_id).unwrap(); let class_hash: Felt252 = - extract_matches!(&constant.exprs[constant.value], Expr::Literal).value.clone().into(); + extract_matches!(&constant.arenas.exprs[constant.value], Expr::Literal) + .value + .clone() + .into(); // Extract functions. let SemanticEntryPoints { external, l1_handler, constructor } =