Skip to content

Commit

Permalink
Refactored constant constant fetches. (#6980)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Jan 5, 2025
1 parent 2d18266 commit c579d8c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 26 deletions.
5 changes: 4 additions & 1 deletion crates/cairo-lang-semantic/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use smol_str::SmolStr;

use crate::diagnostic::SemanticDiagnosticKind;
use crate::expr::inference::{self, ImplVar, ImplVarId};
use crate::items::constant::{ConstValueId, Constant, ImplConstantId};
use crate::items::constant::{ConstCalcInfo, ConstValueId, Constant, ImplConstantId};
use crate::items::function_with_body::FunctionBody;
use crate::items::functions::{ImplicitPrecedence, InlineConfiguration};
use crate::items::generics::{GenericParam, GenericParamData, GenericParamsData};
Expand Down Expand Up @@ -163,6 +163,9 @@ pub trait SemanticGroup:
#[salsa::invoke(items::constant::constant_const_type)]
#[salsa::cycle(items::constant::constant_const_type_cycle)]
fn constant_const_type(&self, const_id: ConstantId) -> Maybe<TypeId>;
/// Returns information required for const calculations.
#[salsa::invoke(items::constant::const_calc_info)]
fn const_calc_info(&self) -> Arc<ConstCalcInfo>;

// Use.
// ====
Expand Down
106 changes: 81 additions & 25 deletions crates/cairo-lang-semantic/src/items/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::Arc;
use cairo_lang_debug::DebugWithDb;
use cairo_lang_defs::ids::{
ConstantId, GenericParamId, LanguageElementId, LookupItemId, ModuleItemId,
NamedLanguageElementId, TraitConstantId, VarId,
NamedLanguageElementId, TraitConstantId, TraitId, VarId,
};
use cairo_lang_diagnostics::{DiagnosticAdded, Diagnostics, Maybe, ToMaybe, skip_diagnostic};
use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
Expand Down Expand Up @@ -366,8 +366,10 @@ pub fn constant_semantic_data_cycle_helper(

/// Checks if the given expression only involved constant calculations.
pub fn validate_const_expr(ctx: &mut ComputationContext<'_>, expr_id: ExprId) {
let info = ctx.db.const_calc_info();
let mut eval_ctx = ConstantEvaluateContext {
db: ctx.db,
info: info.as_ref(),
arenas: &ctx.arenas,
vars: Default::default(),
diagnostics: ctx.diagnostics,
Expand Down Expand Up @@ -406,10 +408,12 @@ pub fn resolve_const_expr_and_evaluate(
// Check that the expression is a valid constant.
_ if ctx.diagnostics.error_count > prev_err_count => ConstValue::Missing(skip_diagnostic()),
_ => {
let info = db.const_calc_info();
let mut eval_ctx = ConstantEvaluateContext {
db,
arenas: &ctx.arenas,
vars: Default::default(),
info: info.as_ref(),
diagnostics: ctx.diagnostics,
};
eval_ctx.validate(value.id);
Expand Down Expand Up @@ -456,6 +460,7 @@ pub fn value_as_const_value(
/// A context for evaluating constant expressions.
struct ConstantEvaluateContext<'a> {
db: &'a dyn SemanticGroup,
info: &'a ConstCalcInfo,
arenas: &'a Arenas,
vars: OrderedHashMap<VarId, ConstValue>,
diagnostics: &'a mut SemanticDiagnostics,
Expand Down Expand Up @@ -588,19 +593,18 @@ impl ConstantEvaluateContext<'_> {
let Ok(trait_id) = db.impl_def_trait(impl_def) else {
return false;
};
let expected_trait_name = match imp.function_body.name(db).as_str() {
"neg" => "Neg",
"add" => "Add",
"sub" => "Sub",
"mul" => "Mul",
"div" => "Div",
"rem" => "Rem",
"bitand" => "BitAnd",
"bitor" => "BitOr",
"bitxor" => "BitXor",
_ => return false,
};
trait_id == get_core_trait(db, CoreTraitContext::TopLevel, expected_trait_name.into())
[
self.neg_trait,
self.add_trait,
self.sub_trait,
self.mul_trait,
self.div_trait,
self.rem_trait,
self.bit_and_trait,
self.bit_or_trait,
self.bit_xor_trait,
]
.contains(&trait_id)
}

/// Evaluate the given const expression value.
Expand Down Expand Up @@ -822,22 +826,22 @@ impl ConstantEvaluateContext<'_> {
GenericFunctionId::Impl
);
let is_felt252_ty = expr.ty == db.core_felt252_ty();
let mut value = match imp.function.name(db.upcast()).as_str() {
"neg" => -&args[0],
"add" => &args[0] + &args[1],
"sub" => &args[0] - &args[1],
"mul" => &args[0] * &args[1],
"div" | "rem" if args[1].is_zero() => {
let mut value = match imp.impl_id.concrete_trait(self.db).unwrap().trait_id(self.db) {
id if id == self.neg_trait => -&args[0],
id if id == self.add_trait => &args[0] + &args[1],
id if id == self.sub_trait => &args[0] - &args[1],
id if id == self.mul_trait => &args[0] * &args[1],
id if (id == self.div_trait || id == self.rem_trait) && args[1].is_zero() => {
return ConstValue::Missing(
self.diagnostics
.report(expr.stable_ptr.untyped(), SemanticDiagnosticKind::DivisionByZero),
);
}
"div" if !is_felt252_ty => &args[0] / &args[1],
"rem" if !is_felt252_ty => &args[0] % &args[1],
"bitand" if !is_felt252_ty => &args[0] & &args[1],
"bitor" if !is_felt252_ty => &args[0] | &args[1],
"bitxor" if !is_felt252_ty => &args[0] ^ &args[1],
id if id == self.div_trait && !is_felt252_ty => &args[0] / &args[1],
id if id == self.rem_trait && !is_felt252_ty => &args[0] % &args[1],
id if id == self.bit_and_trait && !is_felt252_ty => &args[0] & &args[1],
id if id == self.bit_or_trait && !is_felt252_ty => &args[0] | &args[1],
id if id == self.bit_xor_trait && !is_felt252_ty => &args[0] ^ &args[1],
_ => unreachable!("Unexpected function call in constant lowering: {:?}", expr),
};
if is_felt252_ty {
Expand Down Expand Up @@ -928,6 +932,13 @@ impl ConstantEvaluateContext<'_> {
}
}

impl std::ops::Deref for ConstantEvaluateContext<'_> {
type Target = ConstCalcInfo;
fn deref(&self) -> &Self::Target {
self.info
}
}

/// Query implementation of [SemanticGroup::constant_semantic_diagnostics].
pub fn constant_semantic_diagnostics(
db: &dyn SemanticGroup,
Expand Down Expand Up @@ -997,3 +1008,48 @@ pub fn constant_const_type_cycle(
// Forwarding cycle handling to `priv_constant_semantic_data` handler.
db.priv_constant_semantic_data(*const_id, true)?.const_value.ty(db)
}

/// Query implementation of [crate::db::SemanticGroup::const_calc_info].
pub fn const_calc_info(db: &dyn SemanticGroup) -> Arc<ConstCalcInfo> {
Arc::new(ConstCalcInfo::new(db))
}

/// Holds static information about extern functions required for const calculations.
#[derive(Debug, PartialEq, Eq)]
pub struct ConstCalcInfo {
/// The trait for negation.
pub neg_trait: TraitId,
/// The trait for addition.
pub add_trait: TraitId,
/// The trait for subtraction.
pub sub_trait: TraitId,
/// The trait for multiplication.
pub mul_trait: TraitId,
/// The trait for division.
pub div_trait: TraitId,
/// The trait for remainder.
pub rem_trait: TraitId,
/// The trait for bitwise and.
pub bit_and_trait: TraitId,
/// The trait for bitwise or.
pub bit_or_trait: TraitId,
/// The trait for bitwise xor.
pub bit_xor_trait: TraitId,
}

impl ConstCalcInfo {
/// Creates a new ConstCalcInfo.
pub fn new(db: &dyn SemanticGroup) -> Self {
Self {
neg_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Neg".into()),
add_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Add".into()),
sub_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Sub".into()),
mul_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Mul".into()),
div_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Div".into()),
rem_trait: get_core_trait(db, CoreTraitContext::TopLevel, "Rem".into()),
bit_and_trait: get_core_trait(db, CoreTraitContext::TopLevel, "BitAnd".into()),
bit_or_trait: get_core_trait(db, CoreTraitContext::TopLevel, "BitOr".into()),
bit_xor_trait: get_core_trait(db, CoreTraitContext::TopLevel, "BitXor".into()),
}
}
}

0 comments on commit c579d8c

Please sign in to comment.