Skip to content

Commit

Permalink
Mod pow and sqr (#304)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

## Pull Request type

<!-- Please try to limit your pull request to one type; submit multiple
pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [x] Refactoring (no functional changes, no API changes)
- [ ] Build-related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

> No dedicated u256 sqr, which can be done with one less u128 mul.

`pow_mod` uses `mult_mod`

## What is the new behavior?

> Instead of `mult_mod(x, x)`, call the more efficient `sqr_mod(x)`
>
> A new `u256_wide_sqr` which is largely identical to `u256_wide_mul` is
added,
> <img width="1582" alt="image"
src="https://github.com/keep-starknet-strange/alexandria/assets/11048263/3bf4223b-a751-4932-a9d4-77c6ba67abc1">


`pow_mod` uses `sqr_mod`

```diff
- test alexandria_math::tests::mod_arithmetics_test::pow_mod_test ... ok (gas usage est.: 27066340)
+ test alexandria_math::tests::mod_arithmetics_test::pow_mod_test ... ok (gas usage est.: 26160370)
- test alexandria_math::tests::mod_arithmetics_test::pow_mod_1_test ... ok (gas usage est.: 27015940)
+ test alexandria_math::tests::mod_arithmetics_test::pow_mod_1_test ... ok (gas usage est.: 26109970)
- test alexandria_math::tests::mod_arithmetics_test::pow_mod_2_test ... ok (gas usage est.: 27015940)
+ test alexandria_math::tests::mod_arithmetics_test::pow_mod_2_test ... ok (gas usage est.: 26109970)
```
## Does this introduce a breaking change?

- [ ] Yes
- [x] No

<!-- If this does introduce a breaking change, please describe the
impact and migration path for existing applications below. -->

## Other information

<!-- Any other information that is important to this PR, such as
screenshots of how the component looks before and after the change. -->
  • Loading branch information
shramee authored May 15, 2024
1 parent 4059381 commit b26f9e1
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
50 changes: 48 additions & 2 deletions src/math/src/mod_arithmetics.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use core::integer::{u512, u512_safe_div_rem_by_u256, u256_wide_mul};
use core::integer::{
u512, u512_safe_div_rem_by_u256, u256_wide_mul, u128_wide_mul, u128_overflowing_add,
u128_wrapping_add
};
use core::option::OptionTrait;
use core::traits::TryInto;

Expand Down Expand Up @@ -72,6 +75,49 @@ pub fn mult_mod(a: u256, b: u256, mod_non_zero: NonZero<u256>) -> u256 {
rem_u256
}

#[inline(always)]
// core::integer::u128_add_with_carry
fn u128_add_with_carry(a: u128, b: u128) -> (u128, u128) nopanic {
match u128_overflowing_add(a, b) {
Result::Ok(v) => (v, 0),
Result::Err(v) => (v, 1),
}
}

pub fn u256_wide_sqr(a: u256) -> u512 nopanic {
let (limb1, limb0) = u128_wide_mul(a.low, a.low);
let (limb2, limb1_part) = u128_wide_mul(a.low, a.high);
let (limb1, limb1_overflow0) = u128_add_with_carry(limb1, limb1_part);
let (limb1, limb1_overflow1) = u128_add_with_carry(limb1, limb1_part);
let (limb2, limb2_overflow) = u128_add_with_carry(limb2, limb2);
let (limb3, limb2_part) = u128_wide_mul(a.high, a.high);
// No overflow since no limb4.
let limb3 = u128_wrapping_add(limb3, limb2_overflow);
let (limb2, limb2_overflow) = u128_add_with_carry(limb2, limb2_part);
// No overflow since no limb4.
let limb3 = u128_wrapping_add(limb3, limb2_overflow);
// No overflow possible in this addition since both operands are 0/1.
let limb1_overflow = u128_wrapping_add(limb1_overflow0, limb1_overflow1);
let (limb2, limb2_overflow) = u128_add_with_carry(limb2, limb1_overflow);
// No overflow since no limb4.
let limb3 = u128_wrapping_add(limb3, limb2_overflow);
u512 { limb0, limb1, limb2, limb3 }
}

/// Function that performs modular multiplication.
/// # Arguments
/// * `a` - Left hand side of multiplication.
/// * `b` - Right hand side of multiplication.
/// * `modulo` - modulo.
/// # Returns
/// * `u256` - result of modular multiplication
#[inline(always)]
pub fn sqr_mod(a: u256, mod_non_zero: NonZero<u256>) -> u256 {
let mult: u512 = u256_wide_sqr(a);
let (_, rem_u256) = u512_safe_div_rem_by_u256(mult, mod_non_zero);
rem_u256
}

/// Function that performs modular division.
/// # Arguments
/// * `a` - Left hand side of division.
Expand Down Expand Up @@ -100,7 +146,7 @@ pub fn pow_mod(mut base: u256, mut pow: u256, mod_non_zero: NonZero<u256>) -> u2
result = mult_mod(result, base, mod_non_zero);
}
pow = q;
base = mult_mod(base, base, mod_non_zero);
base = sqr_mod(base, mod_non_zero);
};

result
Expand Down
18 changes: 17 additions & 1 deletion src/math/src/tests/mod_arithmetics_test.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alexandria_math::mod_arithmetics::{add_mod, sub_mod, mult_mod, div_mod, pow_mod};
use alexandria_math::mod_arithmetics::{add_mod, sub_mod, mult_mod, sqr_mod, div_mod, pow_mod};
use core::traits::TryInto;

const p: u256 =
Expand Down Expand Up @@ -115,6 +115,22 @@ fn mult_mod_2_test() {
assert_eq!(mult_mod(pow_256_minus_1, 1, 2), 1, "Incorrect result");
}

#[test]
#[available_gas(500000000)]
fn sqr_mod_test() {
assert_eq!(sqr_mod(p, 2), 1, "Incorrect result");
assert_eq!(
sqr_mod(p, pow_256_minus_1.try_into().unwrap()),
mult_mod(p, p, pow_256_minus_1.try_into().unwrap()),
"Incorrect result"
);
assert_eq!(
sqr_mod(pow_256_minus_1, p.try_into().unwrap()),
mult_mod(pow_256_minus_1, pow_256_minus_1, p.try_into().unwrap()),
"Incorrect result"
);
}

#[test]
#[available_gas(500000000)]
fn div_mod_test() {
Expand Down

0 comments on commit b26f9e1

Please sign in to comment.