From 0807fc8455d0f6249cca0e72e14b5fdc87a76f29 Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Sun, 5 Jan 2025 12:42:20 +0200 Subject: [PATCH] Added support for `pow` in constant eval. commit-id:1583fc6f --- corelib/src/num/traits/ops/pow.cairo | 246 +++++++++++++++--- .../src/expr/test_data/constant | 15 +- 2 files changed, 229 insertions(+), 32 deletions(-) diff --git a/corelib/src/num/traits/ops/pow.cairo b/corelib/src/num/traits/ops/pow.cairo index 578957ec275..5b54d2a9063 100644 --- a/corelib/src/num/traits/ops/pow.cairo +++ b/corelib/src/num/traits/ops/pow.cairo @@ -28,38 +28,222 @@ pub trait Pow { fn pow(self: Base, exp: Exp) -> Self::Output; } -mod mul_based { - /// Square and multiply implementation for `Pow`. - pub impl PowByMul< - Base, +Mul, +Copy, +Drop, +core::num::traits::One, - > of super::Pow { - type Output = Base; - - fn pow(self: Base, exp: usize) -> Base { - let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); - let tail_result = if tail_exp == 0 { - core::num::traits::One::one() - } else { - Self::pow(self * self, tail_exp) - }; - if head_exp == 0 { - tail_result - } else { - tail_result * self - } +// TODO(gil): Use a macro for it instead of copy paste. +// Not using a trait for the implementation to allow `fn` to be `const`. + +impl PowFelt252 of Pow { + type Output = felt252; + + const fn pow(self: felt252, exp: usize) -> felt252 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowI8 of Pow { + type Output = i8; + + const fn pow(self: i8, exp: usize) -> i8 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowU8 of Pow { + type Output = u8; + + const fn pow(self: u8, exp: usize) -> u8 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowI16 of Pow { + type Output = i16; + + const fn pow(self: i16, exp: usize) -> i16 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowU16 of Pow { + type Output = u16; + + const fn pow(self: u16, exp: usize) -> u16 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowI32 of Pow { + type Output = i32; + + const fn pow(self: i32, exp: usize) -> i32 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowU32 of Pow { + type Output = u32; + + const fn pow(self: u32, exp: usize) -> u32 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowI64 of Pow { + type Output = i64; + + const fn pow(self: i64, exp: usize) -> i64 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowU64 of Pow { + type Output = u64; + + const fn pow(self: u64, exp: usize) -> u64 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowI128 of Pow { + type Output = i128; + + const fn pow(self: i128, exp: usize) -> i128 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowU128 of Pow { + type Output = u128; + + const fn pow(self: u128, exp: usize) -> u128 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self + } + } +} + +impl PowU256 of Pow { + type Output = u256; + + const fn pow(self: u256, exp: usize) -> u256 { + let (tail_exp, head_exp) = DivRem::div_rem(exp, 2); + let tail_result = if tail_exp == 0 { + 1 + } else { + Self::pow(self * self, tail_exp) + }; + if head_exp == 0 { + tail_result + } else { + tail_result * self } } } -impl PowFelt252 = mul_based::PowByMul; -impl PowI8 = mul_based::PowByMul; -impl PowU8 = mul_based::PowByMul; -impl PowI16 = mul_based::PowByMul; -impl PowU16 = mul_based::PowByMul; -impl PowI32 = mul_based::PowByMul; -impl PowU32 = mul_based::PowByMul; -impl PowI64 = mul_based::PowByMul; -impl PowU64 = mul_based::PowByMul; -impl PowI128 = mul_based::PowByMul; -impl PowU128 = mul_based::PowByMul; -impl PowU256 = mul_based::PowByMul; diff --git a/crates/cairo-lang-semantic/src/expr/test_data/constant b/crates/cairo-lang-semantic/src/expr/test_data/constant index 49270a0bb01..0900037c713 100644 --- a/crates/cairo-lang-semantic/src/expr/test_data/constant +++ b/crates/cairo-lang-semantic/src/expr/test_data/constant @@ -78,6 +78,19 @@ const VALID_LE: () = assert(1_usize <= 1); const VALID_GT: () = assert(2_usize > 1); const VALID_GE: () = assert(1_usize >= 1); const VALID_DIVREM: () = assert(DivRem::div_rem(5_u8, 2) == (2, 1)); +use core::num::traits::Pow; + +const VALID_POW_MIN2_0: () = assert((-2_i8).pow(0) == 1); +const VALID_POW_MIN2_3: () = assert((-2_i8).pow(3) == -8); +const VALID_POW_MIN2_6: () = assert((-2_i8).pow(6) == 64); + +const VALID_POW_0_0: () = assert(0_felt252.pow(0) == 1); +const VALID_POW_0_1: () = assert(0_felt252.pow(1) == 0); +const VALID_POW_0_2: () = assert(0_felt252.pow(2) == 0); + +const VALID_POW_2_0: () = assert(2_felt252.pow(0) == 0b1); +const VALID_POW_2_6: () = assert(2_felt252.pow(6) == 0b1000000); +const VALID_POW_2_10: () = assert(2_felt252.pow(10) == 0b10000000000); const FUNC_CALC_SUCCESS_OPTION: felt252 = Option::Some(5).unwrap(); const FUNC_CALC_SUCCESS_RESULT1: felt252 = Result::<_, felt252>::Ok(5).unwrap(); @@ -130,7 +143,7 @@ note: In `test::assert`: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ error: Constant calculation depth exceeded. - --> lib.cairo:43:43 + --> lib.cairo:56:43 const FUNC_CALC_STACK_EXCEEDED: felt252 = call_myself(); ^^^^^^^^^^^^^