From 45c8b858818baf603a37d70aaa89913d1e6218b0 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Tue, 3 Dec 2024 10:34:38 -0500 Subject: [PATCH] finished a buggy addition --- rsa_circuit/src/tests.rs | 1 + rsa_circuit/src/tests/u120_comp.rs | 124 ++++++++++++++++++++++++++++ rsa_circuit/src/tests/u2048_comp.rs | 75 +++++++++-------- rsa_circuit/src/u120.rs | 15 +++- rsa_circuit/src/u2048.rs | 92 ++++++++++++++++++++- rsa_circuit/src/util.rs | 2 +- 6 files changed, 267 insertions(+), 42 deletions(-) create mode 100644 rsa_circuit/src/tests/u120_comp.rs diff --git a/rsa_circuit/src/tests.rs b/rsa_circuit/src/tests.rs index b3eda11..7fd5957 100644 --- a/rsa_circuit/src/tests.rs +++ b/rsa_circuit/src/tests.rs @@ -1,6 +1,7 @@ mod native; mod u120_add; mod u120_mul; +mod u120_comp; mod u2048_add; mod u2048_comp; mod util; diff --git a/rsa_circuit/src/tests/u120_comp.rs b/rsa_circuit/src/tests/u120_comp.rs new file mode 100644 index 0000000..9420718 --- /dev/null +++ b/rsa_circuit/src/tests/u120_comp.rs @@ -0,0 +1,124 @@ +use std::mem::transmute; +use expander_compiler::frontend::*; +use expander_compiler::{declare_circuit, frontend::{BN254Config, Define, Variable, API}}; +use halo2curves::bn256::Fr; + +use crate::constants::MASK120; +use crate::u120::is_less_than_u120; + +declare_circuit!(LessThanCircuit { + x: Variable, + y: Variable, + result: Variable, +}); + +impl Define for LessThanCircuit { + fn define(&self, builder: &mut API) { + let res = is_less_than_u120(&self.x, &self.y, builder); + builder.assert_is_equal(res, self.result); + } +} + +impl LessThanCircuit { + fn create_circuit(x: [u64; 2], y: [u64; 2], result: [u64; 2]) -> LessThanCircuit { + Self { + x: Fr::from_raw([x[0], x[1], 0, 0]), + y: Fr::from_raw([y[0], y[1], 0, 0]), + result: Fr::from_raw([result[0], result[1], 0, 0]), + } + } +} + +#[test] +fn test_u120_less_than() { + let compile_result = compile(&LessThanCircuit::default()).unwrap(); + + { + // Test case: Simple less than + let x = [5, 0]; + let y = [10, 0]; + let result = [1, 0]; // true: 5 < 10 + + let assignment = LessThanCircuit::::create_circuit(x, y, result); + let witness = compile_result.witness_solver.solve_witness(&assignment).unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } + { + // Test case: Equal values + let x = [42, 0]; + let y = [42, 0]; + let result = [0, 0]; // false: 42 = 42 + + let assignment = LessThanCircuit::::create_circuit(x, y, result); + let witness = compile_result.witness_solver.solve_witness(&assignment).unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } + { + // Test case: Greater than + let x = [100, 0]; + let y = [50, 0]; + let result = [0, 0]; // false: 100 > 50 + + let assignment = LessThanCircuit::::create_circuit(x, y, result); + let witness = compile_result.witness_solver.solve_witness(&assignment).unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } + { + // Test case: Using second limb + let x = [0, 1]; // 2^64 + let y = [u64::MAX, 0]; + let result = [0, 0]; // false: 2^64 > u64::MAX + + let assignment = LessThanCircuit::::create_circuit(x, y, result); + let witness = compile_result.witness_solver.solve_witness(&assignment).unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } + { + // Test case: Large numbers near 120-bit limit + let x = unsafe { transmute(MASK120 - 1) }; // 2^120 - 2 + let y = unsafe { transmute(MASK120) }; // 2^120 - 1 + let result = [1, 0]; // true: (2^120 - 2) < (2^120 - 1) + + let assignment = LessThanCircuit::::create_circuit(x, y, result); + let witness = compile_result.witness_solver.solve_witness(&assignment).unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } + { + // Test case: Equal large numbers + let x = unsafe { transmute(MASK120) }; // 2^120 - 1 + let y = unsafe { transmute(MASK120) }; // 2^120 - 1 + let result = [0, 0]; // false: equal values + + let assignment = LessThanCircuit::::create_circuit(x, y, result); + let witness = compile_result.witness_solver.solve_witness(&assignment).unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } + { + // Test case: Negative case (incorrect result) + let x = [5, 0]; + let y = [10, 0]; + let result = [0, 0]; // incorrect: should be 1 since 5 < 10 + + let assignment = LessThanCircuit::::create_circuit(x, y, result); + let witness = compile_result.witness_solver.solve_witness(&assignment).unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![false]); + } + { + // Test case: Negative case (incorrect result) + let x = [5, 0]; + let y = [5, 0]; + let result = [1, 0]; // incorrect: should be 0 since 5 = 5 + + let assignment = LessThanCircuit::::create_circuit(x, y, result); + let witness = compile_result.witness_solver.solve_witness(&assignment).unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![false]); + } +} \ No newline at end of file diff --git a/rsa_circuit/src/tests/u2048_comp.rs b/rsa_circuit/src/tests/u2048_comp.rs index e78f6fd..96d1772 100644 --- a/rsa_circuit/src/tests/u2048_comp.rs +++ b/rsa_circuit/src/tests/u2048_comp.rs @@ -19,7 +19,7 @@ impl Define for CompareCircuit { let x = U2048Variable { limbs: self.x }; let y = U2048Variable { limbs: self.y }; - let comparison_result = x.unconstrained_greater_eq(&y, builder); + let comparison_result = x.assert_is_less_than(&y, builder); builder.assert_is_equal(comparison_result, self.result); } } @@ -54,7 +54,7 @@ fn test_u2048_comparison() { let x = vec![5; N_LIMBS]; let y = vec![5; N_LIMBS]; - let assignment = CompareCircuit::::create_circuit(x, y, true); + let assignment = CompareCircuit::::create_circuit(x, y, false); // x < y is false when equal let witness = compile_result .witness_solver .solve_witness(&assignment) @@ -65,13 +65,13 @@ fn test_u2048_comparison() { } { - // Test case: Greater in most significant limb + // Test case: Less than in most significant limb let mut x = vec![0; N_LIMBS]; let mut y = vec![0; N_LIMBS]; - x[N_LIMBS - 1] = 10; - y[N_LIMBS - 1] = 5; + x[N_LIMBS - 1] = 5; + y[N_LIMBS - 1] = 10; - let assignment = CompareCircuit::::create_circuit(x, y, true); + let assignment = CompareCircuit::::create_circuit(x, y, true); // x < y is true let witness = compile_result .witness_solver .solve_witness(&assignment) @@ -82,13 +82,13 @@ fn test_u2048_comparison() { } { - // Test case: Less in most significant limb + // Test case: Greater in most significant limb let mut x = vec![0; N_LIMBS]; let mut y = vec![0; N_LIMBS]; - x[N_LIMBS - 1] = 5; - y[N_LIMBS - 1] = 10; + x[N_LIMBS - 1] = 10; + y[N_LIMBS - 1] = 5; - let assignment = CompareCircuit::::create_circuit(x, y, false); + let assignment = CompareCircuit::::create_circuit(x, y, false); // x < y is false let witness = compile_result .witness_solver .solve_witness(&assignment) @@ -99,15 +99,15 @@ fn test_u2048_comparison() { } { - // Test case: Equal in most significant limb, greater in next limb + // Test case: Equal in most significant limb, less than in next limb let mut x = vec![0; N_LIMBS]; let mut y = vec![0; N_LIMBS]; x[N_LIMBS - 1] = 5; y[N_LIMBS - 1] = 5; - x[N_LIMBS - 2] = 10; - y[N_LIMBS - 2] = 5; + x[N_LIMBS - 2] = 5; + y[N_LIMBS - 2] = 10; - let assignment = CompareCircuit::::create_circuit(x, y, true); + let assignment = CompareCircuit::::create_circuit(x, y, true); // x < y is true let witness = compile_result .witness_solver .solve_witness(&assignment) @@ -119,11 +119,11 @@ fn test_u2048_comparison() { // Negative test cases { - // Negative test: Claiming x >= y when x < y + // Negative test: Claiming x < y when x > y let mut x = vec![0; N_LIMBS]; let mut y = vec![0; N_LIMBS]; - x[N_LIMBS - 1] = 5; - y[N_LIMBS - 1] = 10; + x[N_LIMBS - 1] = 10; + y[N_LIMBS - 1] = 5; let assignment = CompareCircuit::::create_circuit(x, y, true); // incorrect result let witness = compile_result @@ -140,7 +140,7 @@ fn test_u2048_comparison() { let x = vec![5; N_LIMBS]; let y = vec![5; N_LIMBS]; - let assignment = CompareCircuit::::create_circuit(x, y, false); // incorrect result + let assignment = CompareCircuit::::create_circuit(x, y, true); // incorrect result let witness = compile_result .witness_solver .solve_witness(&assignment) @@ -150,23 +150,22 @@ fn test_u2048_comparison() { assert_eq!(output, vec![false]); // should fail } - // { - // // soundness bug: this test should fail but it passes - // // Negative test: Equal in most significant limbs but claiming wrong result for lower limbs - // let mut x = vec![0; N_LIMBS]; - // let mut y = vec![0; N_LIMBS]; - // x[N_LIMBS - 1] = 5; - // y[N_LIMBS - 1] = 5; - // x[N_LIMBS - 2] = 4; - // y[N_LIMBS - 2] = 5; - - // let assignment = CompareCircuit::::create_circuit(x, y, true); // incorrect result - // let witness = compile_result - // .witness_solver - // .solve_witness(&assignment) - // .unwrap(); - - // let output = compile_result.layered_circuit.run(&witness); - // assert_eq!(output, vec![false]); // should fail - // } -} + { + // Test case: Equal in most significant limb, comparison in lower limb + let mut x = vec![0; N_LIMBS]; + let mut y = vec![0; N_LIMBS]; + x[N_LIMBS - 1] = 5; + y[N_LIMBS - 1] = 5; + x[N_LIMBS - 2] = 4; + y[N_LIMBS - 2] = 5; + + let assignment = CompareCircuit::::create_circuit(x, y, true); // x < y is true + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} \ No newline at end of file diff --git a/rsa_circuit/src/u120.rs b/rsa_circuit/src/u120.rs index b8d991f..31f3096 100644 --- a/rsa_circuit/src/u120.rs +++ b/rsa_circuit/src/u120.rs @@ -79,7 +79,18 @@ pub(crate) fn is_less_than_u120( builder: &mut API, ) -> Variable { let diff = builder.sub(x, y); - // if x < y, then diff will underflow a range check will fail + let byte_decomp = crate::util::unconstrained_byte_decomposition(&diff, builder); + let res = builder.unconstrained_lesser(x, y); - todo!() + // if res = 1: x < y, then diff will underflow so byte_decomp[31] will be non-zero + // if res = 0: x >= y, then diff will not underflow so byte_decomp[31] will be zero + let zero = builder.constant(0); + let one = builder.constant(1); + let one_minus_res = builder.sub(one, res); + let t1 = builder.mul(one_minus_res, byte_decomp[31]); + let t2 = builder.mul(res, zero); + let t3 = builder.add(t1, t2); + builder.assert_is_zero(t3); + + res } diff --git a/rsa_circuit/src/u2048.rs b/rsa_circuit/src/u2048.rs index 0819e47..14ff7e1 100644 --- a/rsa_circuit/src/u2048.rs +++ b/rsa_circuit/src/u2048.rs @@ -1,6 +1,6 @@ use expander_compiler::frontend::{extra::UnconstrainedAPI, BN254Config, BasicAPI, Variable, API}; -use crate::{constants::N_LIMBS, u120}; +use crate::{constants::N_LIMBS, u120::{self, is_less_than_u120}}; #[derive(Debug, Clone, Copy)] pub struct U2048Variable { @@ -48,6 +48,92 @@ impl U2048Variable { builder.unconstrained_bit_or(result, all_eq_so_far) } + #[inline] + pub fn assert_is_less_than( + &self, + other: &Self, + builder: &mut API, + ) -> Variable { + let mut result = builder.constant(0); + let mut all_eq_so_far = builder.constant(1); + + // Compare limbs from most significant to least significant + for i in (0..N_LIMBS).rev() { + // Compare current limbs using u120 comparison + let curr_less = is_less_than_u120(&self.limbs[i], &other.limbs[i], builder); + + // Check equality for current limbs + let diff = builder.sub(&self.limbs[i], &other.limbs[i]); + let curr_eq = builder.is_zero(diff); + + // If all previous limbs were equal and current limb is less + let update = builder.mul(all_eq_so_far, curr_less); + + // Update result: result = result OR (all_eq_so_far AND curr_less) + result = builder.add(result, update); + let tmp= builder.mul(result, update); + result = builder.sub(result, tmp); + + // Update equality chain: all_eq_so_far = all_eq_so_far AND curr_eq + all_eq_so_far = builder.mul(all_eq_so_far, curr_eq); + + // Assert boolean constraints + builder.assert_is_bool(result); + builder.assert_is_bool(all_eq_so_far); + builder.assert_is_bool(curr_less); + builder.assert_is_bool(curr_eq); + + // Cannot be both less and equal for current limb + let both = builder.mul(curr_less, curr_eq); + builder.assert_is_zero(both); + } + + // If all limbs were equal, result must be 0 + let equal_case = builder.mul(all_eq_so_far, result); + builder.assert_is_zero(equal_case); + + result + } + + // Helper function to check if one U2048 is greater than or equal to another + #[inline] + pub fn assert_is_greater_eq( + &self, + other: &Self, + builder: &mut API, + ) -> Variable { + let less = other.assert_is_less_than(self, builder); + let eq = self.assert_is_equal(other, builder); + + // result = less OR eq + let mut result = builder.add(less, eq); + let tmp = builder.mul(less, eq); + result = builder.sub(result, tmp); + builder.assert_is_bool(result); + + result + } + + // Helper function to check equality + #[inline] + pub fn assert_is_equal( + &self, + other: &Self, + builder: &mut API, + ) -> Variable { + let mut is_equal = builder.constant(1); + + for i in 0..N_LIMBS { + let diff = builder.sub(&self.limbs[i], &other.limbs[i]); + let curr_eq = builder.is_zero(diff); + is_equal = builder.mul(is_equal, curr_eq); + builder.assert_is_bool(curr_eq); + } + + builder.assert_is_bool(is_equal); + is_equal + } + #[inline] // add two U2048 variables with mod reductions // a + b = result + carry * modulus @@ -100,5 +186,9 @@ impl U2048Variable { // Final carry should be 0 since all numbers are within range builder.assert_is_zero(temp_carry); + + let lt = Self::assert_is_less_than(result, modulus, builder); + let one = builder.constant(1); + builder.assert_is_equal(lt, one); } } diff --git a/rsa_circuit/src/util.rs b/rsa_circuit/src/util.rs index 739234a..f0f7da0 100644 --- a/rsa_circuit/src/util.rs +++ b/rsa_circuit/src/util.rs @@ -30,7 +30,7 @@ pub(crate) fn assert_byte_decomposition( builder: &mut API, ) -> Vec { let bytes = unconstrained_byte_decomposition(x, builder); - + // todo: constraint each byte to be less than 256 via logup let inner_product = bytes.iter().zip(constant_scalars.iter()).fold( builder.constant(Fr::zero()), |acc, (byte, scalar)| {