diff --git a/expander_compiler/tests/example_rsa.rs b/expander_compiler/tests/example_rsa.rs new file mode 100644 index 0000000..334ee02 --- /dev/null +++ b/expander_compiler/tests/example_rsa.rs @@ -0,0 +1,201 @@ +use ark_std::test_rng; +use rand::Rng; + +const N_LIMBS: usize = 18; +const MASK120: u128 = (1 << 120) - 1; +const MASK60: u128 = (1 << 60) - 1; +const MASK8: u128 = (1 << 8) - 1; +const HEX_PER_LIMB: usize = 30; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct RSAFieldElement { + // an RSA field element is a 2048 bits integer + // it is represented as an array of 18 u120 elements, stored each in a u128 + data: [u128; N_LIMBS], +} + +#[inline] +fn add_u120_with_carry(a: &u128, b: &u128, carry: &u128) -> (u128, u128) { + // a, b, carry are all 120 bits integers, so we can simply add them + let mut sum = *a + *b + *carry; + + let carry = sum >> 120; + sum = sum & MASK120; + + (sum, carry) +} + +#[inline] +fn mul_u120_with_carry(a: &u128, b: &u128, carry: &u128) -> (u128, u128) { + let a_lo = a & MASK60; + let a_hi = a >> 60; + let b_lo = b & MASK60; + let b_hi = b >> 60; + let c_lo = *carry & MASK60; + let c_hi = *carry >> 60; + + let tmp_0 = &a_lo * &b_lo + &c_lo; + let tmp_1 = &a_lo * &b_hi + &a_hi * &b_lo + c_hi; + let tmp_2 = &a_hi * &b_hi; + + let tmp_1_lo = tmp_1 & MASK60; + let tmp_1_hi = tmp_1 >> 60; + + let (res, mut c) = add_u120_with_carry(&tmp_0, &(tmp_1_lo << 60), &0u128); + c += tmp_1_hi + tmp_2; + + (res, c) +} + +impl RSAFieldElement { + pub fn new(data: [u128; N_LIMBS]) -> Self { + Self { data } + } + + pub fn random(rng: &mut impl Rng) -> Self { + let mut data = [0; N_LIMBS]; + rng.fill(&mut data); + data.iter_mut() + .take(N_LIMBS - 1) + .for_each(|x| *x &= MASK120); + data[N_LIMBS - 1] &= MASK8; + Self { data } + } + + pub fn to_string(&self) -> String { + let mut s = String::new(); + for i in 0..N_LIMBS { + s = (&format!("{:030x}", self.data[i])).to_string() + &s; + } + s + } + + pub fn from_string(s: &str) -> Self { + let mut data = [0; N_LIMBS]; + for i in 0..N_LIMBS { + data[N_LIMBS - 1 - i] = + u128::from_str_radix(&s[i * HEX_PER_LIMB..(i + 1) * HEX_PER_LIMB], 16).unwrap(); + } + Self { data } + } + + // assert a + b = result + r * carry + pub fn assert_addition(a: &Self, b: &Self, modulus: &Self, carry: &bool, result: &Self) { + let mut left_result = [0u128; N_LIMBS]; // for a + b + let mut right_result = result.data.clone(); // for result + r * carry + + // First compute a + b + let mut c = 0u128; + for i in 0..N_LIMBS { + let (sum, new_carry) = add_u120_with_carry(&a.data[i], &b.data[i], &c); + left_result[i] = sum; + c = new_carry; + } + + // If carry is true, add modulus to right_result + if *carry { + let mut c = 0u128; + for i in 0..N_LIMBS { + let (sum, new_carry) = add_u120_with_carry(&right_result[i], &modulus.data[i], &c); + right_result[i] = sum; + c = new_carry; + } + } + + // Assert equality + assert!( + left_result == right_result, + "Addition assertion failed\n{:?}\n{:?}", + left_result, + right_result + ); + } +} + +#[test] +fn test_rsa_field_serial() { + let mut rng = test_rng(); + let a = RSAFieldElement::random(&mut rng); + let a_str = a.to_string(); + println!("{:?}", a_str); + + let a2 = RSAFieldElement::from_string(&a_str); + assert_eq!(a, a2); + + for _ in 0..100 { + let a = RSAFieldElement::random(&mut rng); + let a_str = a.to_string(); + let a2 = RSAFieldElement::from_string(&a_str); + assert_eq!(a, a2); + } +} + +#[test] +fn test_u120_add() { + let a = MASK120; + let b = 1; + let carry = 0; + let (sum, carry) = add_u120_with_carry(&a, &b, &carry); + + assert_eq!(sum, 0); + assert_eq!(carry, 1); +} + +#[test] +fn test_u120_mul() { + let a = MASK120; + let b = 8; + let carry = 0; + let (sum, carry) = mul_u120_with_carry(&a, &b, &carry); + + assert_eq!(sum, 0xfffffffffffffffffffffffffffff8); + assert_eq!(carry, 7); + + let a = MASK120; + let b = MASK120 - 1; + let carry = a; + let (sum, carry) = mul_u120_with_carry(&a, &b, &carry); + + assert_eq!(sum, 1); + assert_eq!(carry, 0xfffffffffffffffffffffffffffffe); +} + +#[test] +fn test_assert_rsa_addition() { + let mut r = RSAFieldElement::new([MASK120; N_LIMBS]); + r.data[N_LIMBS - 1] = MASK8; + + { + let a = RSAFieldElement::new([1u128; N_LIMBS]); + let b = RSAFieldElement::new([2u128; N_LIMBS]); + let result = RSAFieldElement::new([3u128; N_LIMBS]); + RSAFieldElement::assert_addition(&a, &b, &r, &false, &result); + println!("case 1 passed"); + } + + { + let mut a = RSAFieldElement::new([MASK120 - 1; N_LIMBS]); + a.data[N_LIMBS - 1] = MASK8 - 1; + let b = RSAFieldElement::new([1u128; N_LIMBS]); + let result = RSAFieldElement::new([0u128; N_LIMBS]); + println!("a: {:?}", a.to_string()); + println!("b: {:?}", b.to_string()); + println!("r: {:?}", r.to_string()); + println!("result: {:?}", result.to_string()); + RSAFieldElement::assert_addition(&a, &b, &r, &true, &result); + println!("case 2 passed"); + } + + { + let mut a = RSAFieldElement::new([MASK120 - 1; N_LIMBS]); + a.data[N_LIMBS - 1] = MASK8 - 1; + let b = RSAFieldElement::new([2u128; N_LIMBS]); + let result = RSAFieldElement::from_string("000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001"); + println!("a: {:?}", a.to_string()); + println!("b: {:?}", b.to_string()); + println!("r: {:?}", r.to_string()); + println!("result: {:?}", result.to_string()); + RSAFieldElement::assert_addition(&a, &b, &r, &true, &result); + println!("case 3 passed"); + } +} diff --git a/expander_compiler/tests/rsa_mul.py b/expander_compiler/tests/rsa_mul.py new file mode 100644 index 0000000..e8c6c50 --- /dev/null +++ b/expander_compiler/tests/rsa_mul.py @@ -0,0 +1,162 @@ +def multiply_mod_254bit_constraint(a_chunks, b_chunks, r_chunks, chunk_size=120): + """ + Compute (a * b) mod r where: + - a, b, r are stored as arrays of 17 120-bit chunks + - All intermediate calculations must stay under 254 bits + - Result should be in the same format (17 120-bit chunks) + + Parameters: + a_chunks: List[int] - 17 elements, each ≤ 2^120 + b_chunks: List[int] - 17 elements, each ≤ 2^120 + r_chunks: List[int] - 17 elements, each ≤ 2^120 + """ + n = len(a_chunks) # Should be 17 + assert len(a_chunks) == len(b_chunks) == len(r_chunks) == 17 + + # Maximum value in each chunk + MAX_CHUNK = (1 << chunk_size) + + def single_chunk_mult(a_i, b_j): + """ + Multiply two 120-bit chunks, ensuring result stays under 254 bits + Returns (high, low) tuple where high is the overflow into next chunk + """ + product = a_i * b_j + low = product & (MAX_CHUNK - 1) + high = product >> chunk_size + return high, low + + # Initialize result array (need extra space for overflow) + result = [0] * (2 * n) + + # Perform multiplication with careful overflow handling + for i in range(n): + for j in range(n): + # Multiply individual chunks + high, low = single_chunk_mult(a_chunks[i], b_chunks[j]) + + # Position in result + pos = i + j + + # Add low part + result[pos] += low + + # Handle overflow from low part + if result[pos] >= MAX_CHUNK: + result[pos + 1] += result[pos] >> chunk_size + result[pos] &= (MAX_CHUNK - 1) + + # Add high part to next chunk + result[pos + 1] += high + + # Handle overflow from high part + if result[pos + 1] >= MAX_CHUNK: + result[pos + 2] += result[pos + 1] >> chunk_size + result[pos + 1] &= (MAX_CHUNK - 1) + + # Now we need to reduce modulo r + def reduce_mod_r(): + """ + Reduce the result modulo r, maintaining chunk structure + This is a simplified Barrett reduction adapted for our chunk structure + """ + # Convert chunks to a more manageable form for reduction + temp_result = result.copy() + + while any(temp_result[n:]): # While there are any nonzero high chunks + # Find the highest nonzero chunk + highest_chunk = 2 * n - 1 + while highest_chunk >= n and temp_result[highest_chunk] == 0: + highest_chunk -= 1 + + if highest_chunk < n: + break + + # Perform the reduction + shift = highest_chunk - n + 1 + for i in range(n): + if r_chunks[i] != 0: + for j in range(n): + if i + j + shift < 2 * n: + temp_result[i + j + shift] -= ( + (temp_result[highest_chunk] * r_chunks[i]) + >> (chunk_size * (n - j - 1)) + ) & (MAX_CHUNK - 1) + + # Normalize chunks + for i in range(2 * n - 1): + while temp_result[i] < 0: + temp_result[i] += MAX_CHUNK + temp_result[i + 1] -= 1 + while temp_result[i] >= MAX_CHUNK: + temp_result[i] -= MAX_CHUNK + temp_result[i + 1] += 1 + + return temp_result[:n] + + return reduce_mod_r() + +def test_complex_case(): + # Create a more complex test case with larger numbers + + # Helper function to create chunks from a large number + def number_to_chunks(num, chunk_size=120, num_chunks=17): + chunks = [] + mask = (1 << chunk_size) - 1 + for _ in range(num_chunks): + chunks.append(num & mask) + num >>= chunk_size + return chunks + + # Test with some large prime-like numbers + # Using numbers that will exercise multiple chunks + + # Create a large number for testing (about 1000 bits set) + a = (1 << 1000) - (1 << 500) + (1 << 200) - (1 << 100) + 12345 + b = (1 << 900) - (1 << 400) + (1 << 300) - (1 << 50) + 67890 + r = (1 << 1024) - (1 << 512) + (1 << 256) - (1 << 128) + 11111 + + # Convert to chunks + a_chunks = number_to_chunks(a) + b_chunks = number_to_chunks(b) + r_chunks = number_to_chunks(r) + + # Perform the modular multiplication + result_chunks = multiply_mod_254bit_constraint(a_chunks, b_chunks, r_chunks) + + # Convert result back to number for verification + def chunks_to_number(chunks, chunk_size=120): + result = 0 + for i, chunk in enumerate(chunks): + result += chunk << (i * chunk_size) + return result + + # Calculate expected result using regular Python arithmetic + expected = (a * b) % r + actual = chunks_to_number(result_chunks) + + print("Test with large numbers:") + print(f"a (bits): {a.bit_length()}") + print(f"b (bits): {b.bit_length()}") + print(f"r (bits): {r.bit_length()}") + print("a:", a) + print("b:", b) + print("r:", r) + print("expected:", expected) + print("actual:", actual) + print("\nFirst few chunks of a:", a_chunks[:3]) + print("First few chunks of b:", b_chunks[:3]) + print("First few chunks of r:", r_chunks[:3]) + print("\nFirst few chunks of result:", result_chunks[:3]) + print(f"\nExpected result (bits): {expected.bit_length()}") + print(f"Actual result (bits): {actual.bit_length()}") + print(f"Results match: {expected == actual}") + + # Additional verification + print("\nVerification that no chunk exceeds 120 bits:") + max_chunk_size = max(chunk.bit_length() for chunk in result_chunks) + print(f"Maximum chunk size in result: {max_chunk_size} bits") + print(f"All chunks within 120-bit limit: {max_chunk_size <= 120}") + +# Run the test +test_complex_case() \ No newline at end of file