Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for more complex constant calculations. #6979

Merged
merged 1 commit into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions corelib/src/test/language_features/const_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,29 @@ fn test_two_complex_enums() {
.unbox() == (ThreeOptions2::A(1337), ThreeOptions2::C),
);
}

#[test]
fn test_complex_consts() {
const VAR_AND_MATCH_CONST: felt252 = {
let x = Option::Some((1, 2_u8));
match x {
Option::Some((v, _)) => v,
Option::None => 3,
}
};
assert_eq!(VAR_AND_MATCH_CONST, 1);
const TRUE: bool = true;
const IF_CONST_TRUE: felt252 = if TRUE {
4
} else {
5
};
assert_eq!(IF_CONST_TRUE, 4);
const FALSE: bool = false;
const IF_CONST_FALSE: felt252 = if FALSE {
6
} else {
7
};
assert_eq!(IF_CONST_FALSE, 7);
}
237 changes: 221 additions & 16 deletions crates/cairo-lang-semantic/src/items/constant.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use std::iter::zip;
use std::sync::Arc;

use cairo_lang_debug::DebugWithDb;
use cairo_lang_defs::ids::{
ConstantId, GenericParamId, LanguageElementId, LookupItemId, ModuleItemId,
NamedLanguageElementId, TraitConstantId,
NamedLanguageElementId, TraitConstantId, VarId,
};
use cairo_lang_diagnostics::{DiagnosticAdded, Diagnostics, Maybe, ToMaybe, skip_diagnostic};
use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
use cairo_lang_syntax::node::ast::ItemConstant;
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode};
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::{
Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
};
Expand All @@ -21,8 +23,8 @@ use smol_str::SmolStr;
use super::functions::{GenericFunctionId, GenericFunctionWithBodyId};
use super::imp::ImplId;
use crate::corelib::{
CoreTraitContext, LiteralError, core_box_ty, core_nonzero_ty, get_core_trait,
get_core_ty_by_name, try_extract_nz_wrapped_type, validate_literal,
CoreTraitContext, LiteralError, core_box_ty, core_nonzero_ty, false_variant, get_core_trait,
get_core_ty_by_name, true_variant, try_extract_nz_wrapped_type, unit_ty, validate_literal,
};
use crate::db::SemanticGroup;
use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
Expand All @@ -36,9 +38,10 @@ use crate::resolve::{Resolver, ResolverData};
use crate::substitution::SemanticRewriter;
use crate::types::resolve_type;
use crate::{
Arenas, ConcreteTypeId, ConcreteVariant, Expr, ExprBlock, ExprConstant, ExprFunctionCall,
ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, FunctionId, GenericParam,
SemanticDiagnostic, TypeId, TypeLongId, semantic_object_for_id,
Arenas, ConcreteTypeId, ConcreteVariant, Condition, Expr, ExprBlock, ExprConstant,
ExprFunctionCall, ExprFunctionCallArg, ExprId, ExprMemberAccess, ExprStructCtor, FunctionId,
GenericParam, LogicalOperator, Pattern, PatternId, SemanticDiagnostic, Statement, TypeId,
TypeLongId, semantic_object_for_id,
};

#[derive(Clone, Debug, PartialEq, Eq, DebugWithDb)]
Expand Down Expand Up @@ -363,8 +366,12 @@ pub fn constant_semantic_data_cycle_helper(

/// Checks if the given expression only involved constant calculations.
pub fn validate_const_expr(ctx: &mut ComputationContext<'_>, expr_id: ExprId) {
let mut eval_ctx =
ConstantEvaluateContext { db: ctx.db, arenas: &ctx.arenas, diagnostics: ctx.diagnostics };
let mut eval_ctx = ConstantEvaluateContext {
db: ctx.db,
arenas: &ctx.arenas,
vars: Default::default(),
diagnostics: ctx.diagnostics,
};
eval_ctx.validate(expr_id);
}

Expand Down Expand Up @@ -399,8 +406,12 @@ pub fn resolve_const_expr_and_evaluate(
// Check that the expression is a valid constant.
_ if ctx.diagnostics.error_count > prev_err_count => ConstValue::Missing(skip_diagnostic()),
_ => {
let mut eval_ctx =
ConstantEvaluateContext { db, arenas: &ctx.arenas, diagnostics: ctx.diagnostics };
let mut eval_ctx = ConstantEvaluateContext {
db,
arenas: &ctx.arenas,
vars: Default::default(),
diagnostics: ctx.diagnostics,
};
eval_ctx.validate(value.id);
if eval_ctx.diagnostics.error_count > prev_err_count {
ConstValue::Missing(skip_diagnostic())
Expand Down Expand Up @@ -446,16 +457,31 @@ pub fn value_as_const_value(
struct ConstantEvaluateContext<'a> {
db: &'a dyn SemanticGroup,
arenas: &'a Arenas,
vars: OrderedHashMap<VarId, ConstValue>,
diagnostics: &'a mut SemanticDiagnostics,
}
impl ConstantEvaluateContext<'_> {
/// Validate the given expression can be used as constant.
fn validate(&mut self, expr_id: ExprId) {
match &self.arenas.exprs[expr_id] {
Expr::Var(_) | Expr::Constant(_) | Expr::Missing(_) => {}
Expr::Block(ExprBlock { statements, tail: Some(inner), .. })
if statements.is_empty() =>
{
Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) => {
for statement_id in statements {
match &self.arenas.statements[*statement_id] {
Statement::Let(statement) => {
self.validate(statement.expr);
}
Statement::Expr(expr) => {
self.validate(expr.expr);
}
other => {
self.diagnostics.report(
other.stable_ptr(),
SemanticDiagnosticKind::UnsupportedConstant,
);
}
}
}
self.validate(*inner);
}
Expr::FunctionCall(expr) => {
Expand Down Expand Up @@ -519,6 +545,27 @@ impl ConstantEvaluateContext<'_> {
self.validate(*value);
}
},
Expr::Snapshot(expr) => self.validate(expr.inner),
Expr::Desnap(expr) => self.validate(expr.inner),
Expr::LogicalOperator(expr) => {
self.validate(expr.lhs);
self.validate(expr.rhs);
}
Expr::Match(expr) => {
self.validate(expr.matched_expr);
for arm in &expr.arms {
self.validate(arm.expression);
}
}
Expr::If(expr) => {
self.validate(match &expr.condition {
Condition::BoolExpr(id) | Condition::Let(id, _) => *id,
});
self.validate(expr.if_block);
if let Some(else_block) = expr.else_block {
self.validate(else_block);
}
}
other => {
self.diagnostics.report(
other.stable_ptr().untyped(),
Expand Down Expand Up @@ -561,10 +608,30 @@ impl ConstantEvaluateContext<'_> {
let expr = &self.arenas.exprs[expr_id];
let db = self.db;
match expr {
Expr::Var(expr) => self
.vars
.get(&expr.var)
.cloned()
.unwrap_or_else(|| ConstValue::Missing(skip_diagnostic())),
Expr::Constant(expr) => expr.const_value_id.lookup_intern(db),
Expr::Block(ExprBlock { statements, tail: Some(inner), .. })
if statements.is_empty() =>
{
Expr::Block(ExprBlock { statements, tail: Some(inner), .. }) => {
for statement_id in statements {
match &self.arenas.statements[*statement_id] {
Statement::Let(statement) => {
let value = self.evaluate(statement.expr);
self.destructure_pattern(statement.pattern, value);
}
Statement::Expr(expr) => {
self.evaluate(expr.expr);
}
other => {
self.diagnostics.report(
other.stable_ptr(),
SemanticDiagnosticKind::UnsupportedConstant,
);
}
}
}
self.evaluate(*inner)
}
Expr::FunctionCall(expr) => self.evaluate_function_call(expr),
Expand Down Expand Up @@ -626,6 +693,89 @@ impl ConstantEvaluateContext<'_> {
},
expr.ty,
),
Expr::Snapshot(expr) => self.evaluate(expr.inner),
Expr::Desnap(expr) => self.evaluate(expr.inner),
Expr::LogicalOperator(expr) => {
let lhs = self.evaluate(expr.lhs);
if let ConstValue::Enum(v, _) = &lhs {
let early_return_variant = match expr.op {
LogicalOperator::AndAnd => false_variant(self.db),
LogicalOperator::OrOr => true_variant(self.db),
};
if *v == early_return_variant { lhs } else { self.evaluate(expr.lhs) }
} else {
ConstValue::Missing(skip_diagnostic())
}
}
Expr::Match(expr) => {
let value = self.evaluate(expr.matched_expr);
let ConstValue::Enum(variant, value) = value else {
return ConstValue::Missing(skip_diagnostic());
};
for arm in &expr.arms {
for pattern_id in &arm.patterns {
let pattern = &self.arenas.patterns[*pattern_id];
if matches!(pattern, Pattern::Otherwise(_)) {
return self.evaluate(arm.expression);
}
let Pattern::EnumVariant(pattern) = pattern else {
continue;
};
if pattern.variant.idx != variant.idx {
continue;
}
if let Some(inner_pattern) = pattern.inner_pattern {
self.destructure_pattern(inner_pattern, *value);
}
return self.evaluate(arm.expression);
}
}
ConstValue::Missing(
self.diagnostics.report(
expr.stable_ptr.untyped(),
SemanticDiagnosticKind::UnsupportedConstant,
),
)
}
Expr::If(expr) => match &expr.condition {
crate::Condition::BoolExpr(id) => {
let condition = self.evaluate(*id);
let ConstValue::Enum(variant, _) = condition else {
return ConstValue::Missing(skip_diagnostic());
};
if variant == true_variant(self.db) {
self.evaluate(expr.if_block)
} else if let Some(else_block) = expr.else_block {
self.evaluate(else_block)
} else {
ConstValue::Struct(vec![], unit_ty(self.db))
}
}
crate::Condition::Let(id, patterns) => {
let value = self.evaluate(*id);
let ConstValue::Enum(variant, value) = value else {
return ConstValue::Missing(skip_diagnostic());
};
for pattern_id in patterns {
let Pattern::EnumVariant(pattern) = &self.arenas.patterns[*pattern_id]
else {
continue;
};
if pattern.variant != variant {
continue;
}
if let Some(inner_pattern) = pattern.inner_pattern {
self.destructure_pattern(inner_pattern, *value);
}
return self.evaluate(expr.if_block);
}
if let Some(else_block) = expr.else_block {
self.evaluate(else_block)
} else {
ConstValue::Struct(vec![], unit_ty(self.db))
}
}
},
_ => ConstValue::Missing(skip_diagnostic()),
}
}
Expand Down Expand Up @@ -721,6 +871,61 @@ impl ConstantEvaluateContext<'_> {
};
Ok(values.swap_remove(member_idx))
}

/// Destructures the pattern into the const value of the variables in scope.
fn destructure_pattern(&mut self, pattern_id: PatternId, value: ConstValue) {
let pattern = &self.arenas.patterns[pattern_id];
match pattern {
Pattern::Literal(_)
| Pattern::StringLiteral(_)
| Pattern::Otherwise(_)
| Pattern::Missing(_) => {}
Pattern::Variable(pattern) => {
self.vars.insert(VarId::Local(pattern.var.id), value);
}
Pattern::Struct(pattern) => {
if let ConstValue::Struct(inner_values, _) = value {
let member_order =
match self.db.concrete_struct_members(pattern.concrete_struct_id) {
Ok(member_order) => member_order,
Err(_) => return,
};
for (member, inner_value) in zip(member_order.values(), inner_values) {
if let Some((_, inner_pattern)) =
pattern.field_patterns.iter().find(|(field, _)| member.id == field.id)
{
self.destructure_pattern(*inner_pattern, inner_value);
}
}
}
}
Pattern::Tuple(pattern) => {
if let ConstValue::Struct(inner_values, _) = value {
for (inner_pattern, inner_value) in zip(&pattern.field_patterns, inner_values) {
self.destructure_pattern(*inner_pattern, inner_value);
}
}
}
Pattern::FixedSizeArray(pattern) => {
if let ConstValue::Struct(inner_values, _) = value {
for (inner_pattern, inner_value) in
zip(&pattern.elements_patterns, inner_values)
{
self.destructure_pattern(*inner_pattern, inner_value);
}
}
}
Pattern::EnumVariant(pattern) => {
if let ConstValue::Enum(variant, inner_value) = value {
if pattern.variant == variant {
if let Some(inner_pattern) = pattern.inner_pattern {
self.destructure_pattern(inner_pattern, *inner_value);
}
}
}
}
}
}
}

/// Query implementation of [SemanticGroup::constant_semantic_diagnostics].
Expand Down
Loading