From 602c81ec390b2622449765db030878309f75c3f2 Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Fri, 3 Jan 2025 20:31:41 +0200 Subject: [PATCH] Added evaluation of const fns. commit-id:4682d984 --- crates/cairo-lang-semantic/src/diagnostic.rs | 16 ++- .../src/expr/test_data/constant | 18 +++ .../cairo-lang-semantic/src/items/constant.rs | 122 ++++++++++++++---- 3 files changed, 127 insertions(+), 29 deletions(-) diff --git a/crates/cairo-lang-semantic/src/diagnostic.rs b/crates/cairo-lang-semantic/src/diagnostic.rs index d65e8a378c3..bc295e2434c 100644 --- a/crates/cairo-lang-semantic/src/diagnostic.rs +++ b/crates/cairo-lang-semantic/src/diagnostic.rs @@ -9,11 +9,11 @@ use cairo_lang_defs::ids::{ }; use cairo_lang_defs::plugin::PluginDiagnostic; use cairo_lang_diagnostics::{ - DiagnosticAdded, DiagnosticEntry, DiagnosticLocation, DiagnosticsBuilder, ErrorCode, Severity, - error_code, + DiagnosticAdded, DiagnosticEntry, DiagnosticLocation, DiagnosticNote, DiagnosticsBuilder, + ErrorCode, Severity, error_code, }; use cairo_lang_filesystem::db::Edition; -use cairo_lang_syntax as syntax; +use cairo_lang_syntax::{self as syntax}; use itertools::Itertools; use smol_str::SmolStr; use syntax::node::ids::SyntaxStablePtrId; @@ -741,6 +741,7 @@ impl DiagnosticEntry for SemanticDiagnostic { SemanticDiagnosticKind::FailedConstantCalculation => { "Failed to calculate constant.".into() } + SemanticDiagnosticKind::InnerFailedConstantCalculation(inner, _) => inner.format(db), SemanticDiagnosticKind::DivisionByZero => "Division by zero.".into(), SemanticDiagnosticKind::ExternTypeWithImplGenericsNotSupported => { "Extern types with impl generics are not supported.".into() @@ -1027,6 +1028,14 @@ impl DiagnosticEntry for SemanticDiagnostic { } } + fn notes(&self, _db: &Self::DbType) -> &[DiagnosticNote] { + if let SemanticDiagnosticKind::InnerFailedConstantCalculation(_, notes) = &self.kind { + notes + } else { + &[] + } + } + fn error_code(&self) -> Option { self.kind.error_code() } @@ -1313,6 +1322,7 @@ pub enum SemanticDiagnosticKind { UnsupportedOutsideOfFunction(UnsupportedOutsideOfFunctionFeatureName), UnsupportedConstant, FailedConstantCalculation, + InnerFailedConstantCalculation(Box, Vec), DivisionByZero, ExternTypeWithImplGenericsNotSupported, MissingSemicolon, diff --git a/crates/cairo-lang-semantic/src/expr/test_data/constant b/crates/cairo-lang-semantic/src/expr/test_data/constant index 6cdbf673778..4241a63bac1 100644 --- a/crates/cairo-lang-semantic/src/expr/test_data/constant +++ b/crates/cairo-lang-semantic/src/expr/test_data/constant @@ -62,6 +62,15 @@ const FAILING_CALC: felt252 = if true { 70 }; +const FUNC_CALC_SUCCESS: () = panic_if_true(false); +const FUNC_CALC_FAILURE: () = panic_if_true(true); + +const fn panic_if_true(cond: bool) { + if cond { + core::panic_with_felt252('assertion failed') + } +} + //! > expected_diagnostics error: Type not found. --> lib.cairo:1:17 @@ -93,6 +102,15 @@ error: Failed to calculate constant. core::panic_with_felt252('this should fail') ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +error: Failed to calculate constant. + --> lib.cairo:23:31 +const FUNC_CALC_FAILURE: () = panic_if_true(true); + ^^^^^^^^^^^^^^^^^^^ +note: In `test::panic_if_true`: + --> lib.cairo:27:9 + core::panic_with_felt252('assertion failed') + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + //! > ========================================================================== //! > Const of wrong type. diff --git a/crates/cairo-lang-semantic/src/items/constant.rs b/crates/cairo-lang-semantic/src/items/constant.rs index 1d62546bf05..f177d08b86c 100644 --- a/crates/cairo-lang-semantic/src/items/constant.rs +++ b/crates/cairo-lang-semantic/src/items/constant.rs @@ -6,14 +6,16 @@ use cairo_lang_defs::ids::{ ConstantId, GenericParamId, LanguageElementId, LookupItemId, ModuleItemId, NamedLanguageElementId, TraitConstantId, TraitId, VarId, }; -use cairo_lang_diagnostics::{DiagnosticAdded, Diagnostics, Maybe, ToMaybe, skip_diagnostic}; +use cairo_lang_diagnostics::{ + DiagnosticAdded, DiagnosticEntry, DiagnosticNote, Diagnostics, Maybe, ToMaybe, skip_diagnostic, +}; use cairo_lang_proc_macros::{DebugWithDb, SemanticObject}; use cairo_lang_syntax::node::ast::ItemConstant; use cairo_lang_syntax::node::ids::SyntaxStablePtrId; use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode}; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; use cairo_lang_utils::{ - Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches, + Intern, LookupIntern, define_short_id, extract_matches, require, try_extract_matches, }; use itertools::Itertools; use num_bigint::BigInt; @@ -36,13 +38,13 @@ use crate::expr::inference::conform::InferenceConform; use crate::expr::inference::{ConstVar, InferenceId}; use crate::literals::try_extract_minus_literal; use crate::resolve::{Resolver, ResolverData}; -use crate::substitution::SemanticRewriter; +use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter}; use crate::types::resolve_type; use crate::{ - Arenas, ConcreteTypeId, ConcreteVariant, Condition, Expr, ExprBlock, ExprConstant, - ExprFunctionCall, ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, FunctionId, - GenericParam, LogicalOperator, Pattern, PatternId, SemanticDiagnostic, Statement, TypeId, - TypeLongId, semantic_object_for_id, + Arenas, ConcreteFunction, ConcreteTypeId, ConcreteVariant, Condition, Expr, ExprBlock, + ExprConstant, ExprFunctionCall, ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, + FunctionId, GenericParam, LogicalOperator, Pattern, PatternId, SemanticDiagnostic, Statement, + TypeId, TypeLongId, semantic_object_for_id, }; #[derive(Clone, Debug, PartialEq, Eq, DebugWithDb)] @@ -372,6 +374,7 @@ pub fn validate_const_expr(ctx: &mut ComputationContext<'_>, expr_id: ExprId) { db: ctx.db, arenas: &ctx.arenas, vars: Default::default(), + generic_substitution: Default::default(), info: info.as_ref(), diagnostics: ctx.diagnostics, }; @@ -414,6 +417,7 @@ pub fn resolve_const_expr_and_evaluate( db, arenas: &ctx.arenas, vars: Default::default(), + generic_substitution: Default::default(), info: info.as_ref(), diagnostics: ctx.diagnostics, }; @@ -463,6 +467,7 @@ struct ConstantEvaluateContext<'a> { db: &'a dyn SemanticGroup, arenas: &'a Arenas, vars: OrderedHashMap, + generic_substitution: GenericSubstitution, diagnostics: &'a mut SemanticDiagnostics, info: &'a ConstCalcInfo, } @@ -589,6 +594,10 @@ impl ConstantEvaluateContext<'_> { let db = self.db; let concrete_function = function_id.get_concrete(db); let Ok(Some(body)) = concrete_function.body(db) else { return false }; + let signature = self.db.function_with_body_signature(body.function_with_body_id(self.db)); + if signature.map(|s| s.is_const) == Ok(true) { + return true; + } let GenericFunctionWithBodyId::Impl(imp) = body.generic_function(db) else { return false }; let impl_def = imp.concrete_impl_id.impl_def_id(db); if impl_def.parent_module(db.upcast()).owning_crate(db.upcast()) != db.core_crate() { @@ -621,7 +630,11 @@ impl ConstantEvaluateContext<'_> { .get(&expr.var) .cloned() .unwrap_or_else(|| ConstValue::Missing(skip_diagnostic())), - Expr::Constant(expr) => expr.const_value_id.lookup_intern(db), + Expr::Constant(expr) => { + SubstitutionRewriter { db, substitution: &self.generic_substitution } + .rewrite(expr.const_value_id.lookup_intern(db)) + .unwrap_or_else(ConstValue::Missing) + } Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) => { for statement_id in statements { match &self.arenas.statements[*statement_id] { @@ -795,21 +808,38 @@ impl ConstantEvaluateContext<'_> { return value_as_const_value(db, expr.ty, &value) .expect("LiteralError should have been caught on `validate`"); } - let args = match expr + let args = expr .args .iter() .filter_map(|arg| try_extract_matches!(arg, ExprFunctionCallArg::Value)) - .map(|arg| match self.evaluate(*arg) { + .map(|arg| self.evaluate(*arg)) + .collect_vec(); + if expr.function == self.panic_with_felt252 { + return ConstValue::Missing(self.diagnostics.report( + expr.stable_ptr.untyped(), + SemanticDiagnosticKind::FailedConstantCalculation, + )); + } + let concrete_function = expr.function.get_concrete(db); + if let Some(calc_result) = + self.evaluate_const_function_call(&concrete_function, &args, expr.stable_ptr.untyped()) + { + return calc_result; + } + + let args = match args + .into_iter() + .map(|arg| match arg { ConstValue::Int(v, _ty) => Ok(v), // Handling u256 constants to enable const evaluation of them. ConstValue::Struct(v, _) => { if let [ConstValue::Int(low, _), ConstValue::Int(high, _)] = &v[..] { Ok(low + (high << 128)) } else { - Err(self.diagnostics.report( - self.arenas.exprs[*arg].stable_ptr().untyped(), - SemanticDiagnosticKind::UnsupportedConstant, - )) + // Dignostic can be skipped as we would either have a semantic error for a + // bad arg for the function, or the arg itself + // could'nt have been calculated. + Err(skip_diagnostic()) } } ConstValue::Missing(err) => Err(err), @@ -817,24 +847,13 @@ impl ConstantEvaluateContext<'_> { // for the function, or the arg itself could'nt have been calculated. _ => Err(skip_diagnostic()), }) - .collect_vec() - .into_iter() .collect::, _>>() { Ok(args) => args, Err(err) => return ConstValue::Missing(err), }; - if expr.function == self.panic_with_felt252 { - return ConstValue::Missing(self.diagnostics.report( - expr.stable_ptr.untyped(), - SemanticDiagnosticKind::FailedConstantCalculation, - )); - } - let imp = extract_matches!( - expr.function.get_concrete(db.upcast()).generic_function, - GenericFunctionId::Impl - ); + let imp = extract_matches!(concrete_function.generic_function, GenericFunctionId::Impl); let is_felt252_ty = expr.ty == db.core_felt252_ty(); let mut value = match imp.impl_id.concrete_trait(self.db).unwrap().trait_id(self.db) { id if id == self.neg_trait => -&args[0], @@ -870,6 +889,57 @@ impl ConstantEvaluateContext<'_> { .unwrap_or_else(ConstValue::Missing) } + /// Attempts to evaluate a constant function call. + fn evaluate_const_function_call( + &mut self, + concrete_function: &ConcreteFunction, + args: &[ConstValue], + stable_ptr: SyntaxStablePtrId, + ) -> Option { + let db = self.db; + let body_id = concrete_function.body(db).ok()??; + let concrete_body_id = body_id.function_with_body_id(db); + let signature = db.function_with_body_signature(concrete_body_id).ok()?; + require(signature.is_const)?; + let generic_substitution = body_id.substitution(db).ok()?; + let body = db.function_body(concrete_body_id).ok()?; + let mut diagnostics = SemanticDiagnostics::default(); + let mut inner = ConstantEvaluateContext { + db, + arenas: &body.arenas, + vars: signature + .params + .into_iter() + .map(|p| VarId::Param(p.id)) + .zip(args.iter().cloned()) + .collect(), + generic_substitution, + info: self.info, + diagnostics: &mut diagnostics, + }; + let value = inner.evaluate(body.body_expr); + for diagnostic in diagnostics.build().get_all() { + let location = diagnostic.location(db.elongate()); + let (inner_diag, mut notes) = + if let SemanticDiagnosticKind::InnerFailedConstantCalculation(inner_diag, notes) = + diagnostic.kind + { + (inner_diag, notes) + } else { + (diagnostic.into(), vec![]) + }; + notes.push(DiagnosticNote::with_location( + format!("In `{}`", concrete_function.full_name(db)), + location, + )); + self.diagnostics.report( + stable_ptr, + SemanticDiagnosticKind::InnerFailedConstantCalculation(inner_diag, notes), + ); + } + Some(value) + } + /// Extract const member access from a const value. fn evaluate_member_access(&mut self, expr: &ExprMemberAccess) -> Maybe { let full_struct = self.evaluate(expr.expr);