Skip to content

Commit

Permalink
get the multiplication correct
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang committed Nov 22, 2024
1 parent 9957aa3 commit 962ba5c
Showing 1 changed file with 110 additions and 1 deletion.
111 changes: 110 additions & 1 deletion expander_compiler/tests/example_rsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ impl RSAFieldElement {
}

// assert a + b = result + r * carry
// a, b, result, modulus are all RSAFieldElement
pub fn assert_addition(a: &Self, b: &Self, modulus: &Self, carry: &bool, result: &Self) {
let mut left_result = [0u128; N_LIMBS]; // for a + b
let mut right_result = result.data.clone(); // for result + r * carry
Expand Down Expand Up @@ -110,6 +111,75 @@ impl RSAFieldElement {
right_result
);
}

#[inline]
// compute a*b without reduction, add the result to res
fn mul_without_reduction(a: &Self, b: &Self, res: &mut [u128; 2 * N_LIMBS]) {
for i in 0..N_LIMBS {
let mut carry = 0u128;
for j in 0..N_LIMBS {
if i + j < 2 * N_LIMBS {
let (prod, prod_carry) = mul_u120_with_carry(&a.data[i], &b.data[j], &carry);

// Add to accumulator at position i+j
let mut acc_carry = 0u128;
let (sum, new_carry) = add_u120_with_carry(&res[i + j], &prod, &acc_carry);
res[i + j] = sum;

// Propagate carries
carry = prod_carry;
acc_carry = new_carry;
if acc_carry > 0 {
let mut k = 1;
while acc_carry > 0 && (i + j + k) < 2 * N_LIMBS {
let (new_val, new_carry) =
add_u120_with_carry(&res[i + j + k], &acc_carry, &0u128);
res[i + j + k] = new_val;
acc_carry = new_carry;
k += 1;
}
}
}
}
// Handle final multiplication carry
if carry > 0 && i + N_LIMBS < 2 * N_LIMBS {
let mut k = 0;
while carry > 0 && (i + N_LIMBS + k) < 2 * N_LIMBS {
let (new_val, new_carry) =
add_u120_with_carry(&res[i + N_LIMBS + k], &carry, &0u128);
res[i + N_LIMBS + k] = new_val;
carry = new_carry;
k += 1;
}
}
}
}

// assert a * b = result + r * carry
// a, b, result, modulus, carry are all RSAFieldElement
pub fn assert_multiplication(a: &Self, b: &Self, modulus: &Self, carry: &Self, result: &Self) {
// Two arrays to hold left and right results: a * b and result + r * carry
let mut left_result = [0u128; 2 * N_LIMBS]; // for a * b
let mut right_result = [0u128; 2 * N_LIMBS]; // for result + r * carry

// First compute a * b
Self::mul_without_reduction(a, b, &mut left_result);
println!("left_result: {:0x?}", left_result);

// Now compute result + r * carry
// First copy result
for i in 0..N_LIMBS {
right_result[i] = result.data[i];
}
Self::mul_without_reduction(modulus, carry, &mut right_result);
println!("right_result: {:0x?}", right_result);

// Assert equality
assert!(
left_result == right_result,
"Multiplication assertion failed"
);
}
}

#[test]
Expand Down Expand Up @@ -190,7 +260,8 @@ fn test_assert_rsa_addition() {
let mut a = RSAFieldElement::new([MASK120 - 1; N_LIMBS]);
a.data[N_LIMBS - 1] = MASK8 - 1;
let b = RSAFieldElement::new([2u128; N_LIMBS]);
let result = RSAFieldElement::from_string("000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001");
let result = RSAFieldElement::from_string("000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001");

println!("a: {:?}", a.to_string());
println!("b: {:?}", b.to_string());
println!("r: {:?}", r.to_string());
Expand All @@ -199,3 +270,41 @@ fn test_assert_rsa_addition() {
println!("case 3 passed");
}
}
#[test]
fn test_assert_rsa_multiplication() {
let mut r = RSAFieldElement::new([MASK120; N_LIMBS]);
r.data[N_LIMBS - 1] = MASK8;

{
let a = RSAFieldElement::new([1u128; N_LIMBS]);
let b = RSAFieldElement::new([2u128; N_LIMBS]);

let carry = RSAFieldElement::from_string("0000000000000000000000000000000200000000000000000000000000000400000000000000000000000000000600000000000000000000000000000800000000000000000000000000000a00000000000000000000000000000c00000000000000000000000000000e00000000000000000000000000001000000000000000000000000000001200000000000000000000000000001400000000000000000000000000001600000000000000000000000000001800000000000000000000000000001a00000000000000000000000000001c00000000000000000000000000001e0000000000000000000000000000200000000000000000000000000000220000000000000000000000000000");
let result = RSAFieldElement::from_string("00000000000000000000000000002402000000000000000000000000002204000000000000000000000000002006000000000000000000000000001e08000000000000000000000000001c0a000000000000000000000000001a0c00000000000000000000000000180e000000000000000000000000001610000000000000000000000000001412000000000000000000000000001214000000000000000000000000001016000000000000000000000000000e18000000000000000000000000000c1a000000000000000000000000000a1c00000000000000000000000000081e0000000000000000000000000006200000000000000000000000000004220000000000000000000000000002");

println!("a: {:?}", a.to_string());
println!("b: {:?}", b.to_string());
println!("r: {:?}", r.to_string());
println!("carry: {:?}", result.to_string());
println!("result: {:?}", carry.to_string());
RSAFieldElement::assert_multiplication(&a, &b, &r, &carry, &result);
println!("case 1 passed");
}

{
let mut a = RSAFieldElement::new([MASK120 - 1; N_LIMBS]);
a.data[N_LIMBS - 1] = MASK8 - 1;
let b = RSAFieldElement::new([2u128; N_LIMBS]);

let carry = RSAFieldElement::from_string("000000000000000000000000000001fe0000000000000000000000000001fc0000000000000000000000000001fa0000000000000000000000000001f80000000000000000000000000001f60000000000000000000000000001f40000000000000000000000000001f20000000000000000000000000001f00000000000000000000000000001ee0000000000000000000000000001ec0000000000000000000000000001ea0000000000000000000000000001e80000000000000000000000000001e60000000000000000000000000001e40000000000000000000000000001e20000000000000000000000000001e00000000000000000000000000001de0000000000000000000000000001");
let result = RSAFieldElement::from_string("0000000000000000000000000000dbfdffffffffffffffffffffffffffddfbffffffffffffffffffffffffffdff9ffffffffffffffffffffffffffe1f7ffffffffffffffffffffffffffe3f5ffffffffffffffffffffffffffe5f3ffffffffffffffffffffffffffe7f1ffffffffffffffffffffffffffe9efffffffffffffffffffffffffffebedffffffffffffffffffffffffffedebffffffffffffffffffffffffffefe9fffffffffffffffffffffffffff1e7fffffffffffffffffffffffffff3e5fffffffffffffffffffffffffff5e3fffffffffffffffffffffffffff7e1fffffffffffffffffffffffffff9dffffffffffffffffffffffffffffbddfffffffffffffffffffffffffffd");

println!("a: {:?}", a.to_string());
println!("b: {:?}", b.to_string());
println!("r: {:?}", r.to_string());
println!("carry: {:?}", result.to_string());
println!("result: {:?}", carry.to_string());
RSAFieldElement::assert_multiplication(&a, &b, &r, &carry, &result);
println!("case 1 passed");
}
}

0 comments on commit 962ba5c

Please sign in to comment.