diff --git a/zk_prover/src/circuits/merkle_sum_tree.rs b/zk_prover/src/circuits/merkle_sum_tree.rs index f095dfea..9a6cfa12 100644 --- a/zk_prover/src/circuits/merkle_sum_tree.rs +++ b/zk_prover/src/circuits/merkle_sum_tree.rs @@ -3,7 +3,7 @@ use crate::chips::poseidon::hash::{PoseidonChip, PoseidonConfig}; use crate::chips::poseidon::poseidon_spec::PoseidonSpec; use crate::chips::range::range_check::{RangeCheckChip, RangeCheckConfig}; use crate::circuits::traits::CircuitBase; -use crate::merkle_sum_tree::{big_uint_to_fp, Entry, MerkleProof}; +use crate::merkle_sum_tree::{big_uint_to_fp, Entry, MerkleProof, Node}; use halo2_proofs::circuit::{AssignedCell, Layouter, SimpleFloorPlanner}; use halo2_proofs::halo2curves::bn256::Fr as Fp; use halo2_proofs::plonk::{ @@ -31,7 +31,7 @@ pub struct MstInclusionCircuit, pub path_element_balances: Vec<[Fp; N_ASSETS]>, pub path_indices: Vec, - pub root_hash: Fp, + pub root: Node, } impl CircuitExt @@ -40,13 +40,15 @@ where [usize; 2 * (1 + N_ASSETS)]: Sized, [usize; N_ASSETS + 1]: Sized, { - /// Returns the number of public inputs of the circuit. It is 2, namely the laef hash to be verified inclusion of and the root hash of the merkle sum tree. + /// Returns the number of public inputs of the circuit. It is {2 + N_ASSETS}, namely the leaf hash to be verified inclusion of, the root hash of the merkle sum tree and the root balances of the merkle sum tree. fn num_instance(&self) -> Vec { - vec![2] + vec![{ 2 + N_ASSETS }] } /// Returns the values of the public inputs of the circuit. Namely the leaf hash to be verified inclusion of and the root hash of the merkle sum tree. fn instances(&self) -> Vec> { - vec![vec![self.entry.compute_leaf().hash, self.root_hash]] + let mut instance = vec![self.entry.compute_leaf().hash, self.root.hash]; + instance.extend_from_slice(&self.root.balances); + vec![instance] } } @@ -57,6 +59,8 @@ impl CircuitBa impl MstInclusionCircuit +where + [usize; N_ASSETS + 1]: Sized, { pub fn init_empty() -> Self { Self { @@ -64,7 +68,7 @@ impl path_element_hashes: vec![Fp::zero(); LEVELS], path_element_balances: vec![[Fp::zero(); N_ASSETS]; LEVELS], path_indices: vec![Fp::zero(); LEVELS], - root_hash: Fp::zero(), + root: Node::init_empty(), } } @@ -85,7 +89,7 @@ impl path_element_hashes: merkle_proof.sibling_hashes, path_element_balances: merkle_proof.sibling_sums, path_indices: merkle_proof.path_indices, - root_hash: merkle_proof.root_hash, + root: merkle_proof.root, } } } @@ -191,6 +195,7 @@ where impl Circuit for MstInclusionCircuit where + [usize; N_ASSETS + 1]: Sized, [usize; 2 * (1 + N_ASSETS)]: Sized, { type Config = MstInclusionConfig; @@ -374,6 +379,16 @@ where config.instance, )?; + // expose the last current balances, namely the root balances, as public input + for (i, balance) in current_balances.iter().enumerate() { + self.expose_public( + layouter.namespace(|| format!("public root balance {}", i)), + balance, + 2 + i, + config.instance, + )?; + } + // perform range check on the balances of the root to make sure these lie in the range defined by N_BYTES for balance in current_balances.iter() { range_check_chip.assign(layouter.namespace(|| "range check root balance"), balance)?; diff --git a/zk_prover/src/circuits/tests.rs b/zk_prover/src/circuits/tests.rs index 50aab7a4..183868cc 100644 --- a/zk_prover/src/circuits/tests.rs +++ b/zk_prover/src/circuits/tests.rs @@ -41,6 +41,7 @@ mod test { let valid_prover = MockProver::run(K, &circuit, circuit.instances()).unwrap(); assert_eq!(circuit.instances()[0].len(), circuit.num_instance()[0]); + assert_eq!(circuit.instances()[0].len(), 2 + N_ASSETS); valid_prover.assert_satisfied(); } @@ -78,6 +79,21 @@ mod test { // verify the proof to be true assert!(full_verifier(¶ms, &vk, proof, circuit.instances())); + + // the user should perform the check on the public inputs + // public input #0 is the leaf hash + let expected_leaf_hash = user_entry.compute_leaf().hash; + assert_eq!(circuit.instances()[0][0], expected_leaf_hash); + + // public input #1 is the root hash + let expected_root_hash = merkle_sum_tree.root().hash; + assert_eq!(circuit.instances()[0][1], expected_root_hash); + + // public inputs [2, 2+N_ASSETS - 1] are the root balances + let expected_root_balances = merkle_sum_tree.root().balances; + for i in 0..N_ASSETS { + assert_eq!(circuit.instances()[0][2 + i], expected_root_balances[i]); + } } // Passing an invalid root hash in the instance column should fail the permutation check between the computed root hash and the instance column root hash @@ -157,6 +173,7 @@ mod test { // Passing an invalid entry balance as input for the witness generation should fail: // - the permutation check between the leaf hash and the instance column leaf hash // - the permutation check between the computed root hash and the instance column root hash + // - the permutations checks between the computed root balances and the instance column root balances #[test] fn test_invalid_entry_balance_as_witness() { let merkle_sum_tree = @@ -202,6 +219,20 @@ mod test { offset: 36 } }, + VerifyFailure::Permutation { + column: (Any::advice(), 0).into(), + location: FailureLocation::InRegion { + region: (95, "assign value to perform range check").into(), + offset: 0 + } + }, + VerifyFailure::Permutation { + column: (Any::advice(), 0).into(), + location: FailureLocation::InRegion { + region: (96, "assign value to perform range check").into(), + offset: 0 + } + }, VerifyFailure::Permutation { column: (Any::Instance, 0).into(), location: FailureLocation::OutsideRegion { row: 0 } @@ -210,6 +241,14 @@ mod test { column: (Any::Instance, 0).into(), location: FailureLocation::OutsideRegion { row: 1 } }, + VerifyFailure::Permutation { + column: (Any::Instance, 0).into(), + location: FailureLocation::OutsideRegion { row: 2 } + }, + VerifyFailure::Permutation { + column: (Any::Instance, 0).into(), + location: FailureLocation::OutsideRegion { row: 3 } + }, ]) ); } @@ -424,7 +463,8 @@ mod test { ); } - // Adding a balance at the verge of overflowing should fail the range check for any following computed sum and, because we are adding a fake balance, the root hash check should fail too + // Adding a balance at the verge of overflowing should fail the range check for any following computed sum and, because we are adding a fake balance. + // Furthermore, the public input check on the root hash and on root_balances[0] should fail too #[test] fn test_balance_not_in_range() { let merkle_sum_tree = @@ -495,6 +535,13 @@ mod test { offset: 36 } }, + VerifyFailure::Permutation { + column: (Any::advice(), 0).into(), + location: FailureLocation::InRegion { + region: (95, "assign value to perform range check").into(), + offset: 0 + } + }, VerifyFailure::Permutation { column: (Any::advice(), 0).into(), location: FailureLocation::InRegion { @@ -506,6 +553,10 @@ mod test { column: (Any::Instance, 0).into(), location: FailureLocation::OutsideRegion { row: 1 } }, + VerifyFailure::Permutation { + column: (Any::Instance, 0).into(), + location: FailureLocation::OutsideRegion { row: 2 } + }, ]) ); } diff --git a/zk_prover/src/merkle_sum_tree/mod.rs b/zk_prover/src/merkle_sum_tree/mod.rs index 644e2509..d185ef50 100644 --- a/zk_prover/src/merkle_sum_tree/mod.rs +++ b/zk_prover/src/merkle_sum_tree/mod.rs @@ -9,7 +9,7 @@ use halo2_proofs::halo2curves::bn256::Fr as Fp; #[derive(Clone, Debug)] pub struct MerkleProof { pub leaf: Node, - pub root_hash: Fp, + pub root: Node, pub sibling_hashes: Vec, pub sibling_sums: Vec<[Fp; N_ASSETS]>, pub path_indices: Vec, diff --git a/zk_prover/src/merkle_sum_tree/node.rs b/zk_prover/src/merkle_sum_tree/node.rs index e4eef23e..e4d43a84 100644 --- a/zk_prover/src/merkle_sum_tree/node.rs +++ b/zk_prover/src/merkle_sum_tree/node.rs @@ -33,6 +33,16 @@ impl Node { } } + pub fn init_empty() -> Node + where + [usize; N_ASSETS + 1]: Sized, + { + Node { + hash: Fp::zero(), + balances: [Fp::zero(); N_ASSETS], + } + } + /// Builds a leaf-level node of the MST pub fn leaf(username: &BigUint, balances: &[BigUint; N_ASSETS]) -> Node where diff --git a/zk_prover/src/merkle_sum_tree/tests.rs b/zk_prover/src/merkle_sum_tree/tests.rs index 24728d61..0b598c46 100644 --- a/zk_prover/src/merkle_sum_tree/tests.rs +++ b/zk_prover/src/merkle_sum_tree/tests.rs @@ -1,7 +1,7 @@ #[cfg(test)] mod test { - use crate::merkle_sum_tree::utils::{big_uint_to_fp, poseidon_node}; + use crate::merkle_sum_tree::utils::big_uint_to_fp; use crate::merkle_sum_tree::{Entry, MerkleSumTree, Tree}; use num_bigint::{BigUint, ToBigUint}; @@ -81,7 +81,7 @@ mod test { // shouldn't verify a proof with a wrong root hash let mut proof_invalid_2 = proof.clone(); - proof_invalid_2.root_hash = 0.into(); + proof_invalid_2.root.hash = 0.into(); assert!(!merkle_tree.verify_proof(&proof_invalid_2)); // shouldn't verify a proof with a wrong computed balance diff --git a/zk_prover/src/merkle_sum_tree/tree.rs b/zk_prover/src/merkle_sum_tree/tree.rs index 73e9426e..248c5960 100644 --- a/zk_prover/src/merkle_sum_tree/tree.rs +++ b/zk_prover/src/merkle_sum_tree/tree.rs @@ -53,7 +53,7 @@ pub trait Tree { Ok(MerkleProof { leaf: leaf.clone(), - root_hash: root.hash, + root: root.clone(), sibling_hashes, sibling_sums, path_indices, @@ -68,8 +68,6 @@ pub trait Tree { { let mut node = proof.leaf.clone(); - let mut balances = proof.leaf.balances; - for i in 0..proof.sibling_hashes.len() { let sibling_node = Node { hash: proof.sibling_hashes[i], @@ -81,14 +79,9 @@ pub trait Tree { } else { node = Node::middle(&sibling_node, &node); } - - for (balance, sibling_balance) in balances.iter_mut().zip(sibling_node.balances.iter()) - { - *balance += sibling_balance; - } } - proof.root_hash == node.hash && balances == node.balances + proof.root.hash == node.hash && proof.root.balances == node.balances } /// Returns the index of the user with the given username and balances in the tree