From 42b942654ebd373c7062f17487261595f059e18d Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Sun, 5 Jan 2025 12:02:18 +0200 Subject: [PATCH] Added support for const eval `div_rem`. Required minor refactoring of args handling. commit-id:e185f290 --- .../src/expr/test_data/constant | 9 +- .../cairo-lang-semantic/src/items/constant.rs | 101 ++++++++++++------ 2 files changed, 71 insertions(+), 39 deletions(-) diff --git a/crates/cairo-lang-semantic/src/expr/test_data/constant b/crates/cairo-lang-semantic/src/expr/test_data/constant index 61de8ad3644..9f2960b4f3c 100644 --- a/crates/cairo-lang-semantic/src/expr/test_data/constant +++ b/crates/cairo-lang-semantic/src/expr/test_data/constant @@ -70,6 +70,7 @@ const VALID_LT: () = assert(1_usize < 2, '1 < 2'); const VALID_LE: () = assert(1_usize <= 1, '1 <= 1'); const VALID_GT: () = assert(2_usize > 1, '2 > 1'); const VALID_GE: () = assert(1_usize >= 1, '1 >= 1'); +const VALID_DIVREM: () = assert(DivRem::div_rem(5_u8, 2) == (2, 1), 'div_rem(5, 2) == (2, 1)'); const FUNC_CALC_SUCCESS_OPTION: felt252 = Option::Some(5).unwrap(); const FUNC_CALC_FAILURE_OPTION: felt252 = Option::None.unwrap(); @@ -125,7 +126,7 @@ note: In `core::assert`: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ error: Failed to calculate constant. - --> lib.cairo:32:43 + --> lib.cairo:33:43 const FUNC_CALC_FAILURE_OPTION: felt252 = Option::None.unwrap(); ^^^^^^^^^^^^^^^^^^^^^ note: In `core::option::OptionTraitImpl::::expect`: @@ -138,7 +139,7 @@ note: In `core::option::OptionTraitImpl::::unwrap`: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ error: Failed to calculate constant. - --> lib.cairo:34:44 + --> lib.cairo:35:44 const FUNC_CALC_FAILURE_RESULT1: felt252 = Result::Err(5).unwrap(); ^^^^^^^^^^^^^^^^^^^^^^^ note: In `core::result::ResultTraitImpl::::expect::>>`: @@ -151,7 +152,7 @@ note: In `core::result::ResultTraitImpl::::unwrap: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ error: Failed to calculate constant. - --> lib.cairo:36:44 + --> lib.cairo:37:44 const FUNC_CALC_FAILURE_RESULT2: felt252 = Result::Ok(5).unwrap_err(); ^^^^^^^^^^^^^^^^^^^^^^^^^^ note: In `core::result::ResultTraitImpl::::expect_err::<+PanicDestruct>`: @@ -164,7 +165,7 @@ note: In `core::result::ResultTraitImpl::::unwrap_ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ error: Constant calculation depth exceeded. - --> lib.cairo:38:43 + --> lib.cairo:39:43 const FUNC_CALC_STACK_EXCEEDED: felt252 = call_myself(); ^^^^^^^^^^^^^ diff --git a/crates/cairo-lang-semantic/src/items/constant.rs b/crates/cairo-lang-semantic/src/items/constant.rs index ed6ef9d022d..28fe2682786 100644 --- a/crates/cairo-lang-semantic/src/items/constant.rs +++ b/crates/cairo-lang-semantic/src/items/constant.rs @@ -835,49 +835,45 @@ impl ConstantEvaluateContext<'_> { let args = match args .into_iter() - .map(|arg| match arg { - ConstValue::Int(v, _ty) => Ok(v), - // Handling u256 constants to enable const evaluation of them. - ConstValue::Struct(v, _) => { - if let [ConstValue::Int(low, _), ConstValue::Int(high, _)] = &v[..] { - Ok(low + (high << 128)) - } else { - // Dignostic can be skipped as we would either have a semantic error for a - // bad arg for the function, or the arg itself - // could'nt have been calculated. - Err(skip_diagnostic()) - } - } - ConstValue::Missing(err) => Err(err), - // Dignostic can be skipped as we would either have a semantic error for a bad arg - // for the function, or the arg itself could'nt have been calculated. - _ => Err(skip_diagnostic()), - }) - .collect::, _>>() + .map(|arg| NumericArg::try_new(db, arg)) + .collect::>>() { - Ok(args) => args, - Err(err) => return ConstValue::Missing(err), + Some(args) => args, + // Dignostic can be skipped as we would either have a semantic error for a bad arg for + // the function, or the arg itself could'nt have been calculated. + None => return ConstValue::Missing(skip_diagnostic()), }; let mut value = match imp.function { - id if id == self.neg_fn => -&args[0], - id if id == self.add_fn => &args[0] + &args[1], - id if id == self.sub_fn => &args[0] - &args[1], - id if id == self.mul_fn => &args[0] * &args[1], - id if (id == self.div_fn || id == self.rem_fn) && args[1].is_zero() => { + id if id == self.neg_fn => -&args[0].v, + id if id == self.add_fn => &args[0].v + &args[1].v, + id if id == self.sub_fn => &args[0].v - &args[1].v, + id if id == self.mul_fn => &args[0].v * &args[1].v, + id if (id == self.div_fn || id == self.rem_fn) && args[1].v.is_zero() => { return ConstValue::Missing( self.diagnostics .report(expr.stable_ptr.untyped(), SemanticDiagnosticKind::DivisionByZero), ); } - id if id == self.div_fn => &args[0] / &args[1], - id if id == self.rem_fn => &args[0] % &args[1], - id if id == self.bit_and_fn => &args[0] & &args[1], - id if id == self.bit_or_fn => &args[0] | &args[1], - id if id == self.bit_xor_fn => &args[0] ^ &args[1], - id if id == self.lt_fn => return bool_value(args[0] < args[1]), - id if id == self.le_fn => return bool_value(args[0] <= args[1]), - id if id == self.gt_fn => return bool_value(args[0] > args[1]), - id if id == self.ge_fn => return bool_value(args[0] >= args[1]), + id if id == self.div_fn => &args[0].v / &args[1].v, + id if id == self.rem_fn => &args[0].v % &args[1].v, + id if id == self.bit_and_fn => &args[0].v & &args[1].v, + id if id == self.bit_or_fn => &args[0].v | &args[1].v, + id if id == self.bit_xor_fn => &args[0].v ^ &args[1].v, + id if id == self.lt_fn => return bool_value(args[0].v < args[1].v), + id if id == self.le_fn => return bool_value(args[0].v <= args[1].v), + id if id == self.gt_fn => return bool_value(args[0].v > args[1].v), + id if id == self.ge_fn => return bool_value(args[0].v >= args[1].v), + id if id == self.div_rem_fn => { + // No need for non-zero check as this is type checked to begin with. + // Also results are always in the range of the input type, so `unwrap`s are ok. + return ConstValue::Struct( + vec![ + value_as_const_value(db, args[0].ty, &(&args[0].v / &args[1].v)).unwrap(), + value_as_const_value(db, args[0].ty, &(&args[0].v % &args[1].v)).unwrap(), + ], + expr.ty, + ); + } _ => { unreachable!("Unexpected function call in constant lowering: {:?}", expr) } @@ -1041,6 +1037,36 @@ impl std::ops::Deref for ConstantEvaluateContext<'_> { } } +/// Helper for the arguments info. +struct NumericArg { + /// The arg's integer value. + v: BigInt, + /// The arg's type. + ty: TypeId, +} +impl NumericArg { + fn try_new(db: &dyn SemanticGroup, arg: ConstValue) -> Option { + Some(Self { ty: arg.ty(db).ok()?, v: numeric_arg_value(arg)? }) + } +} + +/// Helper for creating a `NumericArg` value. +/// This includes unwrapping of `NonZero` values and struct of 2 values as a `u256`. +fn numeric_arg_value(value: ConstValue) -> Option { + match value { + ConstValue::Int(value, _) => Some(value), + ConstValue::Struct(v, _) => { + if let [ConstValue::Int(low, _), ConstValue::Int(high, _)] = &v[..] { + Some(low + (high << 128)) + } else { + None + } + } + ConstValue::NonZero(const_value) => numeric_arg_value(*const_value), + _ => None, + } +} + /// Query implementation of [SemanticGroup::constant_semantic_diagnostics]. pub fn constant_semantic_diagnostics( db: &dyn SemanticGroup, @@ -1133,6 +1159,8 @@ pub struct ConstCalcInfo { div_fn: TraitFunctionId, /// The trait function for `Rem::rem`. rem_fn: TraitFunctionId, + /// The trait function for `DivRem::div_rem`. + div_rem_fn: TraitFunctionId, /// The trait function for `BitAnd::bitand`. bit_and_fn: TraitFunctionId, /// The trait function for `BitOr::bitor`. @@ -1172,6 +1200,7 @@ impl ConstCalcInfo { let mul_trait = get_core_trait(db, CoreTraitContext::TopLevel, "Mul".into()); let div_trait = get_core_trait(db, CoreTraitContext::TopLevel, "Div".into()); let rem_trait = get_core_trait(db, CoreTraitContext::TopLevel, "Rem".into()); + let div_rem_trait = get_core_trait(db, CoreTraitContext::TopLevel, "DivRem".into()); let bit_and_trait = get_core_trait(db, CoreTraitContext::TopLevel, "BitAnd".into()); let bit_or_trait = get_core_trait(db, CoreTraitContext::TopLevel, "BitOr".into()); let bit_xor_trait = get_core_trait(db, CoreTraitContext::TopLevel, "BitXor".into()); @@ -1190,6 +1219,7 @@ impl ConstCalcInfo { mul_trait, div_trait, rem_trait, + div_rem_trait, bit_and_trait, bit_or_trait, bit_xor_trait, @@ -1205,6 +1235,7 @@ impl ConstCalcInfo { mul_fn: trait_fn(mul_trait, "mul"), div_fn: trait_fn(div_trait, "div"), rem_fn: trait_fn(rem_trait, "rem"), + div_rem_fn: trait_fn(div_rem_trait, "div_rem"), bit_and_fn: trait_fn(bit_and_trait, "bitand"), bit_or_fn: trait_fn(bit_or_trait, "bitor"), bit_xor_fn: trait_fn(bit_xor_trait, "bitxor"),