Skip to content

Commit

Permalink
Added support for pow in constant eval.
Browse files Browse the repository at this point in the history
commit-id:1583fc6f
  • Loading branch information
orizi committed Jan 6, 2025
1 parent 6025b75 commit 0807fc8
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 32 deletions.
246 changes: 215 additions & 31 deletions corelib/src/num/traits/ops/pow.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,222 @@ pub trait Pow<Base, Exp> {
fn pow(self: Base, exp: Exp) -> Self::Output;
}

mod mul_based {
/// Square and multiply implementation for `Pow`.
pub impl PowByMul<
Base, +Mul<Base>, +Copy<Base>, +Drop<Base>, +core::num::traits::One<Base>,
> of super::Pow<Base, usize> {
type Output = Base;

fn pow(self: Base, exp: usize) -> Base {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
core::num::traits::One::one()
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
// TODO(gil): Use a macro for it instead of copy paste.
// Not using a trait for the implementation to allow `fn` to be `const`.

impl PowFelt252 of Pow<felt252, usize> {
type Output = felt252;

const fn pow(self: felt252, exp: usize) -> felt252 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowI8 of Pow<i8, usize> {
type Output = i8;

const fn pow(self: i8, exp: usize) -> i8 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowU8 of Pow<u8, usize> {
type Output = u8;

const fn pow(self: u8, exp: usize) -> u8 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowI16 of Pow<i16, usize> {
type Output = i16;

const fn pow(self: i16, exp: usize) -> i16 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowU16 of Pow<u16, usize> {
type Output = u16;

const fn pow(self: u16, exp: usize) -> u16 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowI32 of Pow<i32, usize> {
type Output = i32;

const fn pow(self: i32, exp: usize) -> i32 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowU32 of Pow<u32, usize> {
type Output = u32;

const fn pow(self: u32, exp: usize) -> u32 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowI64 of Pow<i64, usize> {
type Output = i64;

const fn pow(self: i64, exp: usize) -> i64 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowU64 of Pow<u64, usize> {
type Output = u64;

const fn pow(self: u64, exp: usize) -> u64 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowI128 of Pow<i128, usize> {
type Output = i128;

const fn pow(self: i128, exp: usize) -> i128 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowU128 of Pow<u128, usize> {
type Output = u128;

const fn pow(self: u128, exp: usize) -> u128 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowU256 of Pow<u256, usize> {
type Output = u256;

const fn pow(self: u256, exp: usize) -> u256 {
let (tail_exp, head_exp) = DivRem::div_rem(exp, 2);
let tail_result = if tail_exp == 0 {
1
} else {
Self::pow(self * self, tail_exp)
};
if head_exp == 0 {
tail_result
} else {
tail_result * self
}
}
}

impl PowFelt252 = mul_based::PowByMul<felt252>;
impl PowI8 = mul_based::PowByMul<i8>;
impl PowU8 = mul_based::PowByMul<u8>;
impl PowI16 = mul_based::PowByMul<i16>;
impl PowU16 = mul_based::PowByMul<u16>;
impl PowI32 = mul_based::PowByMul<i32>;
impl PowU32 = mul_based::PowByMul<u32>;
impl PowI64 = mul_based::PowByMul<i64>;
impl PowU64 = mul_based::PowByMul<u64>;
impl PowI128 = mul_based::PowByMul<i128>;
impl PowU128 = mul_based::PowByMul<u128>;
impl PowU256 = mul_based::PowByMul<u256>;
15 changes: 14 additions & 1 deletion crates/cairo-lang-semantic/src/expr/test_data/constant
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ const VALID_LE: () = assert(1_usize <= 1);
const VALID_GT: () = assert(2_usize > 1);
const VALID_GE: () = assert(1_usize >= 1);
const VALID_DIVREM: () = assert(DivRem::div_rem(5_u8, 2) == (2, 1));
use core::num::traits::Pow;

const VALID_POW_MIN2_0: () = assert((-2_i8).pow(0) == 1);
const VALID_POW_MIN2_3: () = assert((-2_i8).pow(3) == -8);
const VALID_POW_MIN2_6: () = assert((-2_i8).pow(6) == 64);

const VALID_POW_0_0: () = assert(0_felt252.pow(0) == 1);
const VALID_POW_0_1: () = assert(0_felt252.pow(1) == 0);
const VALID_POW_0_2: () = assert(0_felt252.pow(2) == 0);

const VALID_POW_2_0: () = assert(2_felt252.pow(0) == 0b1);
const VALID_POW_2_6: () = assert(2_felt252.pow(6) == 0b1000000);
const VALID_POW_2_10: () = assert(2_felt252.pow(10) == 0b10000000000);

const FUNC_CALC_SUCCESS_OPTION: felt252 = Option::Some(5).unwrap();
const FUNC_CALC_SUCCESS_RESULT1: felt252 = Result::<_, felt252>::Ok(5).unwrap();
Expand Down Expand Up @@ -130,7 +143,7 @@ note: In `test::assert`:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: Constant calculation depth exceeded.
--> lib.cairo:43:43
--> lib.cairo:56:43
const FUNC_CALC_STACK_EXCEEDED: felt252 = call_myself();
^^^^^^^^^^^^^

Expand Down

0 comments on commit 0807fc8

Please sign in to comment.