From bd52fe40ac069953ff5d5be766f0a8e8b632331a Mon Sep 17 00:00:00 2001 From: orizi <104711814+orizi@users.noreply.github.com> Date: Sun, 5 Jan 2025 20:57:54 +0500 Subject: [PATCH] Added support for more complex constant calculations. (#6979) --- .../test/language_features/const_test.cairo | 26 ++ .../cairo-lang-semantic/src/items/constant.rs | 237 ++++++++++++++++-- 2 files changed, 247 insertions(+), 16 deletions(-) diff --git a/corelib/src/test/language_features/const_test.cairo b/corelib/src/test/language_features/const_test.cairo index 786049afe7f..702ef3e1ca4 100644 --- a/corelib/src/test/language_features/const_test.cairo +++ b/corelib/src/test/language_features/const_test.cairo @@ -132,3 +132,29 @@ fn test_two_complex_enums() { .unbox() == (ThreeOptions2::A(1337), ThreeOptions2::C), ); } + +#[test] +fn test_complex_consts() { + const VAR_AND_MATCH_CONST: felt252 = { + let x = Option::Some((1, 2_u8)); + match x { + Option::Some((v, _)) => v, + Option::None => 3, + } + }; + assert_eq!(VAR_AND_MATCH_CONST, 1); + const TRUE: bool = true; + const IF_CONST_TRUE: felt252 = if TRUE { + 4 + } else { + 5 + }; + assert_eq!(IF_CONST_TRUE, 4); + const FALSE: bool = false; + const IF_CONST_FALSE: felt252 = if FALSE { + 6 + } else { + 7 + }; + assert_eq!(IF_CONST_FALSE, 7); +} diff --git a/crates/cairo-lang-semantic/src/items/constant.rs b/crates/cairo-lang-semantic/src/items/constant.rs index e98b6f48ba1..3bcb27640e7 100644 --- a/crates/cairo-lang-semantic/src/items/constant.rs +++ b/crates/cairo-lang-semantic/src/items/constant.rs @@ -1,15 +1,17 @@ +use std::iter::zip; use std::sync::Arc; use cairo_lang_debug::DebugWithDb; use cairo_lang_defs::ids::{ ConstantId, GenericParamId, LanguageElementId, LookupItemId, ModuleItemId, - NamedLanguageElementId, TraitConstantId, + NamedLanguageElementId, TraitConstantId, VarId, }; use cairo_lang_diagnostics::{DiagnosticAdded, Diagnostics, Maybe, ToMaybe, skip_diagnostic}; use cairo_lang_proc_macros::{DebugWithDb, SemanticObject}; use cairo_lang_syntax::node::ast::ItemConstant; use cairo_lang_syntax::node::ids::SyntaxStablePtrId; use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode}; +use cairo_lang_utils::ordered_hash_map::OrderedHashMap; use cairo_lang_utils::{ Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches, }; @@ -21,8 +23,8 @@ use smol_str::SmolStr; use super::functions::{GenericFunctionId, GenericFunctionWithBodyId}; use super::imp::ImplId; use crate::corelib::{ - CoreTraitContext, LiteralError, core_box_ty, core_nonzero_ty, get_core_trait, - get_core_ty_by_name, try_extract_nz_wrapped_type, validate_literal, + CoreTraitContext, LiteralError, core_box_ty, core_nonzero_ty, false_variant, get_core_trait, + get_core_ty_by_name, true_variant, try_extract_nz_wrapped_type, unit_ty, validate_literal, }; use crate::db::SemanticGroup; use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder}; @@ -36,9 +38,10 @@ use crate::resolve::{Resolver, ResolverData}; use crate::substitution::SemanticRewriter; use crate::types::resolve_type; use crate::{ - Arenas, ConcreteTypeId, ConcreteVariant, Expr, ExprBlock, ExprConstant, ExprFunctionCall, - ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, FunctionId, GenericParam, - SemanticDiagnostic, TypeId, TypeLongId, semantic_object_for_id, + Arenas, ConcreteTypeId, ConcreteVariant, Condition, Expr, ExprBlock, ExprConstant, + ExprFunctionCall, ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, FunctionId, + GenericParam, LogicalOperator, Pattern, PatternId, SemanticDiagnostic, Statement, TypeId, + TypeLongId, semantic_object_for_id, }; #[derive(Clone, Debug, PartialEq, Eq, DebugWithDb)] @@ -363,8 +366,12 @@ pub fn constant_semantic_data_cycle_helper( /// Checks if the given expression only involved constant calculations. pub fn validate_const_expr(ctx: &mut ComputationContext<'_>, expr_id: ExprId) { - let mut eval_ctx = - ConstantEvaluateContext { db: ctx.db, arenas: &ctx.arenas, diagnostics: ctx.diagnostics }; + let mut eval_ctx = ConstantEvaluateContext { + db: ctx.db, + arenas: &ctx.arenas, + vars: Default::default(), + diagnostics: ctx.diagnostics, + }; eval_ctx.validate(expr_id); } @@ -399,8 +406,12 @@ pub fn resolve_const_expr_and_evaluate( // Check that the expression is a valid constant. _ if ctx.diagnostics.error_count > prev_err_count => ConstValue::Missing(skip_diagnostic()), _ => { - let mut eval_ctx = - ConstantEvaluateContext { db, arenas: &ctx.arenas, diagnostics: ctx.diagnostics }; + let mut eval_ctx = ConstantEvaluateContext { + db, + arenas: &ctx.arenas, + vars: Default::default(), + diagnostics: ctx.diagnostics, + }; eval_ctx.validate(value.id); if eval_ctx.diagnostics.error_count > prev_err_count { ConstValue::Missing(skip_diagnostic()) @@ -446,6 +457,7 @@ pub fn value_as_const_value( struct ConstantEvaluateContext<'a> { db: &'a dyn SemanticGroup, arenas: &'a Arenas, + vars: OrderedHashMap, diagnostics: &'a mut SemanticDiagnostics, } impl ConstantEvaluateContext<'_> { @@ -453,9 +465,23 @@ impl ConstantEvaluateContext<'_> { fn validate(&mut self, expr_id: ExprId) { match &self.arenas.exprs[expr_id] { Expr::Var(_) | Expr::Constant(_) | Expr::Missing(_) => {} - Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) - if statements.is_empty() => - { + Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) => { + for statement_id in statements { + match &self.arenas.statements[*statement_id] { + Statement::Let(statement) => { + self.validate(statement.expr); + } + Statement::Expr(expr) => { + self.validate(expr.expr); + } + other => { + self.diagnostics.report( + other.stable_ptr(), + SemanticDiagnosticKind::UnsupportedConstant, + ); + } + } + } self.validate(*inner); } Expr::FunctionCall(expr) => { @@ -519,6 +545,27 @@ impl ConstantEvaluateContext<'_> { self.validate(*value); } }, + Expr::Snapshot(expr) => self.validate(expr.inner), + Expr::Desnap(expr) => self.validate(expr.inner), + Expr::LogicalOperator(expr) => { + self.validate(expr.lhs); + self.validate(expr.rhs); + } + Expr::Match(expr) => { + self.validate(expr.matched_expr); + for arm in &expr.arms { + self.validate(arm.expression); + } + } + Expr::If(expr) => { + self.validate(match &expr.condition { + Condition::BoolExpr(id) | Condition::Let(id, _) => *id, + }); + self.validate(expr.if_block); + if let Some(else_block) = expr.else_block { + self.validate(else_block); + } + } other => { self.diagnostics.report( other.stable_ptr().untyped(), @@ -561,10 +608,30 @@ impl ConstantEvaluateContext<'_> { let expr = &self.arenas.exprs[expr_id]; let db = self.db; match expr { + Expr::Var(expr) => self + .vars + .get(&expr.var) + .cloned() + .unwrap_or_else(|| ConstValue::Missing(skip_diagnostic())), Expr::Constant(expr) => expr.const_value_id.lookup_intern(db), - Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) - if statements.is_empty() => - { + Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) => { + for statement_id in statements { + match &self.arenas.statements[*statement_id] { + Statement::Let(statement) => { + let value = self.evaluate(statement.expr); + self.destructure_pattern(statement.pattern, value); + } + Statement::Expr(expr) => { + self.evaluate(expr.expr); + } + other => { + self.diagnostics.report( + other.stable_ptr(), + SemanticDiagnosticKind::UnsupportedConstant, + ); + } + } + } self.evaluate(*inner) } Expr::FunctionCall(expr) => self.evaluate_function_call(expr), @@ -626,6 +693,89 @@ impl ConstantEvaluateContext<'_> { }, expr.ty, ), + Expr::Snapshot(expr) => self.evaluate(expr.inner), + Expr::Desnap(expr) => self.evaluate(expr.inner), + Expr::LogicalOperator(expr) => { + let lhs = self.evaluate(expr.lhs); + if let ConstValue::Enum(v, _) = &lhs { + let early_return_variant = match expr.op { + LogicalOperator::AndAnd => false_variant(self.db), + LogicalOperator::OrOr => true_variant(self.db), + }; + if *v == early_return_variant { lhs } else { self.evaluate(expr.lhs) } + } else { + ConstValue::Missing(skip_diagnostic()) + } + } + Expr::Match(expr) => { + let value = self.evaluate(expr.matched_expr); + let ConstValue::Enum(variant, value) = value else { + return ConstValue::Missing(skip_diagnostic()); + }; + for arm in &expr.arms { + for pattern_id in &arm.patterns { + let pattern = &self.arenas.patterns[*pattern_id]; + if matches!(pattern, Pattern::Otherwise(_)) { + return self.evaluate(arm.expression); + } + let Pattern::EnumVariant(pattern) = pattern else { + continue; + }; + if pattern.variant.idx != variant.idx { + continue; + } + if let Some(inner_pattern) = pattern.inner_pattern { + self.destructure_pattern(inner_pattern, *value); + } + return self.evaluate(arm.expression); + } + } + ConstValue::Missing( + self.diagnostics.report( + expr.stable_ptr.untyped(), + SemanticDiagnosticKind::UnsupportedConstant, + ), + ) + } + Expr::If(expr) => match &expr.condition { + crate::Condition::BoolExpr(id) => { + let condition = self.evaluate(*id); + let ConstValue::Enum(variant, _) = condition else { + return ConstValue::Missing(skip_diagnostic()); + }; + if variant == true_variant(self.db) { + self.evaluate(expr.if_block) + } else if let Some(else_block) = expr.else_block { + self.evaluate(else_block) + } else { + ConstValue::Struct(vec![], unit_ty(self.db)) + } + } + crate::Condition::Let(id, patterns) => { + let value = self.evaluate(*id); + let ConstValue::Enum(variant, value) = value else { + return ConstValue::Missing(skip_diagnostic()); + }; + for pattern_id in patterns { + let Pattern::EnumVariant(pattern) = &self.arenas.patterns[*pattern_id] + else { + continue; + }; + if pattern.variant != variant { + continue; + } + if let Some(inner_pattern) = pattern.inner_pattern { + self.destructure_pattern(inner_pattern, *value); + } + return self.evaluate(expr.if_block); + } + if let Some(else_block) = expr.else_block { + self.evaluate(else_block) + } else { + ConstValue::Struct(vec![], unit_ty(self.db)) + } + } + }, _ => ConstValue::Missing(skip_diagnostic()), } } @@ -721,6 +871,61 @@ impl ConstantEvaluateContext<'_> { }; Ok(values.swap_remove(member_idx)) } + + /// Destructures the pattern into the const value of the variables in scope. + fn destructure_pattern(&mut self, pattern_id: PatternId, value: ConstValue) { + let pattern = &self.arenas.patterns[pattern_id]; + match pattern { + Pattern::Literal(_) + | Pattern::StringLiteral(_) + | Pattern::Otherwise(_) + | Pattern::Missing(_) => {} + Pattern::Variable(pattern) => { + self.vars.insert(VarId::Local(pattern.var.id), value); + } + Pattern::Struct(pattern) => { + if let ConstValue::Struct(inner_values, _) = value { + let member_order = + match self.db.concrete_struct_members(pattern.concrete_struct_id) { + Ok(member_order) => member_order, + Err(_) => return, + }; + for (member, inner_value) in zip(member_order.values(), inner_values) { + if let Some((_, inner_pattern)) = + pattern.field_patterns.iter().find(|(field, _)| member.id == field.id) + { + self.destructure_pattern(*inner_pattern, inner_value); + } + } + } + } + Pattern::Tuple(pattern) => { + if let ConstValue::Struct(inner_values, _) = value { + for (inner_pattern, inner_value) in zip(&pattern.field_patterns, inner_values) { + self.destructure_pattern(*inner_pattern, inner_value); + } + } + } + Pattern::FixedSizeArray(pattern) => { + if let ConstValue::Struct(inner_values, _) = value { + for (inner_pattern, inner_value) in + zip(&pattern.elements_patterns, inner_values) + { + self.destructure_pattern(*inner_pattern, inner_value); + } + } + } + Pattern::EnumVariant(pattern) => { + if let ConstValue::Enum(variant, inner_value) = value { + if pattern.variant == variant { + if let Some(inner_pattern) = pattern.inner_pattern { + self.destructure_pattern(inner_pattern, *inner_value); + } + } + } + } + } + } } /// Query implementation of [SemanticGroup::constant_semantic_diagnostics].