Skip to content

Commit

Permalink
Added support for more complex constant calculations. (#6979)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Jan 5, 2025
1 parent e4e4d06 commit bd52fe4
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 16 deletions.
26 changes: 26 additions & 0 deletions corelib/src/test/language_features/const_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,29 @@ fn test_two_complex_enums() {
.unbox() == (ThreeOptions2::A(1337), ThreeOptions2::C),
);
}

#[test]
fn test_complex_consts() {
const VAR_AND_MATCH_CONST: felt252 = {
let x = Option::Some((1, 2_u8));
match x {
Option::Some((v, _)) => v,
Option::None => 3,
}
};
assert_eq!(VAR_AND_MATCH_CONST, 1);
const TRUE: bool = true;
const IF_CONST_TRUE: felt252 = if TRUE {
4
} else {
5
};
assert_eq!(IF_CONST_TRUE, 4);
const FALSE: bool = false;
const IF_CONST_FALSE: felt252 = if FALSE {
6
} else {
7
};
assert_eq!(IF_CONST_FALSE, 7);
}
237 changes: 221 additions & 16 deletions crates/cairo-lang-semantic/src/items/constant.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use std::iter::zip;
use std::sync::Arc;

use cairo_lang_debug::DebugWithDb;
use cairo_lang_defs::ids::{
ConstantId, GenericParamId, LanguageElementId, LookupItemId, ModuleItemId,
NamedLanguageElementId, TraitConstantId,
NamedLanguageElementId, TraitConstantId, VarId,
};
use cairo_lang_diagnostics::{DiagnosticAdded, Diagnostics, Maybe, ToMaybe, skip_diagnostic};
use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
use cairo_lang_syntax::node::ast::ItemConstant;
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode};
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::{
Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
};
Expand All @@ -21,8 +23,8 @@ use smol_str::SmolStr;
use super::functions::{GenericFunctionId, GenericFunctionWithBodyId};
use super::imp::ImplId;
use crate::corelib::{
CoreTraitContext, LiteralError, core_box_ty, core_nonzero_ty, get_core_trait,
get_core_ty_by_name, try_extract_nz_wrapped_type, validate_literal,
CoreTraitContext, LiteralError, core_box_ty, core_nonzero_ty, false_variant, get_core_trait,
get_core_ty_by_name, true_variant, try_extract_nz_wrapped_type, unit_ty, validate_literal,
};
use crate::db::SemanticGroup;
use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
Expand All @@ -36,9 +38,10 @@ use crate::resolve::{Resolver, ResolverData};
use crate::substitution::SemanticRewriter;
use crate::types::resolve_type;
use crate::{
Arenas, ConcreteTypeId, ConcreteVariant, Expr, ExprBlock, ExprConstant, ExprFunctionCall,
ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, FunctionId, GenericParam,
SemanticDiagnostic, TypeId, TypeLongId, semantic_object_for_id,
Arenas, ConcreteTypeId, ConcreteVariant, Condition, Expr, ExprBlock, ExprConstant,
ExprFunctionCall, ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, FunctionId,
GenericParam, LogicalOperator, Pattern, PatternId, SemanticDiagnostic, Statement, TypeId,
TypeLongId, semantic_object_for_id,
};

#[derive(Clone, Debug, PartialEq, Eq, DebugWithDb)]
Expand Down Expand Up @@ -363,8 +366,12 @@ pub fn constant_semantic_data_cycle_helper(

/// Checks if the given expression only involved constant calculations.
pub fn validate_const_expr(ctx: &mut ComputationContext<'_>, expr_id: ExprId) {
let mut eval_ctx =
ConstantEvaluateContext { db: ctx.db, arenas: &ctx.arenas, diagnostics: ctx.diagnostics };
let mut eval_ctx = ConstantEvaluateContext {
db: ctx.db,
arenas: &ctx.arenas,
vars: Default::default(),
diagnostics: ctx.diagnostics,
};
eval_ctx.validate(expr_id);
}

Expand Down Expand Up @@ -399,8 +406,12 @@ pub fn resolve_const_expr_and_evaluate(
// Check that the expression is a valid constant.
_ if ctx.diagnostics.error_count > prev_err_count => ConstValue::Missing(skip_diagnostic()),
_ => {
let mut eval_ctx =
ConstantEvaluateContext { db, arenas: &ctx.arenas, diagnostics: ctx.diagnostics };
let mut eval_ctx = ConstantEvaluateContext {
db,
arenas: &ctx.arenas,
vars: Default::default(),
diagnostics: ctx.diagnostics,
};
eval_ctx.validate(value.id);
if eval_ctx.diagnostics.error_count > prev_err_count {
ConstValue::Missing(skip_diagnostic())
Expand Down Expand Up @@ -446,16 +457,31 @@ pub fn value_as_const_value(
struct ConstantEvaluateContext<'a> {
db: &'a dyn SemanticGroup,
arenas: &'a Arenas,
vars: OrderedHashMap<VarId, ConstValue>,
diagnostics: &'a mut SemanticDiagnostics,
}
impl ConstantEvaluateContext<'_> {
/// Validate the given expression can be used as constant.
fn validate(&mut self, expr_id: ExprId) {
match &self.arenas.exprs[expr_id] {
Expr::Var(_) | Expr::Constant(_) | Expr::Missing(_) => {}
Expr::Block(ExprBlock { statements, tail: Some(inner), .. })
if statements.is_empty() =>
{
Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) => {
for statement_id in statements {
match &self.arenas.statements[*statement_id] {
Statement::Let(statement) => {
self.validate(statement.expr);
}
Statement::Expr(expr) => {
self.validate(expr.expr);
}
other => {
self.diagnostics.report(
other.stable_ptr(),
SemanticDiagnosticKind::UnsupportedConstant,
);
}
}
}
self.validate(*inner);
}
Expr::FunctionCall(expr) => {
Expand Down Expand Up @@ -519,6 +545,27 @@ impl ConstantEvaluateContext<'_> {
self.validate(*value);
}
},
Expr::Snapshot(expr) => self.validate(expr.inner),
Expr::Desnap(expr) => self.validate(expr.inner),
Expr::LogicalOperator(expr) => {
self.validate(expr.lhs);
self.validate(expr.rhs);
}
Expr::Match(expr) => {
self.validate(expr.matched_expr);
for arm in &expr.arms {
self.validate(arm.expression);
}
}
Expr::If(expr) => {
self.validate(match &expr.condition {
Condition::BoolExpr(id) | Condition::Let(id, _) => *id,
});
self.validate(expr.if_block);
if let Some(else_block) = expr.else_block {
self.validate(else_block);
}
}
other => {
self.diagnostics.report(
other.stable_ptr().untyped(),
Expand Down Expand Up @@ -561,10 +608,30 @@ impl ConstantEvaluateContext<'_> {
let expr = &self.arenas.exprs[expr_id];
let db = self.db;
match expr {
Expr::Var(expr) => self
.vars
.get(&expr.var)
.cloned()
.unwrap_or_else(|| ConstValue::Missing(skip_diagnostic())),
Expr::Constant(expr) => expr.const_value_id.lookup_intern(db),
Expr::Block(ExprBlock { statements, tail: Some(inner), .. })
if statements.is_empty() =>
{
Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) => {
for statement_id in statements {
match &self.arenas.statements[*statement_id] {
Statement::Let(statement) => {
let value = self.evaluate(statement.expr);
self.destructure_pattern(statement.pattern, value);
}
Statement::Expr(expr) => {
self.evaluate(expr.expr);
}
other => {
self.diagnostics.report(
other.stable_ptr(),
SemanticDiagnosticKind::UnsupportedConstant,
);
}
}
}
self.evaluate(*inner)
}
Expr::FunctionCall(expr) => self.evaluate_function_call(expr),
Expand Down Expand Up @@ -626,6 +693,89 @@ impl ConstantEvaluateContext<'_> {
},
expr.ty,
),
Expr::Snapshot(expr) => self.evaluate(expr.inner),
Expr::Desnap(expr) => self.evaluate(expr.inner),
Expr::LogicalOperator(expr) => {
let lhs = self.evaluate(expr.lhs);
if let ConstValue::Enum(v, _) = &lhs {
let early_return_variant = match expr.op {
LogicalOperator::AndAnd => false_variant(self.db),
LogicalOperator::OrOr => true_variant(self.db),
};
if *v == early_return_variant { lhs } else { self.evaluate(expr.lhs) }
} else {
ConstValue::Missing(skip_diagnostic())
}
}
Expr::Match(expr) => {
let value = self.evaluate(expr.matched_expr);
let ConstValue::Enum(variant, value) = value else {
return ConstValue::Missing(skip_diagnostic());
};
for arm in &expr.arms {
for pattern_id in &arm.patterns {
let pattern = &self.arenas.patterns[*pattern_id];
if matches!(pattern, Pattern::Otherwise(_)) {
return self.evaluate(arm.expression);
}
let Pattern::EnumVariant(pattern) = pattern else {
continue;
};
if pattern.variant.idx != variant.idx {
continue;
}
if let Some(inner_pattern) = pattern.inner_pattern {
self.destructure_pattern(inner_pattern, *value);
}
return self.evaluate(arm.expression);
}
}
ConstValue::Missing(
self.diagnostics.report(
expr.stable_ptr.untyped(),
SemanticDiagnosticKind::UnsupportedConstant,
),
)
}
Expr::If(expr) => match &expr.condition {
crate::Condition::BoolExpr(id) => {
let condition = self.evaluate(*id);
let ConstValue::Enum(variant, _) = condition else {
return ConstValue::Missing(skip_diagnostic());
};
if variant == true_variant(self.db) {
self.evaluate(expr.if_block)
} else if let Some(else_block) = expr.else_block {
self.evaluate(else_block)
} else {
ConstValue::Struct(vec![], unit_ty(self.db))
}
}
crate::Condition::Let(id, patterns) => {
let value = self.evaluate(*id);
let ConstValue::Enum(variant, value) = value else {
return ConstValue::Missing(skip_diagnostic());
};
for pattern_id in patterns {
let Pattern::EnumVariant(pattern) = &self.arenas.patterns[*pattern_id]
else {
continue;
};
if pattern.variant != variant {
continue;
}
if let Some(inner_pattern) = pattern.inner_pattern {
self.destructure_pattern(inner_pattern, *value);
}
return self.evaluate(expr.if_block);
}
if let Some(else_block) = expr.else_block {
self.evaluate(else_block)
} else {
ConstValue::Struct(vec![], unit_ty(self.db))
}
}
},
_ => ConstValue::Missing(skip_diagnostic()),
}
}
Expand Down Expand Up @@ -721,6 +871,61 @@ impl ConstantEvaluateContext<'_> {
};
Ok(values.swap_remove(member_idx))
}

/// Destructures the pattern into the const value of the variables in scope.
fn destructure_pattern(&mut self, pattern_id: PatternId, value: ConstValue) {
let pattern = &self.arenas.patterns[pattern_id];
match pattern {
Pattern::Literal(_)
| Pattern::StringLiteral(_)
| Pattern::Otherwise(_)
| Pattern::Missing(_) => {}
Pattern::Variable(pattern) => {
self.vars.insert(VarId::Local(pattern.var.id), value);
}
Pattern::Struct(pattern) => {
if let ConstValue::Struct(inner_values, _) = value {
let member_order =
match self.db.concrete_struct_members(pattern.concrete_struct_id) {
Ok(member_order) => member_order,
Err(_) => return,
};
for (member, inner_value) in zip(member_order.values(), inner_values) {
if let Some((_, inner_pattern)) =
pattern.field_patterns.iter().find(|(field, _)| member.id == field.id)
{
self.destructure_pattern(*inner_pattern, inner_value);
}
}
}
}
Pattern::Tuple(pattern) => {
if let ConstValue::Struct(inner_values, _) = value {
for (inner_pattern, inner_value) in zip(&pattern.field_patterns, inner_values) {
self.destructure_pattern(*inner_pattern, inner_value);
}
}
}
Pattern::FixedSizeArray(pattern) => {
if let ConstValue::Struct(inner_values, _) = value {
for (inner_pattern, inner_value) in
zip(&pattern.elements_patterns, inner_values)
{
self.destructure_pattern(*inner_pattern, inner_value);
}
}
}
Pattern::EnumVariant(pattern) => {
if let ConstValue::Enum(variant, inner_value) = value {
if pattern.variant == variant {
if let Some(inner_pattern) = pattern.inner_pattern {
self.destructure_pattern(inner_pattern, *inner_value);
}
}
}
}
}
}
}

/// Query implementation of [SemanticGroup::constant_semantic_diagnostics].
Expand Down

0 comments on commit bd52fe4

Please sign in to comment.