Skip to content

Commit

Permalink
enh: optimise ed25519 (#298)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

## Pull Request type

<!-- Please try to limit your pull request to one type; submit multiple
pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [x] Refactoring (no functional changes, no API changes)
- [ ] Build-related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

<!-- Please describe the current behavior that you are modifying, or
link to a relevant issue. -->

Issue Number: N/A

## What is the new behavior?

<!-- Please describe the behavior or changes that are being added by
this PR. -->

- 15% less gas usage via various changes in the ed25519 verification
algo

## Does this introduce a breaking change?

- [ ] Yes
- [x] No

<!-- If this does introduce a breaking change, please describe the
impact and migration path for existing applications below. -->

## Other information

<!-- Any other information that is important to this PR, such as
screenshots of how the component looks before and after the change. -->
  • Loading branch information
edisontim authored May 15, 2024
1 parent ccd4677 commit 1b6091d
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 109 deletions.
178 changes: 126 additions & 52 deletions src/math/src/ed25519.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ use alexandria_data_structures::array_ext::SpanTraitExt;
use alexandria_math::mod_arithmetics::{
add_mod, sub_mod, mult_mod, div_mod, pow_mod, add_inverse_mod, equality_mod
};
use alexandria_math::pow;
use alexandria_math::sha512::{sha512, SHA512_LEN};
use core::array::ArrayTrait;
use core::integer::u512;
use core::option::OptionTrait;
use core::traits::Div;
use core::traits::TryInto;

// As per RFC-8032: https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.7
Expand All @@ -16,22 +20,27 @@ const d: u256 =
const l: u256 =
7237005577332262213973186563042994240857116359379907606001950938285454250989; // 2^252 + 27742317777372353535851937790883648493

const TWO_POW_8: u256 = 0x100;
const w: u256 = 4;

const TWO_POW_8_NON_ZERO: NonZero<u256> = 0x100;


#[derive(Drop, Copy)]
pub struct Point {
x: u256,
y: u256
y: u256,
prime: u256,
prime_non_zero: NonZero<u256>
}

#[derive(Drop, Copy)]
pub struct ExtendedHomogeneousPoint {
X: u256,
Y: u256,
Z: u256,
T: u256,
pub X: u256,
pub Y: u256,
pub Z: u256,
pub T: u256,
pub prime: u256,
pub prime_non_zero: NonZero<u256>
}

pub trait PointDoubling<T> {
Expand All @@ -40,15 +49,25 @@ pub trait PointDoubling<T> {

impl PointDoublingExtendedHomogeneousPoint of PointDoubling<ExtendedHomogeneousPoint> {
fn double(self: ExtendedHomogeneousPoint) -> ExtendedHomogeneousPoint {
let A: u256 = mult_mod(self.X, self.X, p);
let B: u256 = mult_mod(self.Y, self.Y, p);
let C: u256 = mult_mod(2, mult_mod(self.Z, self.Z, p), p);
let H: u256 = A + B;
let E: u256 = sub_mod(H, pow_mod(add_mod(self.X, self.Y, p), 2, p), p);
let G: u256 = sub_mod(A, B, p);
let F: u256 = add_mod(C, G, p);
let A: u256 = mult_mod(self.X, self.X, self.prime_non_zero);
let B: u256 = mult_mod(self.Y, self.Y, self.prime_non_zero);
let C: u256 = mult_mod(
2, mult_mod(self.Z, self.Z, self.prime_non_zero), self.prime_non_zero
);
let D: u256 = add_inverse_mod(A, self.prime);
let x_1_y_1 = add_mod(self.X, self.Y, self.prime);
let x_1_squared: u256 = mult_mod(x_1_y_1, x_1_y_1, self.prime_non_zero);
let E: u256 = sub_mod(sub_mod(x_1_squared, A, p), B, p);
let G: u256 = add_mod(D, B, self.prime);
let F: u256 = sub_mod(G, C, self.prime);
let H: u256 = sub_mod(D, B, self.prime);
ExtendedHomogeneousPoint {
X: mult_mod(E, F, p), Y: mult_mod(G, H, p), T: mult_mod(E, H, p), Z: mult_mod(F, G, p)
X: mult_mod(E, F, self.prime_non_zero),
Y: mult_mod(G, H, self.prime_non_zero),
T: mult_mod(E, H, self.prime_non_zero),
Z: mult_mod(F, G, self.prime_non_zero),
prime: self.prime,
prime_non_zero: self.prime_non_zero
}
}
}
Expand All @@ -57,34 +76,63 @@ impl ExtendedHomogeneousPointAdd of Add<ExtendedHomogeneousPoint> {
fn add(
lhs: ExtendedHomogeneousPoint, rhs: ExtendedHomogeneousPoint
) -> ExtendedHomogeneousPoint {
let A: u256 = mult_mod(sub_mod(lhs.Y, lhs.X, p), sub_mod(rhs.Y, rhs.X, p), p);
let B: u256 = mult_mod(add_mod(lhs.Y, lhs.X, p), add_mod(rhs.Y, rhs.X, p), p);
let C: u256 = mult_mod(mult_mod(mult_mod(lhs.T, 2, p), d, p), rhs.T, p);
let D: u256 = mult_mod(mult_mod(lhs.Z, 2, p), rhs.Z, p);
let E: u256 = sub_mod(B, A, p);
let F: u256 = sub_mod(D, C, p);
let G: u256 = add_mod(D, C, p);
let H: u256 = add_mod(B, A, p);

let X_3 = mult_mod(E, F, p);
let Y_3 = mult_mod(G, H, p);
let T_3 = mult_mod(E, H, p);
let Z_3 = mult_mod(F, G, p);

ExtendedHomogeneousPoint { X: X_3, Y: Y_3, T: T_3, Z: Z_3 }
if (lhs.prime != rhs.prime) {
panic!("not in the same field");
}

let A: u256 = mult_mod(
sub_mod(lhs.Y, lhs.X, lhs.prime), sub_mod(rhs.Y, rhs.X, lhs.prime), lhs.prime_non_zero
);
let B: u256 = mult_mod(
add_mod(lhs.Y, lhs.X, lhs.prime), add_mod(rhs.Y, rhs.X, lhs.prime), lhs.prime_non_zero
);
let C: u256 = mult_mod(
mult_mod(mult_mod(lhs.T, 2, lhs.prime_non_zero), d, lhs.prime_non_zero),
rhs.T,
lhs.prime_non_zero
);
let D: u256 = mult_mod(mult_mod(lhs.Z, 2, lhs.prime_non_zero), rhs.Z, lhs.prime_non_zero);

let E: u256 = sub_mod(B, A, lhs.prime);
let F: u256 = sub_mod(D, C, lhs.prime);
let G: u256 = add_mod(D, C, lhs.prime);
let H: u256 = add_mod(B, A, lhs.prime);

let X_3 = mult_mod(E, F, lhs.prime_non_zero);
let Y_3 = mult_mod(G, H, lhs.prime_non_zero);
let T_3 = mult_mod(E, H, lhs.prime_non_zero);
let Z_3 = mult_mod(F, G, lhs.prime_non_zero);

ExtendedHomogeneousPoint {
X: X_3, Y: Y_3, T: T_3, Z: Z_3, prime: lhs.prime, prime_non_zero: lhs.prime_non_zero
}
}
}

impl PartialEqExtendedHomogeneousPoint of PartialEq<ExtendedHomogeneousPoint> {
fn eq(lhs: @ExtendedHomogeneousPoint, rhs: @ExtendedHomogeneousPoint) -> bool {
if (lhs.prime != rhs.prime) {
panic!("not in the same field");
}
// lhs.X * rhs.Z - rhs.X * lhs.Z
if (sub_mod(mult_mod(*lhs.X, *rhs.Z, p), mult_mod(*rhs.X, *lhs.Z, p), p) != 0) {
if (sub_mod(
mult_mod(*lhs.X, *rhs.Z, *lhs.prime_non_zero),
mult_mod(*rhs.X, *lhs.Z, *lhs.prime_non_zero),
*lhs.prime
) != 0) {
return false;
}
// lhs.Y * rhs.Z - rhs.Y * lhs.Z
sub_mod(mult_mod(*lhs.Y, *rhs.Z, p), mult_mod(*rhs.Y, *lhs.Z, p), p) == 0
sub_mod(
mult_mod(*lhs.Y, *rhs.Z, *lhs.prime_non_zero),
mult_mod(*rhs.Y, *lhs.Z, *lhs.prime_non_zero),
*lhs.prime
) == 0
}
fn ne(lhs: @ExtendedHomogeneousPoint, rhs: @ExtendedHomogeneousPoint) -> bool {
if (lhs.prime != rhs.prime) {
panic!("not in the same field");
}
!(lhs == rhs)
}
}
Expand Down Expand Up @@ -196,28 +244,37 @@ impl U256TryIntoPoint of TryInto<u256, Point> {
return Option::None;
}

let y_2 = pow_mod(y, 2, p);
let prime_non_zero: NonZero<u256> = p.try_into().unwrap();

let y_2 = pow_mod(y, 2, prime_non_zero);
let u: u256 = sub_mod(y_2, 1, p);
let v: u256 = add_mod(mult_mod(d, y_2, p), 1, p);
let v_pow_3 = pow_mod(v, 3, p);
let v: u256 = add_mod(mult_mod(d, y_2, prime_non_zero), 1, p);
let v_pow_3 = pow_mod(v, 3, prime_non_zero);

let v_pow_7: u256 = pow_mod(v, 7, p);
let v_pow_7: u256 = pow_mod(v, 7, prime_non_zero);

let p_minus_5_div_8: u256 = div_mod(sub_mod(p, 5, p), 8, p);
let p_minus_5_div_8: u256 = div_mod(sub_mod(p, 5, p), 8, prime_non_zero);

let u_times_v_power_3: u256 = mult_mod(u, v_pow_3, p);
let u_times_v_power_3: u256 = mult_mod(u, v_pow_3, prime_non_zero);

let x_candidate_root: u256 = mult_mod(
u_times_v_power_3, pow_mod(mult_mod(u, v_pow_7, p), p_minus_5_div_8, p), p
u_times_v_power_3,
pow_mod(mult_mod(u, v_pow_7, prime_non_zero), p_minus_5_div_8, prime_non_zero),
prime_non_zero
);

let v_times_x_squared: u256 = mult_mod(v, pow_mod(x_candidate_root, 2, p), p);
let v_times_x_squared: u256 = mult_mod(
v, pow_mod(x_candidate_root, 2, prime_non_zero), prime_non_zero
);

if (equality_mod(v_times_x_squared, u, p)) {
x = x_candidate_root;
} else if (equality_mod(v_times_x_squared, add_inverse_mod(u, p), p)) {
let p_minus_one_over_4: u256 = div_mod(sub_mod(p, 1, p), 4, p);
x = mult_mod(x_candidate_root, pow_mod(2, p_minus_one_over_4, p), p);
let p_minus_one_over_4: u256 = div_mod(sub_mod(p, 1, p), 4, prime_non_zero);
x =
mult_mod(
x_candidate_root, pow_mod(2, p_minus_one_over_4, prime_non_zero), prime_non_zero
);
} else {
return Option::None;
}
Expand All @@ -231,13 +288,21 @@ impl U256TryIntoPoint of TryInto<u256, Point> {
x = p - x;
}

Option::Some(Point { x: x, y: y })
Option::Some(Point { x: x, y: y, prime: p, prime_non_zero: prime_non_zero })
}
}

impl PointIntoExtendedHomogeneousPoint of Into<Point, ExtendedHomogeneousPoint> {
fn into(self: Point) -> ExtendedHomogeneousPoint {
ExtendedHomogeneousPoint { X: self.x, Y: self.y, Z: 1, T: mult_mod(self.x, self.y, p) }
let prime_non_zero = p.try_into().unwrap();
ExtendedHomogeneousPoint {
X: self.x,
Y: self.y,
Z: 1,
T: mult_mod(self.x, self.y, prime_non_zero),
prime: p,
prime_non_zero: prime_non_zero
}
}
}

Expand All @@ -247,8 +312,13 @@ impl PointIntoExtendedHomogeneousPoint of Into<Point, ExtendedHomogeneousPoint>
/// * `P` - Elliptic Curve point in the Extended Homogeneous form.
/// # Returns
/// * `u256` - Resulting point in the Extended Homogeneous form.
fn point_mult(mut scalar: u256, mut P: ExtendedHomogeneousPoint) -> ExtendedHomogeneousPoint {
let mut Q = ExtendedHomogeneousPoint { X: 0, Y: 1, Z: 1, T: 0 };
pub fn point_mult_double_and_add(
mut scalar: u256, mut P: ExtendedHomogeneousPoint
) -> ExtendedHomogeneousPoint {
let prime_non_zero = p.try_into().unwrap();
let mut Q = ExtendedHomogeneousPoint {
X: 0, Y: 1, Z: 1, T: 0, prime: p, prime_non_zero: prime_non_zero // neutral element
};
let zero_u512 = Default::default();

// Double and add method
Expand All @@ -273,17 +343,21 @@ fn point_mult(mut scalar: u256, mut P: ExtendedHomogeneousPoint) -> ExtendedHomo
fn check_group_equation(
S: u256, R: ExtendedHomogeneousPoint, k: u256, A_prime: ExtendedHomogeneousPoint
) -> bool {
let prime_non_zero = p.try_into().unwrap();
// (X(P),Y(P)) of edwards25519 in https://datatracker.ietf.org/doc/html/rfc7748
let B: Point = Point {
x: 15112221349535400772501151409588531511454012693041857206046113283949847762202,
y: 46316835694926478169428394003475163141307993866256225615783033603165251855960
let B: ExtendedHomogeneousPoint = ExtendedHomogeneousPoint {
X: 15112221349535400772501151409588531511454012693041857206046113283949847762202,
Y: 46316835694926478169428394003475163141307993866256225615783033603165251855960,
Z: 1,
T: 46827403850823179245072216630277197565144205554125654976674165829533817101731,
prime: p,
prime_non_zero: prime_non_zero
};

let B_extended: ExtendedHomogeneousPoint = B.into();

// Check group equation [S]B = R + [k]A'
let lhs: ExtendedHomogeneousPoint = point_mult(S, B_extended);
let rhs: ExtendedHomogeneousPoint = R + point_mult(k, A_prime);
let lhs: ExtendedHomogeneousPoint = point_mult_double_and_add(S, B);
let kA: ExtendedHomogeneousPoint = point_mult_double_and_add(k, A_prime);
let rhs: ExtendedHomogeneousPoint = R + kA;
lhs == rhs
}

Expand Down
43 changes: 13 additions & 30 deletions src/math/src/mod_arithmetics.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use core::integer::{u512, u512_safe_div_rem_by_u256, u256_wide_mul};
use core::option::OptionTrait;
use core::traits::TryInto;

/// Function that performs modular addition.
/// Function that performs modular addition. Will panick if result is > u256 max
/// # Arguments
/// * `a` - Left hand side of addition.
/// * `b` - Right hand side of addition.
Expand All @@ -9,15 +11,7 @@ use core::integer::{u512, u512_safe_div_rem_by_u256, u256_wide_mul};
/// * `u256` - result of modular addition
#[inline(always)]
pub fn add_mod(a: u256, b: u256, modulo: u256) -> u256 {
let mod_non_zero: NonZero<u256> = modulo.try_into().unwrap();
let low: u256 = a.low.into() + b.low.into();
let high: u256 = a.high.into() + b.high.into();
let carry: u256 = low.high.into() + high.low.into();
let add_u512: u512 = u512 {
limb0: low.low, limb1: carry.low, limb2: carry.high + high.high, limb3: 0
};
let (_, res) = u512_safe_div_rem_by_u256(add_u512, mod_non_zero);
res
(a + b) % modulo
}

/// Function that return the modular multiplicative inverse. Disclaimer: this function should only be used with a prime modulo.
Expand All @@ -27,10 +21,8 @@ pub fn add_mod(a: u256, b: u256, modulo: u256) -> u256 {
/// # Returns
/// * `u256` - modular multiplicative inverse
#[inline(always)]
pub fn mult_inverse(b: u256, modulo: u256) -> u256 {
math::u256_inv_mod(b, modulo.try_into().expect('inverse non zero'))
.expect('inverse non zero')
.into()
pub fn mult_inverse(b: u256, mod_non_zero: NonZero<u256>) -> u256 {
math::u256_inv_mod(b, mod_non_zero).expect('inverse non zero').into()
}

/// Function that return the modular additive inverse.
Expand Down Expand Up @@ -74,9 +66,8 @@ pub fn sub_mod(mut a: u256, mut b: u256, modulo: u256) -> u256 {
/// # Returns
/// * `u256` - result of modular multiplication
#[inline(always)]
pub fn mult_mod(a: u256, b: u256, modulo: u256) -> u256 {
pub fn mult_mod(a: u256, b: u256, mod_non_zero: NonZero<u256>) -> u256 {
let mult: u512 = u256_wide_mul(a, b);
let mod_non_zero: NonZero<u256> = modulo.try_into().unwrap();
let (_, rem_u256) = u512_safe_div_rem_by_u256(mult, mod_non_zero);
rem_u256
}
Expand All @@ -89,10 +80,9 @@ pub fn mult_mod(a: u256, b: u256, modulo: u256) -> u256 {
/// # Returns
/// * `u256` - result of modular division
#[inline(always)]
pub fn div_mod(a: u256, b: u256, modulo: u256) -> u256 {
let modulo_nz = modulo.try_into().expect('0 modulo');
let inv = math::u256_inv_mod(b, modulo_nz).unwrap().into();
math::u256_mul_mod_n(a, inv, modulo_nz)
pub fn div_mod(a: u256, b: u256, mod_non_zero: NonZero<u256>) -> u256 {
let inv = math::u256_inv_mod(b, mod_non_zero).unwrap().into();
math::u256_mul_mod_n(a, inv, mod_non_zero)
}

/// Function that performs modular exponentiation.
Expand All @@ -102,23 +92,16 @@ pub fn div_mod(a: u256, b: u256, modulo: u256) -> u256 {
/// * `modulo` - modulo.
/// # Returns
/// * `u256` - result of modular exponentiation
pub fn pow_mod(mut base: u256, mut pow: u256, modulo: u256) -> u256 {
pub fn pow_mod(mut base: u256, mut pow: u256, mod_non_zero: NonZero<u256>) -> u256 {
let mut result: u256 = 1;
let mod_non_zero: NonZero<u256> = modulo.try_into().unwrap();
let mut mult: u512 = u512 { limb0: 0_u128, limb1: 0_u128, limb2: 0_u128, limb3: 0_u128 };

while (pow != 0) {
if ((pow & 1) > 0) {
mult = u256_wide_mul(result, base);
let (_, res_u256,) = u512_safe_div_rem_by_u256(mult, mod_non_zero);
result = res_u256;
result = mult_mod(result, base, mod_non_zero);
}

pow = pow / 2;

mult = u256_wide_mul(base, base);
let (_, base_u256) = u512_safe_div_rem_by_u256(mult, mod_non_zero);
base = base_u256;
base = mult_mod(base, base, mod_non_zero);
};

result
Expand Down
4 changes: 2 additions & 2 deletions src/math/src/tests/ed25519_test.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alexandria_math::ed25519::verify_signature;
use alexandria_math::ed25519::{verify_signature};

// Public keys and signatures were generated with JS library Noble (https://github.com/paulmillr/noble-ed25519)

Expand Down Expand Up @@ -69,7 +69,7 @@ fn verify_signature_invalid() {
let s_sign: u256 = 0x68e015fa8775659d1f40a01e1f69b8af4409046f4dc8ff02cdb04fdc3585eb01;
let signature = array![r_sign, s_sign];

assert!(!verify_signature(msg, signature.span(), pub_key), "Invalid signature");
assert!(!verify_signature(msg, signature.span(), pub_key), "Signature should be invalid");
}

#[test]
Expand Down
Loading

0 comments on commit 1b6091d

Please sign in to comment.