Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang committed Nov 22, 2024
1 parent 7f7c612 commit 9957aa3
Show file tree
Hide file tree
Showing 2 changed files with 363 additions and 0 deletions.
201 changes: 201 additions & 0 deletions expander_compiler/tests/example_rsa.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
use ark_std::test_rng;
use rand::Rng;

const N_LIMBS: usize = 18;
const MASK120: u128 = (1 << 120) - 1;
const MASK60: u128 = (1 << 60) - 1;
const MASK8: u128 = (1 << 8) - 1;
const HEX_PER_LIMB: usize = 30;

#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RSAFieldElement {
// an RSA field element is a 2048 bits integer
// it is represented as an array of 18 u120 elements, stored each in a u128
data: [u128; N_LIMBS],
}

#[inline]
fn add_u120_with_carry(a: &u128, b: &u128, carry: &u128) -> (u128, u128) {
// a, b, carry are all 120 bits integers, so we can simply add them
let mut sum = *a + *b + *carry;

let carry = sum >> 120;
sum = sum & MASK120;

(sum, carry)
}

#[inline]
fn mul_u120_with_carry(a: &u128, b: &u128, carry: &u128) -> (u128, u128) {
let a_lo = a & MASK60;
let a_hi = a >> 60;
let b_lo = b & MASK60;
let b_hi = b >> 60;
let c_lo = *carry & MASK60;
let c_hi = *carry >> 60;

let tmp_0 = &a_lo * &b_lo + &c_lo;
let tmp_1 = &a_lo * &b_hi + &a_hi * &b_lo + c_hi;
let tmp_2 = &a_hi * &b_hi;

let tmp_1_lo = tmp_1 & MASK60;
let tmp_1_hi = tmp_1 >> 60;

let (res, mut c) = add_u120_with_carry(&tmp_0, &(tmp_1_lo << 60), &0u128);
c += tmp_1_hi + tmp_2;

(res, c)
}

impl RSAFieldElement {
pub fn new(data: [u128; N_LIMBS]) -> Self {
Self { data }
}

pub fn random(rng: &mut impl Rng) -> Self {
let mut data = [0; N_LIMBS];
rng.fill(&mut data);
data.iter_mut()
.take(N_LIMBS - 1)
.for_each(|x| *x &= MASK120);
data[N_LIMBS - 1] &= MASK8;
Self { data }
}

pub fn to_string(&self) -> String {
let mut s = String::new();
for i in 0..N_LIMBS {
s = (&format!("{:030x}", self.data[i])).to_string() + &s;
}
s
}

pub fn from_string(s: &str) -> Self {
let mut data = [0; N_LIMBS];
for i in 0..N_LIMBS {
data[N_LIMBS - 1 - i] =
u128::from_str_radix(&s[i * HEX_PER_LIMB..(i + 1) * HEX_PER_LIMB], 16).unwrap();
}
Self { data }
}

// assert a + b = result + r * carry
pub fn assert_addition(a: &Self, b: &Self, modulus: &Self, carry: &bool, result: &Self) {
let mut left_result = [0u128; N_LIMBS]; // for a + b
let mut right_result = result.data.clone(); // for result + r * carry

// First compute a + b
let mut c = 0u128;
for i in 0..N_LIMBS {
let (sum, new_carry) = add_u120_with_carry(&a.data[i], &b.data[i], &c);
left_result[i] = sum;
c = new_carry;
}

// If carry is true, add modulus to right_result
if *carry {
let mut c = 0u128;
for i in 0..N_LIMBS {
let (sum, new_carry) = add_u120_with_carry(&right_result[i], &modulus.data[i], &c);
right_result[i] = sum;
c = new_carry;
}
}

// Assert equality
assert!(
left_result == right_result,
"Addition assertion failed\n{:?}\n{:?}",
left_result,
right_result
);
}
}

#[test]
fn test_rsa_field_serial() {
let mut rng = test_rng();
let a = RSAFieldElement::random(&mut rng);
let a_str = a.to_string();
println!("{:?}", a_str);

let a2 = RSAFieldElement::from_string(&a_str);
assert_eq!(a, a2);

for _ in 0..100 {
let a = RSAFieldElement::random(&mut rng);
let a_str = a.to_string();
let a2 = RSAFieldElement::from_string(&a_str);
assert_eq!(a, a2);
}
}

#[test]
fn test_u120_add() {
let a = MASK120;
let b = 1;
let carry = 0;
let (sum, carry) = add_u120_with_carry(&a, &b, &carry);

assert_eq!(sum, 0);
assert_eq!(carry, 1);
}

#[test]
fn test_u120_mul() {
let a = MASK120;
let b = 8;
let carry = 0;
let (sum, carry) = mul_u120_with_carry(&a, &b, &carry);

assert_eq!(sum, 0xfffffffffffffffffffffffffffff8);
assert_eq!(carry, 7);

let a = MASK120;
let b = MASK120 - 1;
let carry = a;
let (sum, carry) = mul_u120_with_carry(&a, &b, &carry);

assert_eq!(sum, 1);
assert_eq!(carry, 0xfffffffffffffffffffffffffffffe);
}

#[test]
fn test_assert_rsa_addition() {
let mut r = RSAFieldElement::new([MASK120; N_LIMBS]);
r.data[N_LIMBS - 1] = MASK8;

{
let a = RSAFieldElement::new([1u128; N_LIMBS]);
let b = RSAFieldElement::new([2u128; N_LIMBS]);
let result = RSAFieldElement::new([3u128; N_LIMBS]);
RSAFieldElement::assert_addition(&a, &b, &r, &false, &result);
println!("case 1 passed");
}

{
let mut a = RSAFieldElement::new([MASK120 - 1; N_LIMBS]);
a.data[N_LIMBS - 1] = MASK8 - 1;
let b = RSAFieldElement::new([1u128; N_LIMBS]);
let result = RSAFieldElement::new([0u128; N_LIMBS]);
println!("a: {:?}", a.to_string());
println!("b: {:?}", b.to_string());
println!("r: {:?}", r.to_string());
println!("result: {:?}", result.to_string());
RSAFieldElement::assert_addition(&a, &b, &r, &true, &result);
println!("case 2 passed");
}

{
let mut a = RSAFieldElement::new([MASK120 - 1; N_LIMBS]);
a.data[N_LIMBS - 1] = MASK8 - 1;
let b = RSAFieldElement::new([2u128; N_LIMBS]);
let result = RSAFieldElement::from_string("000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001");
println!("a: {:?}", a.to_string());
println!("b: {:?}", b.to_string());
println!("r: {:?}", r.to_string());
println!("result: {:?}", result.to_string());
RSAFieldElement::assert_addition(&a, &b, &r, &true, &result);
println!("case 3 passed");
}
}
162 changes: 162 additions & 0 deletions expander_compiler/tests/rsa_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
def multiply_mod_254bit_constraint(a_chunks, b_chunks, r_chunks, chunk_size=120):
"""
Compute (a * b) mod r where:
- a, b, r are stored as arrays of 17 120-bit chunks
- All intermediate calculations must stay under 254 bits
- Result should be in the same format (17 120-bit chunks)
Parameters:
a_chunks: List[int] - 17 elements, each ≤ 2^120
b_chunks: List[int] - 17 elements, each ≤ 2^120
r_chunks: List[int] - 17 elements, each ≤ 2^120
"""
n = len(a_chunks) # Should be 17
assert len(a_chunks) == len(b_chunks) == len(r_chunks) == 17

# Maximum value in each chunk
MAX_CHUNK = (1 << chunk_size)

def single_chunk_mult(a_i, b_j):
"""
Multiply two 120-bit chunks, ensuring result stays under 254 bits
Returns (high, low) tuple where high is the overflow into next chunk
"""
product = a_i * b_j
low = product & (MAX_CHUNK - 1)
high = product >> chunk_size
return high, low

# Initialize result array (need extra space for overflow)
result = [0] * (2 * n)

# Perform multiplication with careful overflow handling
for i in range(n):
for j in range(n):
# Multiply individual chunks
high, low = single_chunk_mult(a_chunks[i], b_chunks[j])

# Position in result
pos = i + j

# Add low part
result[pos] += low

# Handle overflow from low part
if result[pos] >= MAX_CHUNK:
result[pos + 1] += result[pos] >> chunk_size
result[pos] &= (MAX_CHUNK - 1)

# Add high part to next chunk
result[pos + 1] += high

# Handle overflow from high part
if result[pos + 1] >= MAX_CHUNK:
result[pos + 2] += result[pos + 1] >> chunk_size
result[pos + 1] &= (MAX_CHUNK - 1)

# Now we need to reduce modulo r
def reduce_mod_r():
"""
Reduce the result modulo r, maintaining chunk structure
This is a simplified Barrett reduction adapted for our chunk structure
"""
# Convert chunks to a more manageable form for reduction
temp_result = result.copy()

while any(temp_result[n:]): # While there are any nonzero high chunks
# Find the highest nonzero chunk
highest_chunk = 2 * n - 1
while highest_chunk >= n and temp_result[highest_chunk] == 0:
highest_chunk -= 1

if highest_chunk < n:
break

# Perform the reduction
shift = highest_chunk - n + 1
for i in range(n):
if r_chunks[i] != 0:
for j in range(n):
if i + j + shift < 2 * n:
temp_result[i + j + shift] -= (
(temp_result[highest_chunk] * r_chunks[i])
>> (chunk_size * (n - j - 1))
) & (MAX_CHUNK - 1)

# Normalize chunks
for i in range(2 * n - 1):
while temp_result[i] < 0:
temp_result[i] += MAX_CHUNK
temp_result[i + 1] -= 1
while temp_result[i] >= MAX_CHUNK:
temp_result[i] -= MAX_CHUNK
temp_result[i + 1] += 1

return temp_result[:n]

return reduce_mod_r()

def test_complex_case():
# Create a more complex test case with larger numbers

# Helper function to create chunks from a large number
def number_to_chunks(num, chunk_size=120, num_chunks=17):
chunks = []
mask = (1 << chunk_size) - 1
for _ in range(num_chunks):
chunks.append(num & mask)
num >>= chunk_size
return chunks

# Test with some large prime-like numbers
# Using numbers that will exercise multiple chunks

# Create a large number for testing (about 1000 bits set)
a = (1 << 1000) - (1 << 500) + (1 << 200) - (1 << 100) + 12345
b = (1 << 900) - (1 << 400) + (1 << 300) - (1 << 50) + 67890
r = (1 << 1024) - (1 << 512) + (1 << 256) - (1 << 128) + 11111

# Convert to chunks
a_chunks = number_to_chunks(a)
b_chunks = number_to_chunks(b)
r_chunks = number_to_chunks(r)

# Perform the modular multiplication
result_chunks = multiply_mod_254bit_constraint(a_chunks, b_chunks, r_chunks)

# Convert result back to number for verification
def chunks_to_number(chunks, chunk_size=120):
result = 0
for i, chunk in enumerate(chunks):
result += chunk << (i * chunk_size)
return result

# Calculate expected result using regular Python arithmetic
expected = (a * b) % r
actual = chunks_to_number(result_chunks)

print("Test with large numbers:")
print(f"a (bits): {a.bit_length()}")
print(f"b (bits): {b.bit_length()}")
print(f"r (bits): {r.bit_length()}")
print("a:", a)
print("b:", b)
print("r:", r)
print("expected:", expected)
print("actual:", actual)
print("\nFirst few chunks of a:", a_chunks[:3])
print("First few chunks of b:", b_chunks[:3])
print("First few chunks of r:", r_chunks[:3])
print("\nFirst few chunks of result:", result_chunks[:3])
print(f"\nExpected result (bits): {expected.bit_length()}")
print(f"Actual result (bits): {actual.bit_length()}")
print(f"Results match: {expected == actual}")

# Additional verification
print("\nVerification that no chunk exceeds 120 bits:")
max_chunk_size = max(chunk.bit_length() for chunk in result_chunks)
print(f"Maximum chunk size in result: {max_chunk_size} bits")
print(f"All chunks within 120-bit limit: {max_chunk_size <= 120}")

# Run the test
test_complex_case()

0 comments on commit 9957aa3

Please sign in to comment.