Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

u256 sqr mod #302

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/math/src/mod_arithmetics.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use core::integer::{u512, u512_safe_div_rem_by_u256, u256_wide_mul};
use core::integer::{
u512, u512_safe_div_rem_by_u256, u256_wide_mul, u128_wide_mul, u128_overflowing_add,
u128_wrapping_add
};
use core::option::OptionTrait;
use core::traits::TryInto;

Expand Down Expand Up @@ -72,6 +75,49 @@ pub fn mult_mod(a: u256, b: u256, mod_non_zero: NonZero<u256>) -> u256 {
rem_u256
}

#[inline(always)]
// core::integer::u128_add_with_carry
fn u128_add_with_carry(a: u128, b: u128) -> (u128, u128) nopanic {
match u128_overflowing_add(a, b) {
Result::Ok(v) => (v, 0),
Result::Err(v) => (v, 1),
}
}

pub fn u256_wide_sqr(a: u256) -> u512 nopanic {
let (limb1, limb0) = u128_wide_mul(a.low, a.low);
let (limb2, limb1_part) = u128_wide_mul(a.low, a.high);
let (limb1, limb1_overflow0) = u128_add_with_carry(limb1, limb1_part);
let (limb1, limb1_overflow1) = u128_add_with_carry(limb1, limb1_part);
let (limb2, limb2_overflow) = u128_add_with_carry(limb2, limb2);
let (limb3, limb2_part) = u128_wide_mul(a.high, a.high);
// No overflow since no limb4.
let limb3 = u128_wrapping_add(limb3, limb2_overflow);
let (limb2, limb2_overflow) = u128_add_with_carry(limb2, limb2_part);
// No overflow since no limb4.
let limb3 = u128_wrapping_add(limb3, limb2_overflow);
// No overflow possible in this addition since both operands are 0/1.
let limb1_overflow = u128_wrapping_add(limb1_overflow0, limb1_overflow1);
let (limb2, limb2_overflow) = u128_add_with_carry(limb2, limb1_overflow);
// No overflow since no limb4.
let limb3 = u128_wrapping_add(limb3, limb2_overflow);
u512 { limb0, limb1, limb2, limb3 }
}

/// Function that performs modular multiplication.
/// # Arguments
/// * `a` - Left hand side of multiplication.
/// * `b` - Right hand side of multiplication.
/// * `modulo` - modulo.
/// # Returns
/// * `u256` - result of modular multiplication
#[inline(always)]
pub fn sqr_mod(a: u256, mod_non_zero: NonZero<u256>) -> u256 {
let mult: u512 = u256_wide_sqr(a);
let (_, rem_u256) = u512_safe_div_rem_by_u256(mult, mod_non_zero);
rem_u256
}

/// Function that performs modular division.
/// # Arguments
/// * `a` - Left hand side of division.
Expand Down
18 changes: 17 additions & 1 deletion src/math/src/tests/mod_arithmetics_test.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alexandria_math::mod_arithmetics::{add_mod, sub_mod, mult_mod, div_mod, pow_mod};
use alexandria_math::mod_arithmetics::{add_mod, sub_mod, mult_mod, sqr_mod, div_mod, pow_mod};
use core::traits::TryInto;

const p: u256 =
Expand Down Expand Up @@ -115,6 +115,22 @@ fn mult_mod_2_test() {
assert_eq!(mult_mod(pow_256_minus_1, 1, 2), 1, "Incorrect result");
}

#[test]
#[available_gas(500000000)]
fn sqr_mod_test() {
assert_eq!(sqr_mod(p, 2), 1, "Incorrect result");
assert_eq!(
sqr_mod(p, pow_256_minus_1.try_into().unwrap()),
mult_mod(p, p, pow_256_minus_1.try_into().unwrap()),
"Incorrect result"
);
assert_eq!(
sqr_mod(pow_256_minus_1, p.try_into().unwrap()),
mult_mod(pow_256_minus_1, pow_256_minus_1, p.try_into().unwrap()),
"Incorrect result"
);
}

#[test]
#[available_gas(500000000)]
fn div_mod_test() {
Expand Down
Loading