Skip to content

Commit

Permalink
Added ! for const evaluation.
Browse files Browse the repository at this point in the history
Additionally refactored bool handling and testing by assert.

commit-id:73bf14cd
  • Loading branch information
orizi committed Jan 5, 2025
1 parent 2950fa3 commit b7c6e77
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 48 deletions.
2 changes: 1 addition & 1 deletion corelib/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ pub fn panic_with_felt252(err_code: felt252) -> never {
/// assert(false, 'error message');
/// ```
#[inline]
pub fn assert(cond: bool, err_code: felt252) {
pub const fn assert(cond: bool, err_code: felt252) {
if !cond {
panic_with_felt252(err_code)
}
Expand Down
42 changes: 18 additions & 24 deletions crates/cairo-lang-semantic/src/expr/test_data/constant
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,14 @@ const FAILING_CALC: felt252 = if true {
70
};

const FUNC_CALC_SUCCESS: () = panic_if_true(false);
const FUNC_CALC_FAILURE: () = panic_if_true(true);
const VALID_EQ: () = panic_if_true(1 == 2);
const VALID_NE: () = panic_if_true(1 != 1);
const VALID_LT: () = panic_if_true(1_usize < 1);
const VALID_LE: () = panic_if_true(2_usize <= 1);
const VALID_GT: () = panic_if_true(1_usize > 1);
const VALID_GE: () = panic_if_true(1_usize >= 2);

const fn panic_if_true(cond: bool) {
if cond {
core::panic_with_felt252('assertion failed')
}
}
const FUNC_CALC_SUCCESS: () = assert(true, 'works');
const FUNC_CALC_FAILURE: () = assert(false, 'does not work');
const VALID_EQ: () = assert(1 == 1, '1 == 1');
const VALID_NE: () = assert(1 != 2, '1 != 1');
const VALID_LT: () = assert(1_usize < 2, '1 < 2');
const VALID_LE: () = assert(1_usize <= 1, '1 <= 1');
const VALID_GT: () = assert(2_usize > 1, '2 > 1');
const VALID_GE: () = assert(1_usize >= 1, '1 >= 1');

const FUNC_CALC_SUCCESS_OPTION: felt252 = Option::Some(5).unwrap();
const FUNC_CALC_FAILURE_OPTION: felt252 = Option::None.unwrap();
Expand Down Expand Up @@ -123,15 +117,15 @@ error: Failed to calculate constant.

error: Failed to calculate constant.
--> lib.cairo:23:31
const FUNC_CALC_FAILURE: () = panic_if_true(true);
^^^^^^^^^^^^^^^^^^^
note: In `test::panic_if_true`:
--> lib.cairo:33:9
core::panic_with_felt252('assertion failed')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
const FUNC_CALC_FAILURE: () = assert(false, 'does not work');
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
note: In `core::assert`:
--> /home/ori/rust/cairo-ws/corelib/src/lib.cairo:359:9
panic_with_felt252(err_code)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: Failed to calculate constant.
--> lib.cairo:38:43
--> lib.cairo:32:43
const FUNC_CALC_FAILURE_OPTION: felt252 = Option::None.unwrap();
^^^^^^^^^^^^^^^^^^^^^
note: In `core::option::OptionTraitImpl::<T>::expect`:
Expand All @@ -144,7 +138,7 @@ note: In `core::option::OptionTraitImpl::<core::felt252>::unwrap`:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: Failed to calculate constant.
--> lib.cairo:40:44
--> lib.cairo:34:44
const FUNC_CALC_FAILURE_RESULT1: felt252 = Result::Err(5).unwrap();
^^^^^^^^^^^^^^^^^^^^^^^
note: In `core::result::ResultTraitImpl::<T, E>::expect::<core::traits::PanicDestructForDestruct::<E, +Destruct<E>>>`:
Expand All @@ -157,7 +151,7 @@ note: In `core::result::ResultTraitImpl::<core::felt252, core::felt252>::unwrap:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: Failed to calculate constant.
--> lib.cairo:42:44
--> lib.cairo:36:44
const FUNC_CALC_FAILURE_RESULT2: felt252 = Result::Ok(5).unwrap_err();
^^^^^^^^^^^^^^^^^^^^^^^^^^
note: In `core::result::ResultTraitImpl::<T, E>::expect_err::<+PanicDestruct<T>>`:
Expand All @@ -170,7 +164,7 @@ note: In `core::result::ResultTraitImpl::<core::felt252, core::felt252>::unwrap_
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: Constant calculation depth exceeded.
--> lib.cairo:44:43
--> lib.cairo:38:43
const FUNC_CALC_STACK_EXCEEDED: felt252 = call_myself();
^^^^^^^^^^^^^

Expand Down
60 changes: 37 additions & 23 deletions crates/cairo-lang-semantic/src/items/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ impl ConstantEvaluateContext<'_> {
} else if let Some(else_block) = expr.else_block {
self.evaluate(else_block)
} else {
ConstValue::Struct(vec![], unit_ty(self.db))
self.unit_const.clone()
}
}
crate::Condition::Let(id, patterns) => {
Expand All @@ -786,7 +786,7 @@ impl ConstantEvaluateContext<'_> {
if let Some(else_block) = expr.else_block {
self.evaluate(else_block)
} else {
ConstValue::Struct(vec![], unit_ty(self.db))
self.unit_const.clone()
}
}
},
Expand Down Expand Up @@ -822,16 +822,15 @@ impl ConstantEvaluateContext<'_> {

let imp = extract_matches!(concrete_function.generic_function, GenericFunctionId::Impl);
let bool_value = |condition: bool| {
ConstValue::Enum(
if condition { true_variant(db) } else { false_variant(db) },
ConstValue::Struct(vec![], unit_ty(db)).into(),
)
if condition { self.true_const.clone() } else { self.false_const.clone() }
};

if imp.function == self.eq_fn {
return bool_value(args[0] == args[1]);
} else if imp.function == self.ne_fn {
return bool_value(args[0] != args[1]);
} else if imp.function == self.not_fn {
return bool_value(args[0] == self.false_const);
}

let args = match args
Expand Down Expand Up @@ -1123,42 +1122,50 @@ pub struct ConstCalcInfo {
/// Traits that are allowed for consts if their impls is in the corelib.
const_traits: UnorderedHashSet<TraitId>,
/// The trait function for `Neg::neg`.
pub neg_fn: TraitFunctionId,
neg_fn: TraitFunctionId,
/// The trait function for `Add::add`.
pub add_fn: TraitFunctionId,
add_fn: TraitFunctionId,
/// The trait function for `Sub::sub`.
pub sub_fn: TraitFunctionId,
sub_fn: TraitFunctionId,
/// The trait function for `Mul::mul`.
pub mul_fn: TraitFunctionId,
mul_fn: TraitFunctionId,
/// The trait function for `Div::div`.
pub div_fn: TraitFunctionId,
div_fn: TraitFunctionId,
/// The trait function for `Rem::rem`.
pub rem_fn: TraitFunctionId,
rem_fn: TraitFunctionId,
/// The trait function for `BitAnd::bitand`.
pub bit_and_fn: TraitFunctionId,
bit_and_fn: TraitFunctionId,
/// The trait function for `BitOr::bitor`.
pub bit_or_fn: TraitFunctionId,
bit_or_fn: TraitFunctionId,
/// The trait function for `BitXor::bitxor`.
pub bit_xor_fn: TraitFunctionId,
bit_xor_fn: TraitFunctionId,
/// The trait function for `PartialEq::eq`.
pub eq_fn: TraitFunctionId,
eq_fn: TraitFunctionId,
/// The trait function for `PartialEq::ne`.
pub ne_fn: TraitFunctionId,
ne_fn: TraitFunctionId,
/// The trait function for `PartialOrd::lt`.
pub lt_fn: TraitFunctionId,
lt_fn: TraitFunctionId,
/// The trait function for `PartialOrd::le`.
pub le_fn: TraitFunctionId,
le_fn: TraitFunctionId,
/// The trait function for `PartialOrd::gt`.
pub gt_fn: TraitFunctionId,
gt_fn: TraitFunctionId,
/// The trait function for `PartialOrd::ge`.
pub ge_fn: TraitFunctionId,
ge_fn: TraitFunctionId,
/// The trait function for `Not::not`.
not_fn: TraitFunctionId,
/// The const value for the unit type `()`.
unit_const: ConstValue,
/// The const value for `true`.
true_const: ConstValue,
/// The const value for `false`.
false_const: ConstValue,
/// The function for panicking with a felt252.
pub panic_with_felt252: FunctionId,
panic_with_felt252: FunctionId,
}

impl ConstCalcInfo {
/// Creates a new ConstCalcInfo.
pub fn new(db: &dyn SemanticGroup) -> Self {
fn new(db: &dyn SemanticGroup) -> Self {
let neg_trait = get_core_trait(db, CoreTraitContext::TopLevel, "Neg".into());
let add_trait = get_core_trait(db, CoreTraitContext::TopLevel, "Add".into());
let sub_trait = get_core_trait(db, CoreTraitContext::TopLevel, "Sub".into());
Expand All @@ -1170,9 +1177,11 @@ impl ConstCalcInfo {
let bit_xor_trait = get_core_trait(db, CoreTraitContext::TopLevel, "BitXor".into());
let partial_eq_trait = get_core_trait(db, CoreTraitContext::TopLevel, "PartialEq".into());
let partial_ord_trait = get_core_trait(db, CoreTraitContext::TopLevel, "PartialOrd".into());
let not_trait = get_core_trait(db, CoreTraitContext::TopLevel, "Not".into());
let trait_fn = |trait_id, name: &str| {
db.trait_function_by_name(trait_id, name.into()).unwrap().unwrap()
};
let unit_const = ConstValue::Struct(vec![], unit_ty(db));
Self {
const_traits: [
neg_trait,
Expand All @@ -1186,6 +1195,7 @@ impl ConstCalcInfo {
bit_xor_trait,
partial_eq_trait,
partial_ord_trait,
not_trait,
]
.into_iter()
.collect(),
Expand All @@ -1204,6 +1214,10 @@ impl ConstCalcInfo {
le_fn: trait_fn(partial_ord_trait, "le"),
gt_fn: trait_fn(partial_ord_trait, "gt"),
ge_fn: trait_fn(partial_ord_trait, "ge"),
not_fn: trait_fn(not_trait, "not"),
true_const: ConstValue::Enum(true_variant(db), unit_const.clone().into()),
false_const: ConstValue::Enum(false_variant(db), unit_const.clone().into()),
unit_const,
panic_with_felt252: get_core_function_id(db, "panic_with_felt252".into(), vec![]),
}
}
Expand Down

0 comments on commit b7c6e77

Please sign in to comment.