Skip to content

Commit

Permalink
fmt and readability
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang committed Dec 10, 2024
1 parent 9634607 commit 6b16a78
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 80 deletions.
3 changes: 1 addition & 2 deletions rsa_circuit/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use halo2curves::bn256::Fr;

// we use 18 limbs, each with 120 bits, to store a 2048 bit integer
pub const N_LIMBS: usize = 18;
// // 2048 bits = 256 bytes
// pub(crate) const N_BYTES: usize = 256;

// Each 120 bits limb needs 30 hex number to store
pub(crate) const HEX_PER_LIMB: usize = 30;

Expand Down
105 changes: 100 additions & 5 deletions rsa_circuit/src/tests/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,26 @@ fn test_assert_rsa_addition() {
let mut a = RSAFieldElement::new([MASK120 - 1; N_LIMBS]);
a.data[N_LIMBS - 1] = MASK8 - 1;
let b = RSAFieldElement::new([2u128; N_LIMBS]);
let result = RSAFieldElement::from_string("000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001");
let result = RSAFieldElement::from_string(
"000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001\
000000000000000000000000000001",
);

println!("a: {:?}", a.to_string());
println!("b: {:?}", b.to_string());
Expand All @@ -103,8 +122,46 @@ fn test_assert_rsa_multiplication() {
let a = RSAFieldElement::new([1u128; N_LIMBS]);
let b = RSAFieldElement::new([2u128; N_LIMBS]);

let carry = RSAFieldElement::from_string("0000000000000000000000000000000200000000000000000000000000000400000000000000000000000000000600000000000000000000000000000800000000000000000000000000000a00000000000000000000000000000c00000000000000000000000000000e00000000000000000000000000001000000000000000000000000000001200000000000000000000000000001400000000000000000000000000001600000000000000000000000000001800000000000000000000000000001a00000000000000000000000000001c00000000000000000000000000001e0000000000000000000000000000200000000000000000000000000000220000000000000000000000000000");
let result = RSAFieldElement::from_string("00000000000000000000000000002402000000000000000000000000002204000000000000000000000000002006000000000000000000000000001e08000000000000000000000000001c0a000000000000000000000000001a0c00000000000000000000000000180e000000000000000000000000001610000000000000000000000000001412000000000000000000000000001214000000000000000000000000001016000000000000000000000000000e18000000000000000000000000000c1a000000000000000000000000000a1c00000000000000000000000000081e0000000000000000000000000006200000000000000000000000000004220000000000000000000000000002");
let carry = RSAFieldElement::from_string(
"000000000000000000000000000000\
020000000000000000000000000000\
040000000000000000000000000000\
060000000000000000000000000000\
080000000000000000000000000000\
0a0000000000000000000000000000\
0c0000000000000000000000000000\
0e0000000000000000000000000000\
100000000000000000000000000000\
120000000000000000000000000000\
140000000000000000000000000000\
160000000000000000000000000000\
180000000000000000000000000000\
1a0000000000000000000000000000\
1c0000000000000000000000000000\
1e0000000000000000000000000000\
200000000000000000000000000000\
220000000000000000000000000000",
);
let result = RSAFieldElement::from_string(
"000000000000000000000000000024\
020000000000000000000000000022\
040000000000000000000000000020\
06000000000000000000000000001e\
08000000000000000000000000001c\
0a000000000000000000000000001a\
0c0000000000000000000000000018\
0e0000000000000000000000000016\
100000000000000000000000000014\
120000000000000000000000000012\
140000000000000000000000000010\
16000000000000000000000000000e\
18000000000000000000000000000c\
1a000000000000000000000000000a\
1c0000000000000000000000000008\
1e0000000000000000000000000006\
200000000000000000000000000004\
220000000000000000000000000002",
);

println!("a: {:?}", a.to_string());
println!("b: {:?}", b.to_string());
Expand All @@ -120,8 +177,46 @@ fn test_assert_rsa_multiplication() {
a.data[N_LIMBS - 1] = MASK8 - 1;
let b = RSAFieldElement::new([2u128; N_LIMBS]);

let carry = RSAFieldElement::from_string("000000000000000000000000000001fe0000000000000000000000000001fc0000000000000000000000000001fa0000000000000000000000000001f80000000000000000000000000001f60000000000000000000000000001f40000000000000000000000000001f20000000000000000000000000001f00000000000000000000000000001ee0000000000000000000000000001ec0000000000000000000000000001ea0000000000000000000000000001e80000000000000000000000000001e60000000000000000000000000001e40000000000000000000000000001e20000000000000000000000000001e00000000000000000000000000001de0000000000000000000000000001");
let result = RSAFieldElement::from_string("0000000000000000000000000000dbfdffffffffffffffffffffffffffddfbffffffffffffffffffffffffffdff9ffffffffffffffffffffffffffe1f7ffffffffffffffffffffffffffe3f5ffffffffffffffffffffffffffe5f3ffffffffffffffffffffffffffe7f1ffffffffffffffffffffffffffe9efffffffffffffffffffffffffffebedffffffffffffffffffffffffffedebffffffffffffffffffffffffffefe9fffffffffffffffffffffffffff1e7fffffffffffffffffffffffffff3e5fffffffffffffffffffffffffff5e3fffffffffffffffffffffffffff7e1fffffffffffffffffffffffffff9dffffffffffffffffffffffffffffbddfffffffffffffffffffffffffffd");
let carry = RSAFieldElement::from_string(
"000000000000000000000000000001\
fe0000000000000000000000000001\
fc0000000000000000000000000001\
fa0000000000000000000000000001\
f80000000000000000000000000001\
f60000000000000000000000000001\
f40000000000000000000000000001\
f20000000000000000000000000001\
f00000000000000000000000000001\
ee0000000000000000000000000001\
ec0000000000000000000000000001\
ea0000000000000000000000000001\
e80000000000000000000000000001\
e60000000000000000000000000001\
e40000000000000000000000000001\
e20000000000000000000000000001\
e00000000000000000000000000001\
de0000000000000000000000000001",
);
let result = RSAFieldElement::from_string(
"0000000000000000000000000000db\
fdffffffffffffffffffffffffffdd\
fbffffffffffffffffffffffffffdf\
f9ffffffffffffffffffffffffffe1\
f7ffffffffffffffffffffffffffe3\
f5ffffffffffffffffffffffffffe5\
f3ffffffffffffffffffffffffffe7\
f1ffffffffffffffffffffffffffe9\
efffffffffffffffffffffffffffeb\
edffffffffffffffffffffffffffed\
ebffffffffffffffffffffffffffef\
e9fffffffffffffffffffffffffff1\
e7fffffffffffffffffffffffffff3\
e5fffffffffffffffffffffffffff5\
e3fffffffffffffffffffffffffff7\
e1fffffffffffffffffffffffffff9\
dffffffffffffffffffffffffffffb\
ddfffffffffffffffffffffffffffd",
);

println!("a: {:?}", a.to_string());
println!("b: {:?}", b.to_string());
Expand Down
206 changes: 135 additions & 71 deletions rsa_circuit/src/tests/u2048_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,78 @@ fn test_mul_mod() {
}

{
let x = BigUint::from_str_radix("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff000000000000000000000000000000", 16).unwrap();
let modulus = BigUint::from_str_radix("80000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000", 16).unwrap();
let x = BigUint::from_str_radix(
"7f\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
000000000000000000000000000000",
16,
)
.unwrap();
let modulus = BigUint::from_str_radix(
"80\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000000\
000000000000000000000000000001\
000000000000000000000000000000",
16,
)
.unwrap();

let res = BigUint::from_str_radix(
"4000000000000000000000000000000000000000000000000000000000000",
16,
)
.unwrap();
let carry= BigUint::from_str_radix("7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd000000000000000000000000000000", 16).unwrap();
let carry = BigUint::from_str_radix(
"7f\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
ffffffffffffffffffffffffffffff\
fffffffffffffffffffffffffffffd\
000000000000000000000000000000",
16,
)
.unwrap();
assert_eq!(&x * &x, &res + &carry * &modulus);

let x = RSAFieldElement::from_big_uint(x);
Expand Down Expand Up @@ -355,72 +419,72 @@ fn test_mul_mod() {
assert_eq!(output, vec![true]);
}

// // Negative test cases
// {
// // Test case 8: Result >= modulus
// let mut x = [[0, 0]; N_LIMBS];
// let mut y = [[0, 0]; N_LIMBS];
// let mut result = [[0, 0]; N_LIMBS];
// let carry = [[0, 0]; N_LIMBS];
// let mut modulus = [[0, 0]; N_LIMBS];

// x[0] = [7, 0];
// y[0] = [5, 0];
// result[0] = [10, 0]; // Invalid: result >= modulus
// modulus[0] = [10, 0];

// let assignment = MulModCircuit::<Fr>::create_circuit(x, y, result, carry, modulus);
// let witness = compile_result
// .witness_solver
// .solve_witness(&assignment)
// .unwrap();
// let output = compile_result.layered_circuit.run(&witness);
// assert_eq!(output, vec![false]);
// }

// {
// // Test case 9: Incorrect carry value
// let mut x = [[0, 0]; N_LIMBS];
// let mut y = [[0, 0]; N_LIMBS];
// let mut result = [[0, 0]; N_LIMBS];
// let mut carry = [[0, 0]; N_LIMBS];
// let mut modulus = [[0, 0]; N_LIMBS];

// x[0] = [7, 0];
// y[0] = [5, 0];
// result[0] = [5, 0];
// carry[0] = [2, 0]; // Wrong carry (should be 3)
// modulus[0] = [10, 0];

// let assignment = MulModCircuit::<Fr>::create_circuit(x, y, result, carry, modulus);
// let witness = compile_result
// .witness_solver
// .solve_witness(&assignment)
// .unwrap();
// let output = compile_result.layered_circuit.run(&witness);
// assert_eq!(output, vec![false]);
// }

// {
// // Test case 10: Incorrect result
// let mut x = [[0, 0]; N_LIMBS];
// let mut y = [[0, 0]; N_LIMBS];
// let mut result = [[0, 0]; N_LIMBS];
// let mut carry = [[0, 0]; N_LIMBS];
// let mut modulus = [[0, 0]; N_LIMBS];

// x[0] = [7, 0];
// y[0] = [5, 0];
// result[0] = [6, 0]; // Wrong result (should be 5)
// carry[0] = [3, 0];
// modulus[0] = [10, 0];

// let assignment = MulModCircuit::<Fr>::create_circuit(x, y, result, carry, modulus);
// let witness = compile_result
// .witness_solver
// .solve_witness(&assignment)
// .unwrap();
// let output = compile_result.layered_circuit.run(&witness);
// assert_eq!(output, vec![false]);
// }
// Negative test cases
{
// Test case 8: Result >= modulus
let mut x = [[0, 0]; N_LIMBS];
let mut y = [[0, 0]; N_LIMBS];
let mut result = [[0, 0]; N_LIMBS];
let carry = [[0, 0]; N_LIMBS];
let mut modulus = [[0, 0]; N_LIMBS];

x[0] = [7, 0];
y[0] = [5, 0];
result[0] = [10, 0]; // Invalid: result >= modulus
modulus[0] = [10, 0];

let assignment = MulModCircuit::<Fr>::create_circuit(x, y, result, carry, modulus);
let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![false]);
}

{
// Test case 9: Incorrect carry value
let mut x = [[0, 0]; N_LIMBS];
let mut y = [[0, 0]; N_LIMBS];
let mut result = [[0, 0]; N_LIMBS];
let mut carry = [[0, 0]; N_LIMBS];
let mut modulus = [[0, 0]; N_LIMBS];

x[0] = [7, 0];
y[0] = [5, 0];
result[0] = [5, 0];
carry[0] = [2, 0]; // Wrong carry (should be 3)
modulus[0] = [10, 0];

let assignment = MulModCircuit::<Fr>::create_circuit(x, y, result, carry, modulus);
let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![false]);
}

{
// Test case 10: Incorrect result
let mut x = [[0, 0]; N_LIMBS];
let mut y = [[0, 0]; N_LIMBS];
let mut result = [[0, 0]; N_LIMBS];
let mut carry = [[0, 0]; N_LIMBS];
let mut modulus = [[0, 0]; N_LIMBS];

x[0] = [7, 0];
y[0] = [5, 0];
result[0] = [6, 0]; // Wrong result (should be 5)
carry[0] = [3, 0];
modulus[0] = [10, 0];

let assignment = MulModCircuit::<Fr>::create_circuit(x, y, result, carry, modulus);
let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![false]);
}
}
Loading

0 comments on commit 6b16a78

Please sign in to comment.