From 6b16a780c14eef8d8b77af9c6e23690c4e6ca1da Mon Sep 17 00:00:00 2001 From: zhenfei Date: Tue, 10 Dec 2024 09:33:56 -0500 Subject: [PATCH] fmt and readability --- rsa_circuit/src/constants.rs | 3 +- rsa_circuit/src/tests/native.rs | 105 ++++++++++- rsa_circuit/src/tests/u2048_mul.rs | 206 ++++++++++++++-------- rsa_circuit/src/tests/u2048_mul_no_mod.rs | 59 ++++++- 4 files changed, 293 insertions(+), 80 deletions(-) diff --git a/rsa_circuit/src/constants.rs b/rsa_circuit/src/constants.rs index 5acc0bd..da68ac2 100644 --- a/rsa_circuit/src/constants.rs +++ b/rsa_circuit/src/constants.rs @@ -2,8 +2,7 @@ use halo2curves::bn256::Fr; // we use 18 limbs, each with 120 bits, to store a 2048 bit integer pub const N_LIMBS: usize = 18; -// // 2048 bits = 256 bytes -// pub(crate) const N_BYTES: usize = 256; + // Each 120 bits limb needs 30 hex number to store pub(crate) const HEX_PER_LIMB: usize = 30; diff --git a/rsa_circuit/src/tests/native.rs b/rsa_circuit/src/tests/native.rs index 244e9bc..a6162a9 100644 --- a/rsa_circuit/src/tests/native.rs +++ b/rsa_circuit/src/tests/native.rs @@ -83,7 +83,26 @@ fn test_assert_rsa_addition() { let mut a = RSAFieldElement::new([MASK120 - 1; N_LIMBS]); a.data[N_LIMBS - 1] = MASK8 - 1; let b = RSAFieldElement::new([2u128; N_LIMBS]); - let result = RSAFieldElement::from_string("000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001000000000000000000000000000001"); + let result = RSAFieldElement::from_string( + "000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001\ + 000000000000000000000000000001", + ); println!("a: {:?}", a.to_string()); println!("b: {:?}", b.to_string()); @@ -103,8 +122,46 @@ fn test_assert_rsa_multiplication() { let a = RSAFieldElement::new([1u128; N_LIMBS]); let b = RSAFieldElement::new([2u128; N_LIMBS]); - let carry = RSAFieldElement::from_string("0000000000000000000000000000000200000000000000000000000000000400000000000000000000000000000600000000000000000000000000000800000000000000000000000000000a00000000000000000000000000000c00000000000000000000000000000e00000000000000000000000000001000000000000000000000000000001200000000000000000000000000001400000000000000000000000000001600000000000000000000000000001800000000000000000000000000001a00000000000000000000000000001c00000000000000000000000000001e0000000000000000000000000000200000000000000000000000000000220000000000000000000000000000"); - let result = RSAFieldElement::from_string("00000000000000000000000000002402000000000000000000000000002204000000000000000000000000002006000000000000000000000000001e08000000000000000000000000001c0a000000000000000000000000001a0c00000000000000000000000000180e000000000000000000000000001610000000000000000000000000001412000000000000000000000000001214000000000000000000000000001016000000000000000000000000000e18000000000000000000000000000c1a000000000000000000000000000a1c00000000000000000000000000081e0000000000000000000000000006200000000000000000000000000004220000000000000000000000000002"); + let carry = RSAFieldElement::from_string( + "000000000000000000000000000000\ + 020000000000000000000000000000\ + 040000000000000000000000000000\ + 060000000000000000000000000000\ + 080000000000000000000000000000\ + 0a0000000000000000000000000000\ + 0c0000000000000000000000000000\ + 0e0000000000000000000000000000\ + 100000000000000000000000000000\ + 120000000000000000000000000000\ + 140000000000000000000000000000\ + 160000000000000000000000000000\ + 180000000000000000000000000000\ + 1a0000000000000000000000000000\ + 1c0000000000000000000000000000\ + 1e0000000000000000000000000000\ + 200000000000000000000000000000\ + 220000000000000000000000000000", + ); + let result = RSAFieldElement::from_string( + "000000000000000000000000000024\ + 020000000000000000000000000022\ + 040000000000000000000000000020\ + 06000000000000000000000000001e\ + 08000000000000000000000000001c\ + 0a000000000000000000000000001a\ + 0c0000000000000000000000000018\ + 0e0000000000000000000000000016\ + 100000000000000000000000000014\ + 120000000000000000000000000012\ + 140000000000000000000000000010\ + 16000000000000000000000000000e\ + 18000000000000000000000000000c\ + 1a000000000000000000000000000a\ + 1c0000000000000000000000000008\ + 1e0000000000000000000000000006\ + 200000000000000000000000000004\ + 220000000000000000000000000002", + ); println!("a: {:?}", a.to_string()); println!("b: {:?}", b.to_string()); @@ -120,8 +177,46 @@ fn test_assert_rsa_multiplication() { a.data[N_LIMBS - 1] = MASK8 - 1; let b = RSAFieldElement::new([2u128; N_LIMBS]); - let carry = RSAFieldElement::from_string("000000000000000000000000000001fe0000000000000000000000000001fc0000000000000000000000000001fa0000000000000000000000000001f80000000000000000000000000001f60000000000000000000000000001f40000000000000000000000000001f20000000000000000000000000001f00000000000000000000000000001ee0000000000000000000000000001ec0000000000000000000000000001ea0000000000000000000000000001e80000000000000000000000000001e60000000000000000000000000001e40000000000000000000000000001e20000000000000000000000000001e00000000000000000000000000001de0000000000000000000000000001"); - let result = RSAFieldElement::from_string("0000000000000000000000000000dbfdffffffffffffffffffffffffffddfbffffffffffffffffffffffffffdff9ffffffffffffffffffffffffffe1f7ffffffffffffffffffffffffffe3f5ffffffffffffffffffffffffffe5f3ffffffffffffffffffffffffffe7f1ffffffffffffffffffffffffffe9efffffffffffffffffffffffffffebedffffffffffffffffffffffffffedebffffffffffffffffffffffffffefe9fffffffffffffffffffffffffff1e7fffffffffffffffffffffffffff3e5fffffffffffffffffffffffffff5e3fffffffffffffffffffffffffff7e1fffffffffffffffffffffffffff9dffffffffffffffffffffffffffffbddfffffffffffffffffffffffffffd"); + let carry = RSAFieldElement::from_string( + "000000000000000000000000000001\ + fe0000000000000000000000000001\ + fc0000000000000000000000000001\ + fa0000000000000000000000000001\ + f80000000000000000000000000001\ + f60000000000000000000000000001\ + f40000000000000000000000000001\ + f20000000000000000000000000001\ + f00000000000000000000000000001\ + ee0000000000000000000000000001\ + ec0000000000000000000000000001\ + ea0000000000000000000000000001\ + e80000000000000000000000000001\ + e60000000000000000000000000001\ + e40000000000000000000000000001\ + e20000000000000000000000000001\ + e00000000000000000000000000001\ + de0000000000000000000000000001", + ); + let result = RSAFieldElement::from_string( + "0000000000000000000000000000db\ + fdffffffffffffffffffffffffffdd\ + fbffffffffffffffffffffffffffdf\ + f9ffffffffffffffffffffffffffe1\ + f7ffffffffffffffffffffffffffe3\ + f5ffffffffffffffffffffffffffe5\ + f3ffffffffffffffffffffffffffe7\ + f1ffffffffffffffffffffffffffe9\ + efffffffffffffffffffffffffffeb\ + edffffffffffffffffffffffffffed\ + ebffffffffffffffffffffffffffef\ + e9fffffffffffffffffffffffffff1\ + e7fffffffffffffffffffffffffff3\ + e5fffffffffffffffffffffffffff5\ + e3fffffffffffffffffffffffffff7\ + e1fffffffffffffffffffffffffff9\ + dffffffffffffffffffffffffffffb\ + ddfffffffffffffffffffffffffffd", + ); println!("a: {:?}", a.to_string()); println!("b: {:?}", b.to_string()); diff --git a/rsa_circuit/src/tests/u2048_mul.rs b/rsa_circuit/src/tests/u2048_mul.rs index 57fed91..6f82340 100644 --- a/rsa_circuit/src/tests/u2048_mul.rs +++ b/rsa_circuit/src/tests/u2048_mul.rs @@ -276,14 +276,78 @@ fn test_mul_mod() { } { - let x = BigUint::from_str_radix("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff000000000000000000000000000000", 16).unwrap(); - let modulus = BigUint::from_str_radix("80000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000", 16).unwrap(); + let x = BigUint::from_str_radix( + "7f\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + 000000000000000000000000000000", + 16, + ) + .unwrap(); + let modulus = BigUint::from_str_radix( + "80\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000001\ + 000000000000000000000000000000", + 16, + ) + .unwrap(); + let res = BigUint::from_str_radix( "4000000000000000000000000000000000000000000000000000000000000", 16, ) .unwrap(); - let carry= BigUint::from_str_radix("7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd000000000000000000000000000000", 16).unwrap(); + let carry = BigUint::from_str_radix( + "7f\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + fffffffffffffffffffffffffffffd\ + 000000000000000000000000000000", + 16, + ) + .unwrap(); assert_eq!(&x * &x, &res + &carry * &modulus); let x = RSAFieldElement::from_big_uint(x); @@ -355,72 +419,72 @@ fn test_mul_mod() { assert_eq!(output, vec![true]); } - // // Negative test cases - // { - // // Test case 8: Result >= modulus - // let mut x = [[0, 0]; N_LIMBS]; - // let mut y = [[0, 0]; N_LIMBS]; - // let mut result = [[0, 0]; N_LIMBS]; - // let carry = [[0, 0]; N_LIMBS]; - // let mut modulus = [[0, 0]; N_LIMBS]; - - // x[0] = [7, 0]; - // y[0] = [5, 0]; - // result[0] = [10, 0]; // Invalid: result >= modulus - // modulus[0] = [10, 0]; - - // let assignment = MulModCircuit::::create_circuit(x, y, result, carry, modulus); - // let witness = compile_result - // .witness_solver - // .solve_witness(&assignment) - // .unwrap(); - // let output = compile_result.layered_circuit.run(&witness); - // assert_eq!(output, vec![false]); - // } - - // { - // // Test case 9: Incorrect carry value - // let mut x = [[0, 0]; N_LIMBS]; - // let mut y = [[0, 0]; N_LIMBS]; - // let mut result = [[0, 0]; N_LIMBS]; - // let mut carry = [[0, 0]; N_LIMBS]; - // let mut modulus = [[0, 0]; N_LIMBS]; - - // x[0] = [7, 0]; - // y[0] = [5, 0]; - // result[0] = [5, 0]; - // carry[0] = [2, 0]; // Wrong carry (should be 3) - // modulus[0] = [10, 0]; - - // let assignment = MulModCircuit::::create_circuit(x, y, result, carry, modulus); - // let witness = compile_result - // .witness_solver - // .solve_witness(&assignment) - // .unwrap(); - // let output = compile_result.layered_circuit.run(&witness); - // assert_eq!(output, vec![false]); - // } - - // { - // // Test case 10: Incorrect result - // let mut x = [[0, 0]; N_LIMBS]; - // let mut y = [[0, 0]; N_LIMBS]; - // let mut result = [[0, 0]; N_LIMBS]; - // let mut carry = [[0, 0]; N_LIMBS]; - // let mut modulus = [[0, 0]; N_LIMBS]; - - // x[0] = [7, 0]; - // y[0] = [5, 0]; - // result[0] = [6, 0]; // Wrong result (should be 5) - // carry[0] = [3, 0]; - // modulus[0] = [10, 0]; - - // let assignment = MulModCircuit::::create_circuit(x, y, result, carry, modulus); - // let witness = compile_result - // .witness_solver - // .solve_witness(&assignment) - // .unwrap(); - // let output = compile_result.layered_circuit.run(&witness); - // assert_eq!(output, vec![false]); - // } + // Negative test cases + { + // Test case 8: Result >= modulus + let mut x = [[0, 0]; N_LIMBS]; + let mut y = [[0, 0]; N_LIMBS]; + let mut result = [[0, 0]; N_LIMBS]; + let carry = [[0, 0]; N_LIMBS]; + let mut modulus = [[0, 0]; N_LIMBS]; + + x[0] = [7, 0]; + y[0] = [5, 0]; + result[0] = [10, 0]; // Invalid: result >= modulus + modulus[0] = [10, 0]; + + let assignment = MulModCircuit::::create_circuit(x, y, result, carry, modulus); + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![false]); + } + + { + // Test case 9: Incorrect carry value + let mut x = [[0, 0]; N_LIMBS]; + let mut y = [[0, 0]; N_LIMBS]; + let mut result = [[0, 0]; N_LIMBS]; + let mut carry = [[0, 0]; N_LIMBS]; + let mut modulus = [[0, 0]; N_LIMBS]; + + x[0] = [7, 0]; + y[0] = [5, 0]; + result[0] = [5, 0]; + carry[0] = [2, 0]; // Wrong carry (should be 3) + modulus[0] = [10, 0]; + + let assignment = MulModCircuit::::create_circuit(x, y, result, carry, modulus); + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![false]); + } + + { + // Test case 10: Incorrect result + let mut x = [[0, 0]; N_LIMBS]; + let mut y = [[0, 0]; N_LIMBS]; + let mut result = [[0, 0]; N_LIMBS]; + let mut carry = [[0, 0]; N_LIMBS]; + let mut modulus = [[0, 0]; N_LIMBS]; + + x[0] = [7, 0]; + y[0] = [5, 0]; + result[0] = [6, 0]; // Wrong result (should be 5) + carry[0] = [3, 0]; + modulus[0] = [10, 0]; + + let assignment = MulModCircuit::::create_circuit(x, y, result, carry, modulus); + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![false]); + } } diff --git a/rsa_circuit/src/tests/u2048_mul_no_mod.rs b/rsa_circuit/src/tests/u2048_mul_no_mod.rs index e6b84c3..65e8434 100644 --- a/rsa_circuit/src/tests/u2048_mul_no_mod.rs +++ b/rsa_circuit/src/tests/u2048_mul_no_mod.rs @@ -256,9 +256,64 @@ fn test_mul_without_mod() { } { - let x = BigUint::from_str_radix("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff000000000000000000000000000000", 16).unwrap(); + let x = BigUint::from_str_radix( + "7f\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + 000000000000000000000000000000", + 16, + ) + .unwrap(); let mut res = BigUint::from_str_radix( - "3fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000", + "3fff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffffff\ + ffffffffffffffffffffffffffff00\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000000\ + 000000000000000000000000000001\ + 000000000000000000000000000000\ + 000000000000000000000000000000", 16, ) .unwrap();