Skip to content

Commit

Permalink
Added evaluation of const fns.
Browse files Browse the repository at this point in the history
commit-id:4682d984
  • Loading branch information
orizi committed Jan 4, 2025
1 parent 90e8dbf commit 798c008
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 29 deletions.
16 changes: 13 additions & 3 deletions crates/cairo-lang-semantic/src/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ use cairo_lang_defs::ids::{
};
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_diagnostics::{
DiagnosticAdded, DiagnosticEntry, DiagnosticLocation, DiagnosticsBuilder, ErrorCode, Severity,
error_code,
DiagnosticAdded, DiagnosticEntry, DiagnosticLocation, DiagnosticNote, DiagnosticsBuilder,
ErrorCode, Severity, error_code,
};
use cairo_lang_filesystem::db::Edition;
use cairo_lang_syntax as syntax;
use cairo_lang_syntax::{self as syntax};
use itertools::Itertools;
use smol_str::SmolStr;
use syntax::node::ids::SyntaxStablePtrId;
Expand Down Expand Up @@ -741,6 +741,7 @@ impl DiagnosticEntry for SemanticDiagnostic {
SemanticDiagnosticKind::FailedConstantCalculation => {
"Failed to calculate constant.".into()
}
SemanticDiagnosticKind::InnerFailedConstantCalculation(inner, _) => inner.format(db),
SemanticDiagnosticKind::DivisionByZero => "Division by zero.".into(),
SemanticDiagnosticKind::ExternTypeWithImplGenericsNotSupported => {
"Extern types with impl generics are not supported.".into()
Expand Down Expand Up @@ -1027,6 +1028,14 @@ impl DiagnosticEntry for SemanticDiagnostic {
}
}

fn notes(&self, _db: &Self::DbType) -> &[DiagnosticNote] {
if let SemanticDiagnosticKind::InnerFailedConstantCalculation(_, notes) = &self.kind {
notes
} else {
&[]
}
}

fn error_code(&self) -> Option<ErrorCode> {
self.kind.error_code()
}
Expand Down Expand Up @@ -1313,6 +1322,7 @@ pub enum SemanticDiagnosticKind {
UnsupportedOutsideOfFunction(UnsupportedOutsideOfFunctionFeatureName),
UnsupportedConstant,
FailedConstantCalculation,
InnerFailedConstantCalculation(Box<SemanticDiagnostic>, Vec<DiagnosticNote>),
DivisionByZero,
ExternTypeWithImplGenericsNotSupported,
MissingSemicolon,
Expand Down
18 changes: 18 additions & 0 deletions crates/cairo-lang-semantic/src/expr/test_data/constant
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ const FAILING_CALC: felt252 = if true {
70
};

const FUNC_CALC_SUCCESS: () = panic_if_true(false);
const FUNC_CALC_FAILURE: () = panic_if_true(true);

const fn panic_if_true(cond: bool) {
if cond {
core::panic_with_felt252('assertion failed')
}
}

//! > expected_diagnostics
error: Type not found.
--> lib.cairo:1:17
Expand Down Expand Up @@ -93,6 +102,15 @@ error: Failed to calculate constant.
core::panic_with_felt252('this should fail')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: Failed to calculate constant.
--> lib.cairo:23:31
const FUNC_CALC_FAILURE: () = panic_if_true(true);
^^^^^^^^^^^^^^^^^^^
note: In `test::panic_if_true`:
--> lib.cairo:27:9
core::panic_with_felt252('assertion failed')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

//! > ==========================================================================

//! > Const of wrong type.
Expand Down
122 changes: 96 additions & 26 deletions crates/cairo-lang-semantic/src/items/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ use cairo_lang_defs::ids::{
ConstantId, GenericParamId, LanguageElementId, LookupItemId, ModuleItemId,
NamedLanguageElementId, TraitConstantId, TraitId, VarId,
};
use cairo_lang_diagnostics::{DiagnosticAdded, Diagnostics, Maybe, ToMaybe, skip_diagnostic};
use cairo_lang_diagnostics::{
DiagnosticAdded, DiagnosticEntry, DiagnosticNote, Diagnostics, Maybe, ToMaybe, skip_diagnostic,
};
use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
use cairo_lang_syntax::node::ast::ItemConstant;
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode};
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::{
Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
Intern, LookupIntern, define_short_id, extract_matches, require, try_extract_matches,
};
use itertools::Itertools;
use num_bigint::BigInt;
Expand All @@ -36,13 +38,13 @@ use crate::expr::inference::conform::InferenceConform;
use crate::expr::inference::{ConstVar, InferenceId};
use crate::literals::try_extract_minus_literal;
use crate::resolve::{Resolver, ResolverData};
use crate::substitution::SemanticRewriter;
use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter};
use crate::types::resolve_type;
use crate::{
Arenas, ConcreteTypeId, ConcreteVariant, Condition, Expr, ExprBlock, ExprConstant,
ExprFunctionCall, ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, FunctionId,
GenericParam, LogicalOperator, Pattern, PatternId, SemanticDiagnostic, Statement, TypeId,
TypeLongId, semantic_object_for_id,
Arenas, ConcreteFunction, ConcreteTypeId, ConcreteVariant, Condition, Expr, ExprBlock,
ExprConstant, ExprFunctionCall, ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor,
FunctionId, GenericParam, LogicalOperator, Pattern, PatternId, SemanticDiagnostic, Statement,
TypeId, TypeLongId, semantic_object_for_id,
};

#[derive(Clone, Debug, PartialEq, Eq, DebugWithDb)]
Expand Down Expand Up @@ -372,6 +374,7 @@ pub fn validate_const_expr(ctx: &mut ComputationContext<'_>, expr_id: ExprId) {
db: ctx.db,
arenas: &ctx.arenas,
vars: Default::default(),
generic_substitution: Default::default(),
info: info.as_ref(),
diagnostics: ctx.diagnostics,
};
Expand Down Expand Up @@ -414,6 +417,7 @@ pub fn resolve_const_expr_and_evaluate(
db,
arenas: &ctx.arenas,
vars: Default::default(),
generic_substitution: Default::default(),
info: info.as_ref(),
diagnostics: ctx.diagnostics,
};
Expand Down Expand Up @@ -463,6 +467,7 @@ struct ConstantEvaluateContext<'a> {
db: &'a dyn SemanticGroup,
arenas: &'a Arenas,
vars: OrderedHashMap<VarId, ConstValue>,
generic_substitution: GenericSubstitution,
diagnostics: &'a mut SemanticDiagnostics,
info: &'a ConstCalcInfo,
}
Expand Down Expand Up @@ -589,6 +594,10 @@ impl ConstantEvaluateContext<'_> {
let db = self.db;
let concrete_function = function_id.get_concrete(db);
let Ok(Some(body)) = concrete_function.body(db) else { return false };
let signature = self.db.function_with_body_signature(body.function_with_body_id(self.db));
if signature.map(|s| s.is_const) == Ok(true) {
return true;
}
let GenericFunctionWithBodyId::Impl(imp) = body.generic_function(db) else { return false };
let impl_def = imp.concrete_impl_id.impl_def_id(db);
if impl_def.parent_module(db.upcast()).owning_crate(db.upcast()) != db.core_crate() {
Expand Down Expand Up @@ -621,7 +630,11 @@ impl ConstantEvaluateContext<'_> {
.get(&expr.var)
.cloned()
.unwrap_or_else(|| ConstValue::Missing(skip_diagnostic())),
Expr::Constant(expr) => expr.const_value_id.lookup_intern(db),
Expr::Constant(expr) => {
SubstitutionRewriter { db, substitution: &self.generic_substitution }
.rewrite(expr.const_value_id.lookup_intern(db))
.unwrap_or_else(ConstValue::Missing)
}
Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) => {
for statement_id in statements {
match &self.arenas.statements[*statement_id] {
Expand Down Expand Up @@ -795,46 +808,52 @@ impl ConstantEvaluateContext<'_> {
return value_as_const_value(db, expr.ty, &value)
.expect("LiteralError should have been caught on `validate`");
}
let args = match expr
let args = expr
.args
.iter()
.filter_map(|arg| try_extract_matches!(arg, ExprFunctionCallArg::Value))
.map(|arg| match self.evaluate(*arg) {
.map(|arg| self.evaluate(*arg))
.collect_vec();
if expr.function == self.panic_with_felt252 {
return ConstValue::Missing(self.diagnostics.report(
expr.stable_ptr.untyped(),
SemanticDiagnosticKind::FailedConstantCalculation,
));
}
let concrete_function = expr.function.get_concrete(db);
if let Some(calc_result) =
self.evaluate_const_function_call(&concrete_function, &args, expr.stable_ptr.untyped())
{
return calc_result;
}

let args = match args
.into_iter()
.map(|arg| match arg {
ConstValue::Int(v, _ty) => Ok(v),
// Handling u256 constants to enable const evaluation of them.
ConstValue::Struct(v, _) => {
if let [ConstValue::Int(low, _), ConstValue::Int(high, _)] = &v[..] {
Ok(low + (high << 128))
} else {
Err(self.diagnostics.report(
self.arenas.exprs[*arg].stable_ptr().untyped(),
SemanticDiagnosticKind::UnsupportedConstant,
))
// Dignostic can be skipped as we would either have a semantic error for a
// bad arg for the function, or the arg itself
// could'nt have been calculated.
Err(skip_diagnostic())
}
}
ConstValue::Missing(err) => Err(err),
// Dignostic can be skipped as we would either have a semantic error for a bad arg
// for the function, or the arg itself could'nt have been calculated.
_ => Err(skip_diagnostic()),
})
.collect_vec()
.into_iter()
.collect::<Result<Vec<_>, _>>()
{
Ok(args) => args,
Err(err) => return ConstValue::Missing(err),
};
if expr.function == self.panic_with_felt252 {
return ConstValue::Missing(self.diagnostics.report(
expr.stable_ptr.untyped(),
SemanticDiagnosticKind::FailedConstantCalculation,
));
}

let imp = extract_matches!(
expr.function.get_concrete(db.upcast()).generic_function,
GenericFunctionId::Impl
);
let imp = extract_matches!(concrete_function.generic_function, GenericFunctionId::Impl);
let is_felt252_ty = expr.ty == db.core_felt252_ty();
let mut value = match imp.impl_id.concrete_trait(self.db).unwrap().trait_id(self.db) {
id if id == self.neg_trait => -&args[0],
Expand Down Expand Up @@ -870,6 +889,57 @@ impl ConstantEvaluateContext<'_> {
.unwrap_or_else(ConstValue::Missing)
}

/// Attempts to evaluate a constant function call.
fn evaluate_const_function_call(
&mut self,
concrete_function: &ConcreteFunction,
args: &[ConstValue],
stable_ptr: SyntaxStablePtrId,
) -> Option<ConstValue> {
let db = self.db;
let body_id = concrete_function.body(db).ok()??;
let concrete_body_id = body_id.function_with_body_id(db);
let signature = db.function_with_body_signature(concrete_body_id).ok()?;
require(signature.is_const)?;
let generic_substitution = body_id.substitution(db).ok()?;
let body = db.function_body(concrete_body_id).ok()?;
let mut diagnostics = SemanticDiagnostics::default();
let mut inner = ConstantEvaluateContext {
db,
arenas: &body.arenas,
vars: signature
.params
.into_iter()
.map(|p| VarId::Param(p.id))
.zip(args.iter().cloned())
.collect(),
generic_substitution,
info: self.info,
diagnostics: &mut diagnostics,
};
let value = inner.evaluate(body.body_expr);
for diagnostic in diagnostics.build().get_all() {
let location = diagnostic.location(db.elongate());
let (inner_diag, mut notes) =
if let SemanticDiagnosticKind::InnerFailedConstantCalculation(inner_diag, notes) =
diagnostic.kind
{
(inner_diag, notes)
} else {
(diagnostic.into(), vec![])
};
notes.push(DiagnosticNote::with_location(
format!("In `{}`", concrete_function.full_name(db)),
location,
));
self.diagnostics.report(
stable_ptr,
SemanticDiagnosticKind::InnerFailedConstantCalculation(inner_diag, notes),
);
}
Some(value)
}

/// Extract const member access from a const value.
fn evaluate_member_access(&mut self, expr: &ExprMemberAccess) -> Maybe<ConstValue> {
let full_struct = self.evaluate(expr.expr);
Expand Down

0 comments on commit 798c008

Please sign in to comment.