Skip to content

Commit

Permalink
Added support for const eval div_rem.
Browse files Browse the repository at this point in the history
Required minor refactoring of args handling.

commit-id:e185f290
  • Loading branch information
orizi committed Jan 5, 2025
1 parent 8c5fdd5 commit 42b9426
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 39 deletions.
9 changes: 5 additions & 4 deletions crates/cairo-lang-semantic/src/expr/test_data/constant
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ const VALID_LT: () = assert(1_usize < 2, '1 < 2');
const VALID_LE: () = assert(1_usize <= 1, '1 <= 1');
const VALID_GT: () = assert(2_usize > 1, '2 > 1');
const VALID_GE: () = assert(1_usize >= 1, '1 >= 1');
const VALID_DIVREM: () = assert(DivRem::div_rem(5_u8, 2) == (2, 1), 'div_rem(5, 2) == (2, 1)');

const FUNC_CALC_SUCCESS_OPTION: felt252 = Option::Some(5).unwrap();
const FUNC_CALC_FAILURE_OPTION: felt252 = Option::None.unwrap();
Expand Down Expand Up @@ -125,7 +126,7 @@ note: In `core::assert`:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: Failed to calculate constant.
--> lib.cairo:32:43
--> lib.cairo:33:43
const FUNC_CALC_FAILURE_OPTION: felt252 = Option::None.unwrap();
^^^^^^^^^^^^^^^^^^^^^
note: In `core::option::OptionTraitImpl::<T>::expect`:
Expand All @@ -138,7 +139,7 @@ note: In `core::option::OptionTraitImpl::<core::felt252>::unwrap`:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: Failed to calculate constant.
--> lib.cairo:34:44
--> lib.cairo:35:44
const FUNC_CALC_FAILURE_RESULT1: felt252 = Result::Err(5).unwrap();
^^^^^^^^^^^^^^^^^^^^^^^
note: In `core::result::ResultTraitImpl::<T, E>::expect::<core::traits::PanicDestructForDestruct::<E, +Destruct<E>>>`:
Expand All @@ -151,7 +152,7 @@ note: In `core::result::ResultTraitImpl::<core::felt252, core::felt252>::unwrap:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: Failed to calculate constant.
--> lib.cairo:36:44
--> lib.cairo:37:44
const FUNC_CALC_FAILURE_RESULT2: felt252 = Result::Ok(5).unwrap_err();
^^^^^^^^^^^^^^^^^^^^^^^^^^
note: In `core::result::ResultTraitImpl::<T, E>::expect_err::<+PanicDestruct<T>>`:
Expand All @@ -164,7 +165,7 @@ note: In `core::result::ResultTraitImpl::<core::felt252, core::felt252>::unwrap_
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: Constant calculation depth exceeded.
--> lib.cairo:38:43
--> lib.cairo:39:43
const FUNC_CALC_STACK_EXCEEDED: felt252 = call_myself();
^^^^^^^^^^^^^

Expand Down
101 changes: 66 additions & 35 deletions crates/cairo-lang-semantic/src/items/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -835,49 +835,45 @@ impl ConstantEvaluateContext<'_> {

let args = match args
.into_iter()
.map(|arg| match arg {
ConstValue::Int(v, _ty) => Ok(v),
// Handling u256 constants to enable const evaluation of them.
ConstValue::Struct(v, _) => {
if let [ConstValue::Int(low, _), ConstValue::Int(high, _)] = &v[..] {
Ok(low + (high << 128))
} else {
// Dignostic can be skipped as we would either have a semantic error for a
// bad arg for the function, or the arg itself
// could'nt have been calculated.
Err(skip_diagnostic())
}
}
ConstValue::Missing(err) => Err(err),
// Dignostic can be skipped as we would either have a semantic error for a bad arg
// for the function, or the arg itself could'nt have been calculated.
_ => Err(skip_diagnostic()),
})
.collect::<Result<Vec<_>, _>>()
.map(|arg| NumericArg::try_new(db, arg))
.collect::<Option<Vec<_>>>()
{
Ok(args) => args,
Err(err) => return ConstValue::Missing(err),
Some(args) => args,
// Dignostic can be skipped as we would either have a semantic error for a bad arg for
// the function, or the arg itself could'nt have been calculated.
None => return ConstValue::Missing(skip_diagnostic()),
};
let mut value = match imp.function {
id if id == self.neg_fn => -&args[0],
id if id == self.add_fn => &args[0] + &args[1],
id if id == self.sub_fn => &args[0] - &args[1],
id if id == self.mul_fn => &args[0] * &args[1],
id if (id == self.div_fn || id == self.rem_fn) && args[1].is_zero() => {
id if id == self.neg_fn => -&args[0].v,
id if id == self.add_fn => &args[0].v + &args[1].v,
id if id == self.sub_fn => &args[0].v - &args[1].v,
id if id == self.mul_fn => &args[0].v * &args[1].v,
id if (id == self.div_fn || id == self.rem_fn) && args[1].v.is_zero() => {
return ConstValue::Missing(
self.diagnostics
.report(expr.stable_ptr.untyped(), SemanticDiagnosticKind::DivisionByZero),
);
}
id if id == self.div_fn => &args[0] / &args[1],
id if id == self.rem_fn => &args[0] % &args[1],
id if id == self.bit_and_fn => &args[0] & &args[1],
id if id == self.bit_or_fn => &args[0] | &args[1],
id if id == self.bit_xor_fn => &args[0] ^ &args[1],
id if id == self.lt_fn => return bool_value(args[0] < args[1]),
id if id == self.le_fn => return bool_value(args[0] <= args[1]),
id if id == self.gt_fn => return bool_value(args[0] > args[1]),
id if id == self.ge_fn => return bool_value(args[0] >= args[1]),
id if id == self.div_fn => &args[0].v / &args[1].v,
id if id == self.rem_fn => &args[0].v % &args[1].v,
id if id == self.bit_and_fn => &args[0].v & &args[1].v,
id if id == self.bit_or_fn => &args[0].v | &args[1].v,
id if id == self.bit_xor_fn => &args[0].v ^ &args[1].v,
id if id == self.lt_fn => return bool_value(args[0].v < args[1].v),
id if id == self.le_fn => return bool_value(args[0].v <= args[1].v),
id if id == self.gt_fn => return bool_value(args[0].v > args[1].v),
id if id == self.ge_fn => return bool_value(args[0].v >= args[1].v),
id if id == self.div_rem_fn => {
// No need for non-zero check as this is type checked to begin with.
// Also results are always in the range of the input type, so `unwrap`s are ok.
return ConstValue::Struct(
vec![
value_as_const_value(db, args[0].ty, &(&args[0].v / &args[1].v)).unwrap(),
value_as_const_value(db, args[0].ty, &(&args[0].v % &args[1].v)).unwrap(),
],
expr.ty,
);
}
_ => {
unreachable!("Unexpected function call in constant lowering: {:?}", expr)
}
Expand Down Expand Up @@ -1041,6 +1037,36 @@ impl std::ops::Deref for ConstantEvaluateContext<'_> {
}
}

/// Helper for the arguments info.
struct NumericArg {
/// The arg's integer value.
v: BigInt,
/// The arg's type.
ty: TypeId,
}
impl NumericArg {
fn try_new(db: &dyn SemanticGroup, arg: ConstValue) -> Option<Self> {
Some(Self { ty: arg.ty(db).ok()?, v: numeric_arg_value(arg)? })
}
}

/// Helper for creating a `NumericArg` value.
/// This includes unwrapping of `NonZero` values and struct of 2 values as a `u256`.
fn numeric_arg_value(value: ConstValue) -> Option<BigInt> {
match value {
ConstValue::Int(value, _) => Some(value),
ConstValue::Struct(v, _) => {
if let [ConstValue::Int(low, _), ConstValue::Int(high, _)] = &v[..] {
Some(low + (high << 128))
} else {
None
}
}
ConstValue::NonZero(const_value) => numeric_arg_value(*const_value),
_ => None,
}
}

/// Query implementation of [SemanticGroup::constant_semantic_diagnostics].
pub fn constant_semantic_diagnostics(
db: &dyn SemanticGroup,
Expand Down Expand Up @@ -1133,6 +1159,8 @@ pub struct ConstCalcInfo {
div_fn: TraitFunctionId,
/// The trait function for `Rem::rem`.
rem_fn: TraitFunctionId,
/// The trait function for `DivRem::div_rem`.
div_rem_fn: TraitFunctionId,
/// The trait function for `BitAnd::bitand`.
bit_and_fn: TraitFunctionId,
/// The trait function for `BitOr::bitor`.
Expand Down Expand Up @@ -1172,6 +1200,7 @@ impl ConstCalcInfo {
let mul_trait = get_core_trait(db, CoreTraitContext::TopLevel, "Mul".into());
let div_trait = get_core_trait(db, CoreTraitContext::TopLevel, "Div".into());
let rem_trait = get_core_trait(db, CoreTraitContext::TopLevel, "Rem".into());
let div_rem_trait = get_core_trait(db, CoreTraitContext::TopLevel, "DivRem".into());
let bit_and_trait = get_core_trait(db, CoreTraitContext::TopLevel, "BitAnd".into());
let bit_or_trait = get_core_trait(db, CoreTraitContext::TopLevel, "BitOr".into());
let bit_xor_trait = get_core_trait(db, CoreTraitContext::TopLevel, "BitXor".into());
Expand All @@ -1190,6 +1219,7 @@ impl ConstCalcInfo {
mul_trait,
div_trait,
rem_trait,
div_rem_trait,
bit_and_trait,
bit_or_trait,
bit_xor_trait,
Expand All @@ -1205,6 +1235,7 @@ impl ConstCalcInfo {
mul_fn: trait_fn(mul_trait, "mul"),
div_fn: trait_fn(div_trait, "div"),
rem_fn: trait_fn(rem_trait, "rem"),
div_rem_fn: trait_fn(div_rem_trait, "div_rem"),
bit_and_fn: trait_fn(bit_and_trait, "bitand"),
bit_or_fn: trait_fn(bit_or_trait, "bitor"),
bit_xor_fn: trait_fn(bit_xor_trait, "bitxor"),
Expand Down

0 comments on commit 42b9426

Please sign in to comment.