diff --git a/crates/cairo-lang-semantic/src/db.rs b/crates/cairo-lang-semantic/src/db.rs index 0515093ff9e..a5cdb3995a5 100644 --- a/crates/cairo-lang-semantic/src/db.rs +++ b/crates/cairo-lang-semantic/src/db.rs @@ -24,7 +24,7 @@ use smol_str::SmolStr; use crate::diagnostic::SemanticDiagnosticKind; use crate::expr::inference::{self, ImplVar, ImplVarId}; -use crate::items::constant::{ConstValueId, Constant, ImplConstantId}; +use crate::items::constant::{ConstCalcInfo, ConstValueId, Constant, ImplConstantId}; use crate::items::function_with_body::FunctionBody; use crate::items::functions::{ImplicitPrecedence, InlineConfiguration}; use crate::items::generics::{GenericParam, GenericParamData, GenericParamsData}; @@ -163,6 +163,9 @@ pub trait SemanticGroup: #[salsa::invoke(items::constant::constant_const_type)] #[salsa::cycle(items::constant::constant_const_type_cycle)] fn constant_const_type(&self, const_id: ConstantId) -> Maybe; + /// Returns information required for const calculations. + #[salsa::invoke(items::constant::const_calc_info)] + fn const_calc_info(&self) -> Arc; // Use. // ==== diff --git a/crates/cairo-lang-semantic/src/items/constant.rs b/crates/cairo-lang-semantic/src/items/constant.rs index 3bcb27640e7..9534a2dada0 100644 --- a/crates/cairo-lang-semantic/src/items/constant.rs +++ b/crates/cairo-lang-semantic/src/items/constant.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use cairo_lang_debug::DebugWithDb; use cairo_lang_defs::ids::{ ConstantId, GenericParamId, LanguageElementId, LookupItemId, ModuleItemId, - NamedLanguageElementId, TraitConstantId, VarId, + NamedLanguageElementId, TraitConstantId, TraitId, VarId, }; use cairo_lang_diagnostics::{DiagnosticAdded, Diagnostics, Maybe, ToMaybe, skip_diagnostic}; use cairo_lang_proc_macros::{DebugWithDb, SemanticObject}; @@ -366,8 +366,10 @@ pub fn constant_semantic_data_cycle_helper( /// Checks if the given expression only involved constant calculations. pub fn validate_const_expr(ctx: &mut ComputationContext<'_>, expr_id: ExprId) { + let info = ctx.db.const_calc_info(); let mut eval_ctx = ConstantEvaluateContext { db: ctx.db, + info: info.as_ref(), arenas: &ctx.arenas, vars: Default::default(), diagnostics: ctx.diagnostics, @@ -406,10 +408,12 @@ pub fn resolve_const_expr_and_evaluate( // Check that the expression is a valid constant. _ if ctx.diagnostics.error_count > prev_err_count => ConstValue::Missing(skip_diagnostic()), _ => { + let info = db.const_calc_info(); let mut eval_ctx = ConstantEvaluateContext { db, arenas: &ctx.arenas, vars: Default::default(), + info: info.as_ref(), diagnostics: ctx.diagnostics, }; eval_ctx.validate(value.id); @@ -456,6 +460,7 @@ pub fn value_as_const_value( /// A context for evaluating constant expressions. struct ConstantEvaluateContext<'a> { db: &'a dyn SemanticGroup, + info: &'a ConstCalcInfo, arenas: &'a Arenas, vars: OrderedHashMap, diagnostics: &'a mut SemanticDiagnostics, @@ -588,19 +593,18 @@ impl ConstantEvaluateContext<'_> { let Ok(trait_id) = db.impl_def_trait(impl_def) else { return false; }; - let expected_trait_name = match imp.function_body.name(db).as_str() { - "neg" => "Neg", - "add" => "Add", - "sub" => "Sub", - "mul" => "Mul", - "div" => "Div", - "rem" => "Rem", - "bitand" => "BitAnd", - "bitor" => "BitOr", - "bitxor" => "BitXor", - _ => return false, - }; - trait_id == get_core_trait(db, CoreTraitContext::TopLevel, expected_trait_name.into()) + [ + self.neg_trait, + self.add_trait, + self.sub_trait, + self.mul_trait, + self.div_trait, + self.rem_trait, + self.bit_and_trait, + self.bit_or_trait, + self.bit_xor_trait, + ] + .contains(&trait_id) } /// Evaluate the given const expression value. @@ -822,22 +826,22 @@ impl ConstantEvaluateContext<'_> { GenericFunctionId::Impl ); let is_felt252_ty = expr.ty == db.core_felt252_ty(); - let mut value = match imp.function.name(db.upcast()).as_str() { - "neg" => -&args[0], - "add" => &args[0] + &args[1], - "sub" => &args[0] - &args[1], - "mul" => &args[0] * &args[1], - "div" | "rem" if args[1].is_zero() => { + let mut value = match imp.impl_id.concrete_trait(self.db).unwrap().trait_id(self.db) { + id if id == self.neg_trait => -&args[0], + id if id == self.add_trait => &args[0] + &args[1], + id if id == self.sub_trait => &args[0] - &args[1], + id if id == self.mul_trait => &args[0] * &args[1], + id if (id == self.div_trait || id == self.rem_trait) && args[1].is_zero() => { return ConstValue::Missing( self.diagnostics .report(expr.stable_ptr.untyped(), SemanticDiagnosticKind::DivisionByZero), ); } - "div" if !is_felt252_ty => &args[0] / &args[1], - "rem" if !is_felt252_ty => &args[0] % &args[1], - "bitand" if !is_felt252_ty => &args[0] & &args[1], - "bitor" if !is_felt252_ty => &args[0] | &args[1], - "bitxor" if !is_felt252_ty => &args[0] ^ &args[1], + id if id == self.div_trait && !is_felt252_ty => &args[0] / &args[1], + id if id == self.rem_trait && !is_felt252_ty => &args[0] % &args[1], + id if id == self.bit_and_trait && !is_felt252_ty => &args[0] & &args[1], + id if id == self.bit_or_trait && !is_felt252_ty => &args[0] | &args[1], + id if id == self.bit_xor_trait && !is_felt252_ty => &args[0] ^ &args[1], _ => unreachable!("Unexpected function call in constant lowering: {:?}", expr), }; if is_felt252_ty { @@ -928,6 +932,13 @@ impl ConstantEvaluateContext<'_> { } } +impl std::ops::Deref for ConstantEvaluateContext<'_> { + type Target = ConstCalcInfo; + fn deref(&self) -> &Self::Target { + self.info + } +} + /// Query implementation of [SemanticGroup::constant_semantic_diagnostics]. pub fn constant_semantic_diagnostics( db: &dyn SemanticGroup, @@ -997,3 +1008,48 @@ pub fn constant_const_type_cycle( // Forwarding cycle handling to `priv_constant_semantic_data` handler. db.priv_constant_semantic_data(*const_id, true)?.const_value.ty(db) } + +/// Query implementation of [crate::db::SemanticGroup::const_calc_info]. +pub fn const_calc_info(db: &dyn SemanticGroup) -> Arc { + Arc::new(ConstCalcInfo::new(db)) +} + +/// Holds static information about extern functions required for const calculations. +#[derive(Debug, PartialEq, Eq)] +pub struct ConstCalcInfo { + /// The trait for negation. + pub neg_trait: TraitId, + /// The trait for addition. + pub add_trait: TraitId, + /// The trait for subtraction. + pub sub_trait: TraitId, + /// The trait for multiplication. + pub mul_trait: TraitId, + /// The trait for division. + pub div_trait: TraitId, + /// The trait for remainder. + pub rem_trait: TraitId, + /// The trait for bitwise and. + pub bit_and_trait: TraitId, + /// The trait for bitwise or. + pub bit_or_trait: TraitId, + /// The trait for bitwise xor. + pub bit_xor_trait: TraitId, +} + +impl ConstCalcInfo { + /// Creates a new ConstCalcInfo. + pub fn new(db: &dyn SemanticGroup) -> Self { + Self { + neg_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Neg".into()), + add_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Add".into()), + sub_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Sub".into()), + mul_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Mul".into()), + div_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Div".into()), + rem_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Rem".into()), + bit_and_trait: get_core_trait(db, CoreTraitContext::TopLevel, "BitAnd".into()), + bit_or_trait: get_core_trait(db, CoreTraitContext::TopLevel, "BitOr".into()), + bit_xor_trait: get_core_trait(db, CoreTraitContext::TopLevel, "BitXor".into()), + } + } +}