-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7f7c612
commit 9957aa3
Showing
2 changed files
with
363 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
use ark_std::test_rng; | ||
use rand::Rng; | ||
|
||
const N_LIMBS: usize = 18; | ||
const MASK120: u128 = (1 << 120) - 1; | ||
const MASK60: u128 = (1 << 60) - 1; | ||
const MASK8: u128 = (1 << 8) - 1; | ||
const HEX_PER_LIMB: usize = 30; | ||
|
||
#[derive(Debug, Clone, Copy, PartialEq)] | ||
pub struct RSAFieldElement { | ||
// an RSA field element is a 2048 bits integer | ||
// it is represented as an array of 18 u120 elements, stored each in a u128 | ||
data: [u128; N_LIMBS], | ||
} | ||
|
||
#[inline] | ||
fn add_u120_with_carry(a: &u128, b: &u128, carry: &u128) -> (u128, u128) { | ||
// a, b, carry are all 120 bits integers, so we can simply add them | ||
let mut sum = *a + *b + *carry; | ||
|
||
let carry = sum >> 120; | ||
sum = sum & MASK120; | ||
|
||
(sum, carry) | ||
} | ||
|
||
#[inline] | ||
fn mul_u120_with_carry(a: &u128, b: &u128, carry: &u128) -> (u128, u128) { | ||
let a_lo = a & MASK60; | ||
let a_hi = a >> 60; | ||
let b_lo = b & MASK60; | ||
let b_hi = b >> 60; | ||
let c_lo = *carry & MASK60; | ||
let c_hi = *carry >> 60; | ||
|
||
let tmp_0 = &a_lo * &b_lo + &c_lo; | ||
let tmp_1 = &a_lo * &b_hi + &a_hi * &b_lo + c_hi; | ||
let tmp_2 = &a_hi * &b_hi; | ||
|
||
let tmp_1_lo = tmp_1 & MASK60; | ||
let tmp_1_hi = tmp_1 >> 60; | ||
|
||
let (res, mut c) = add_u120_with_carry(&tmp_0, &(tmp_1_lo << 60), &0u128); | ||
c += tmp_1_hi + tmp_2; | ||
|
||
(res, c) | ||
} | ||
|
||
impl RSAFieldElement { | ||
pub fn new(data: [u128; N_LIMBS]) -> Self { | ||
Self { data } | ||
} | ||
|
||
pub fn random(rng: &mut impl Rng) -> Self { | ||
let mut data = [0; N_LIMBS]; | ||
rng.fill(&mut data); | ||
data.iter_mut() | ||
.take(N_LIMBS - 1) | ||
.for_each(|x| *x &= MASK120); | ||
data[N_LIMBS - 1] &= MASK8; | ||
Self { data } | ||
} | ||
|
||
pub fn to_string(&self) -> String { | ||
let mut s = String::new(); | ||
for i in 0..N_LIMBS { | ||
s = (&format!("{:030x}", self.data[i])).to_string() + &s; | ||
} | ||
s | ||
} | ||
|
||
pub fn from_string(s: &str) -> Self { | ||
let mut data = [0; N_LIMBS]; | ||
for i in 0..N_LIMBS { | ||
data[N_LIMBS - 1 - i] = | ||
u128::from_str_radix(&s[i * HEX_PER_LIMB..(i + 1) * HEX_PER_LIMB], 16).unwrap(); | ||
} | ||
Self { data } | ||
} | ||
|
||
// assert a + b = result + r * carry | ||
pub fn assert_addition(a: &Self, b: &Self, modulus: &Self, carry: &bool, result: &Self) { | ||
let mut left_result = [0u128; N_LIMBS]; // for a + b | ||
let mut right_result = result.data.clone(); // for result + r * carry | ||
|
||
// First compute a + b | ||
let mut c = 0u128; | ||
for i in 0..N_LIMBS { | ||
let (sum, new_carry) = add_u120_with_carry(&a.data[i], &b.data[i], &c); | ||
left_result[i] = sum; | ||
c = new_carry; | ||
} | ||
|
||
// If carry is true, add modulus to right_result | ||
if *carry { | ||
let mut c = 0u128; | ||
for i in 0..N_LIMBS { | ||
let (sum, new_carry) = add_u120_with_carry(&right_result[i], &modulus.data[i], &c); | ||
right_result[i] = sum; | ||
c = new_carry; | ||
} | ||
} | ||
|
||
// Assert equality | ||
assert!( | ||
left_result == right_result, | ||
"Addition assertion failed\n{:?}\n{:?}", | ||
left_result, | ||
right_result | ||
); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_rsa_field_serial() { | ||
let mut rng = test_rng(); | ||
let a = RSAFieldElement::random(&mut rng); | ||
let a_str = a.to_string(); | ||
println!("{:?}", a_str); | ||
|
||
let a2 = RSAFieldElement::from_string(&a_str); | ||
assert_eq!(a, a2); | ||
|
||
for _ in 0..100 { | ||
let a = RSAFieldElement::random(&mut rng); | ||
let a_str = a.to_string(); | ||
let a2 = RSAFieldElement::from_string(&a_str); | ||
assert_eq!(a, a2); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_u120_add() { | ||
let a = MASK120; | ||
let b = 1; | ||
let carry = 0; | ||
let (sum, carry) = add_u120_with_carry(&a, &b, &carry); | ||
|
||
assert_eq!(sum, 0); | ||
assert_eq!(carry, 1); | ||
} | ||
|
||
#[test] | ||
fn test_u120_mul() { | ||
let a = MASK120; | ||
let b = 8; | ||
let carry = 0; | ||
let (sum, carry) = mul_u120_with_carry(&a, &b, &carry); | ||
|
||
assert_eq!(sum, 0xfffffffffffffffffffffffffffff8); | ||
assert_eq!(carry, 7); | ||
|
||
let a = MASK120; | ||
let b = MASK120 - 1; | ||
let carry = a; | ||
let (sum, carry) = mul_u120_with_carry(&a, &b, &carry); | ||
|
||
assert_eq!(sum, 1); | ||
assert_eq!(carry, 0xfffffffffffffffffffffffffffffe); | ||
} | ||
|
||
#[test] | ||
fn test_assert_rsa_addition() { | ||
let mut r = RSAFieldElement::new([MASK120; N_LIMBS]); | ||
r.data[N_LIMBS - 1] = MASK8; | ||
|
||
{ | ||
let a = RSAFieldElement::new([1u128; N_LIMBS]); | ||
let b = RSAFieldElement::new([2u128; N_LIMBS]); | ||
let result = RSAFieldElement::new([3u128; N_LIMBS]); | ||
RSAFieldElement::assert_addition(&a, &b, &r, &false, &result); | ||
println!("case 1 passed"); | ||
} | ||
|
||
{ | ||
let mut a = RSAFieldElement::new([MASK120 - 1; N_LIMBS]); | ||
a.data[N_LIMBS - 1] = MASK8 - 1; | ||
let b = RSAFieldElement::new([1u128; N_LIMBS]); | ||
let result = RSAFieldElement::new([0u128; N_LIMBS]); | ||
println!("a: {:?}", a.to_string()); | ||
println!("b: {:?}", b.to_string()); | ||
println!("r: {:?}", r.to_string()); | ||
println!("result: {:?}", result.to_string()); | ||
RSAFieldElement::assert_addition(&a, &b, &r, &true, &result); | ||
println!("case 2 passed"); | ||
} | ||
|
||
{ | ||
let mut a = RSAFieldElement::new([MASK120 - 1; N_LIMBS]); | ||
a.data[N_LIMBS - 1] = MASK8 - 1; | ||
let b = RSAFieldElement::new([2u128; N_LIMBS]); | ||
let result = RSAFieldElement::from_string("000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001"); | ||
println!("a: {:?}", a.to_string()); | ||
println!("b: {:?}", b.to_string()); | ||
println!("r: {:?}", r.to_string()); | ||
println!("result: {:?}", result.to_string()); | ||
RSAFieldElement::assert_addition(&a, &b, &r, &true, &result); | ||
println!("case 3 passed"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
def multiply_mod_254bit_constraint(a_chunks, b_chunks, r_chunks, chunk_size=120): | ||
""" | ||
Compute (a * b) mod r where: | ||
- a, b, r are stored as arrays of 17 120-bit chunks | ||
- All intermediate calculations must stay under 254 bits | ||
- Result should be in the same format (17 120-bit chunks) | ||
Parameters: | ||
a_chunks: List[int] - 17 elements, each ≤ 2^120 | ||
b_chunks: List[int] - 17 elements, each ≤ 2^120 | ||
r_chunks: List[int] - 17 elements, each ≤ 2^120 | ||
""" | ||
n = len(a_chunks) # Should be 17 | ||
assert len(a_chunks) == len(b_chunks) == len(r_chunks) == 17 | ||
|
||
# Maximum value in each chunk | ||
MAX_CHUNK = (1 << chunk_size) | ||
|
||
def single_chunk_mult(a_i, b_j): | ||
""" | ||
Multiply two 120-bit chunks, ensuring result stays under 254 bits | ||
Returns (high, low) tuple where high is the overflow into next chunk | ||
""" | ||
product = a_i * b_j | ||
low = product & (MAX_CHUNK - 1) | ||
high = product >> chunk_size | ||
return high, low | ||
|
||
# Initialize result array (need extra space for overflow) | ||
result = [0] * (2 * n) | ||
|
||
# Perform multiplication with careful overflow handling | ||
for i in range(n): | ||
for j in range(n): | ||
# Multiply individual chunks | ||
high, low = single_chunk_mult(a_chunks[i], b_chunks[j]) | ||
|
||
# Position in result | ||
pos = i + j | ||
|
||
# Add low part | ||
result[pos] += low | ||
|
||
# Handle overflow from low part | ||
if result[pos] >= MAX_CHUNK: | ||
result[pos + 1] += result[pos] >> chunk_size | ||
result[pos] &= (MAX_CHUNK - 1) | ||
|
||
# Add high part to next chunk | ||
result[pos + 1] += high | ||
|
||
# Handle overflow from high part | ||
if result[pos + 1] >= MAX_CHUNK: | ||
result[pos + 2] += result[pos + 1] >> chunk_size | ||
result[pos + 1] &= (MAX_CHUNK - 1) | ||
|
||
# Now we need to reduce modulo r | ||
def reduce_mod_r(): | ||
""" | ||
Reduce the result modulo r, maintaining chunk structure | ||
This is a simplified Barrett reduction adapted for our chunk structure | ||
""" | ||
# Convert chunks to a more manageable form for reduction | ||
temp_result = result.copy() | ||
|
||
while any(temp_result[n:]): # While there are any nonzero high chunks | ||
# Find the highest nonzero chunk | ||
highest_chunk = 2 * n - 1 | ||
while highest_chunk >= n and temp_result[highest_chunk] == 0: | ||
highest_chunk -= 1 | ||
|
||
if highest_chunk < n: | ||
break | ||
|
||
# Perform the reduction | ||
shift = highest_chunk - n + 1 | ||
for i in range(n): | ||
if r_chunks[i] != 0: | ||
for j in range(n): | ||
if i + j + shift < 2 * n: | ||
temp_result[i + j + shift] -= ( | ||
(temp_result[highest_chunk] * r_chunks[i]) | ||
>> (chunk_size * (n - j - 1)) | ||
) & (MAX_CHUNK - 1) | ||
|
||
# Normalize chunks | ||
for i in range(2 * n - 1): | ||
while temp_result[i] < 0: | ||
temp_result[i] += MAX_CHUNK | ||
temp_result[i + 1] -= 1 | ||
while temp_result[i] >= MAX_CHUNK: | ||
temp_result[i] -= MAX_CHUNK | ||
temp_result[i + 1] += 1 | ||
|
||
return temp_result[:n] | ||
|
||
return reduce_mod_r() | ||
|
||
def test_complex_case(): | ||
# Create a more complex test case with larger numbers | ||
|
||
# Helper function to create chunks from a large number | ||
def number_to_chunks(num, chunk_size=120, num_chunks=17): | ||
chunks = [] | ||
mask = (1 << chunk_size) - 1 | ||
for _ in range(num_chunks): | ||
chunks.append(num & mask) | ||
num >>= chunk_size | ||
return chunks | ||
|
||
# Test with some large prime-like numbers | ||
# Using numbers that will exercise multiple chunks | ||
|
||
# Create a large number for testing (about 1000 bits set) | ||
a = (1 << 1000) - (1 << 500) + (1 << 200) - (1 << 100) + 12345 | ||
b = (1 << 900) - (1 << 400) + (1 << 300) - (1 << 50) + 67890 | ||
r = (1 << 1024) - (1 << 512) + (1 << 256) - (1 << 128) + 11111 | ||
|
||
# Convert to chunks | ||
a_chunks = number_to_chunks(a) | ||
b_chunks = number_to_chunks(b) | ||
r_chunks = number_to_chunks(r) | ||
|
||
# Perform the modular multiplication | ||
result_chunks = multiply_mod_254bit_constraint(a_chunks, b_chunks, r_chunks) | ||
|
||
# Convert result back to number for verification | ||
def chunks_to_number(chunks, chunk_size=120): | ||
result = 0 | ||
for i, chunk in enumerate(chunks): | ||
result += chunk << (i * chunk_size) | ||
return result | ||
|
||
# Calculate expected result using regular Python arithmetic | ||
expected = (a * b) % r | ||
actual = chunks_to_number(result_chunks) | ||
|
||
print("Test with large numbers:") | ||
print(f"a (bits): {a.bit_length()}") | ||
print(f"b (bits): {b.bit_length()}") | ||
print(f"r (bits): {r.bit_length()}") | ||
print("a:", a) | ||
print("b:", b) | ||
print("r:", r) | ||
print("expected:", expected) | ||
print("actual:", actual) | ||
print("\nFirst few chunks of a:", a_chunks[:3]) | ||
print("First few chunks of b:", b_chunks[:3]) | ||
print("First few chunks of r:", r_chunks[:3]) | ||
print("\nFirst few chunks of result:", result_chunks[:3]) | ||
print(f"\nExpected result (bits): {expected.bit_length()}") | ||
print(f"Actual result (bits): {actual.bit_length()}") | ||
print(f"Results match: {expected == actual}") | ||
|
||
# Additional verification | ||
print("\nVerification that no chunk exceeds 120 bits:") | ||
max_chunk_size = max(chunk.bit_length() for chunk in result_chunks) | ||
print(f"Maximum chunk size in result: {max_chunk_size} bits") | ||
print(f"All chunks within 120-bit limit: {max_chunk_size <= 120}") | ||
|
||
# Run the test | ||
test_complex_case() |