From 51e4286fe4ea1482607ace8273ebb64c1a97a524 Mon Sep 17 00:00:00 2001 From: Alex Kuzmin <6849426+alxkzmn@users.noreply.github.com> Date: Mon, 18 Sep 2023 17:54:55 +0800 Subject: [PATCH] Add MST leaf update function (#156) * Add MST leaf update function * Implement code review changes * Fix documentation --- zk_prover/benches/full_solvency_flow.rs | 23 ++++- zk_prover/src/circuits/merkle_sum_tree.rs | 2 +- .../merkle_sum_tree/csv/entry_16_modified.csv | 17 ++++ zk_prover/src/merkle_sum_tree/entry.rs | 44 ++++----- zk_prover/src/merkle_sum_tree/mod.rs | 8 +- zk_prover/src/merkle_sum_tree/mst.rs | 89 +++++++++++++++++++ zk_prover/src/merkle_sum_tree/node.rs | 60 +++++++++++++ zk_prover/src/merkle_sum_tree/tests.rs | 79 +++++++++++++++- .../src/merkle_sum_tree/utils/build_tree.rs | 3 +- .../utils/create_middle_node.rs | 27 ------ zk_prover/src/merkle_sum_tree/utils/mod.rs | 1 - .../utils/proof_verification.rs | 6 +- 12 files changed, 291 insertions(+), 68 deletions(-) create mode 100644 zk_prover/src/merkle_sum_tree/csv/entry_16_modified.csv create mode 100644 zk_prover/src/merkle_sum_tree/node.rs delete mode 100644 zk_prover/src/merkle_sum_tree/utils/create_middle_node.rs diff --git a/zk_prover/benches/full_solvency_flow.rs b/zk_prover/benches/full_solvency_flow.rs index 92a0f0f2..5c11ef15 100644 --- a/zk_prover/benches/full_solvency_flow.rs +++ b/zk_prover/benches/full_solvency_flow.rs @@ -29,7 +29,7 @@ fn build_mstree(_c: &mut Criterion) { ); let bench_name = format!( - "build merkle sum tree for 2 power of {} entries with {} assets", + "build Merkle sum tree for 2 power of {} entries with {} assets", LEVELS, N_ASSETS ); @@ -40,6 +40,26 @@ fn build_mstree(_c: &mut Criterion) { }); } +fn build_sorted_mstree(_c: &mut Criterion) { + let mut criterion = Criterion::default().sample_size(SAMPLE_SIZE); + + let csv_file = format!( + "benches/csv/{}/{}_entry_2_{}.csv", + PATH_NAME, PATH_NAME, LEVELS + ); + + let bench_name = format!( + "build sorted Merkle sum tree for 2 power of {} entries with {} assets", + LEVELS, N_ASSETS + ); + + criterion.bench_function(&bench_name, |b| { + b.iter(|| { + MerkleSumTree::::new_sorted(&csv_file).unwrap(); + }) + }); +} + fn verification_key_gen_mst_inclusion_circuit(_c: &mut Criterion) { let mut criterion = Criterion::default().sample_size(SAMPLE_SIZE); @@ -239,6 +259,7 @@ fn verify_zk_proof_solvency_circuit(_c: &mut Criterion) { criterion_group!( benches, build_mstree, + build_sorted_mstree, verification_key_gen_mst_inclusion_circuit, proving_key_gen_mst_inclusion_circuit, generate_zk_proof_mst_inclusion_circuit, diff --git a/zk_prover/src/circuits/merkle_sum_tree.rs b/zk_prover/src/circuits/merkle_sum_tree.rs index df07e01c..3ed075c0 100644 --- a/zk_prover/src/circuits/merkle_sum_tree.rs +++ b/zk_prover/src/circuits/merkle_sum_tree.rs @@ -228,7 +228,7 @@ where // Assign the entry username let username = self.assign_value_to_witness( layouter.namespace(|| "assign entry username"), - big_uint_to_fp(self.entry.username_to_big_uint()), + big_uint_to_fp(self.entry.username_as_big_uint()), "entry username", config.advices[0], )?; diff --git a/zk_prover/src/merkle_sum_tree/csv/entry_16_modified.csv b/zk_prover/src/merkle_sum_tree/csv/entry_16_modified.csv new file mode 100644 index 00000000..cacf3c69 --- /dev/null +++ b/zk_prover/src/merkle_sum_tree/csv/entry_16_modified.csv @@ -0,0 +1,17 @@ +username;balances +dxGaEAii;11888,41163 +MBlfbBGI;67823,18651 +lAhWlEWZ;18651,2087 +nuZweYtO;22073,55683 +gbdSwiuY;34897,83296 +RZNneNuP;83296,16881 +YsscHXkp;31699,35479 +RkLzkDun;2086,79732 +HlQlnEYI;30605,11888 +RqkZOFYe;16881,14874 +NjCSRAfD;41163,67823 +pHniJMQY;14874,22073 +dOGIMzKR;10032,10032 +HfMDmNLp;55683,34897 +xPLKzCBl;79731,30605 +AtwIxZHo;35479,31699 diff --git a/zk_prover/src/merkle_sum_tree/entry.rs b/zk_prover/src/merkle_sum_tree/entry.rs index 73fbf6b2..be6fdb35 100644 --- a/zk_prover/src/merkle_sum_tree/entry.rs +++ b/zk_prover/src/merkle_sum_tree/entry.rs @@ -1,13 +1,12 @@ -use crate::merkle_sum_tree::utils::{big_intify_username, big_uint_to_fp, poseidon_entry}; +use crate::merkle_sum_tree::utils::big_intify_username; use crate::merkle_sum_tree::Node; -use halo2_proofs::halo2curves::bn256::Fr as Fp; use num_bigint::BigUint; /// An entry in the Merkle Sum Tree from the database of the CEX. /// It contains the username and the balances of the user. #[derive(Clone, Debug)] pub struct Entry { - username_to_big_uint: BigUint, + username_as_big_uint: BigUint, balances: [BigUint; N_ASSETS], username: String, } @@ -15,7 +14,7 @@ pub struct Entry { impl Entry { pub fn new(username: String, balances: [BigUint; N_ASSETS]) -> Result { Ok(Entry { - username_to_big_uint: big_intify_username(&username), + username_as_big_uint: big_intify_username(&username), balances, username, }) @@ -25,7 +24,7 @@ impl Entry { let empty_balances: [BigUint; N_ASSETS] = std::array::from_fn(|_| BigUint::from(0u32)); Entry { - username_to_big_uint: BigUint::from(0u32), + username_as_big_uint: BigUint::from(0u32), balances: empty_balances, username: "".to_string(), } @@ -35,33 +34,26 @@ impl Entry { where [usize; N_ASSETS + 1]: Sized, { - Node { - hash: poseidon_entry::( - big_uint_to_fp(&self.username_to_big_uint), - self.balances - .iter() - .map(big_uint_to_fp) - .collect::>() - .try_into() - .unwrap(), - ), - //Map the array of balances using big_int_to_fp: - balances: self - .balances - .iter() - .map(big_uint_to_fp) - .collect::>() - .try_into() - .unwrap(), - } + Node::leaf(&self.username_as_big_uint, &self.balances) + } + + /// Stores the new balance values + /// + /// Returns the updated node + pub fn recompute_leaf(&mut self, updated_balances: &[BigUint; N_ASSETS]) -> Node + where + [usize; N_ASSETS + 1]: Sized, + { + self.balances = updated_balances.clone(); + Node::leaf(&self.username_as_big_uint, updated_balances) } pub fn balances(&self) -> &[BigUint; N_ASSETS] { &self.balances } - pub fn username_to_big_uint(&self) -> &BigUint { - &self.username_to_big_uint + pub fn username_as_big_uint(&self) -> &BigUint { + &self.username_as_big_uint } pub fn username(&self) -> &str { diff --git a/zk_prover/src/merkle_sum_tree/mod.rs b/zk_prover/src/merkle_sum_tree/mod.rs index e7cce2f8..b7950d93 100644 --- a/zk_prover/src/merkle_sum_tree/mod.rs +++ b/zk_prover/src/merkle_sum_tree/mod.rs @@ -1,5 +1,6 @@ mod entry; mod mst; +mod node; mod tests; pub mod utils; use halo2_proofs::halo2curves::bn256::Fr as Fp; @@ -13,12 +14,7 @@ pub struct MerkleProof { pub path_indices: Vec, } -#[derive(Clone, Debug)] -pub struct Node { - pub hash: Fp, - pub balances: [Fp; N_ASSETS], -} - pub use entry::Entry; pub use mst::MerkleSumTree; +pub use node::Node; pub use utils::{big_intify_username, big_uint_to_fp}; diff --git a/zk_prover/src/merkle_sum_tree/mst.rs b/zk_prover/src/merkle_sum_tree/mst.rs index 682b5991..0a24aef9 100644 --- a/zk_prover/src/merkle_sum_tree/mst.rs +++ b/zk_prover/src/merkle_sum_tree/mst.rs @@ -22,6 +22,7 @@ pub struct MerkleSumTree { nodes: Vec>>, depth: usize, entries: Vec>, + is_sorted: bool, } impl MerkleSumTree { @@ -38,6 +39,34 @@ impl MerkleSumTree(path)?; + Self::build_tree(entries, false) + } + + /// Builds a Merkle Sum Tree from a CSV file stored at `path`. The MST leaves are sorted by the username byte values. The CSV file must be formatted as follows: + /// + /// `username;balances` + /// + /// `dxGaEAii;11888,41163` + pub fn new_sorted(path: &str) -> Result> + where + [usize; N_ASSETS + 1]: Sized, + [usize; 2 * (1 + N_ASSETS)]: Sized, + { + let mut entries = parse_csv_to_entries::<&str, N_ASSETS, N_BYTES>(path)?; + + entries.sort_by(|a, b| a.username().cmp(b.username())); + + Self::build_tree(entries, true) + } + + fn build_tree( + entries: Vec>, + is_sorted: bool, + ) -> Result, Box> + where + [usize; N_ASSETS + 1]: Sized, + [usize; 2 * (1 + N_ASSETS)]: Sized, + { let depth = (entries.len() as f64).log2().ceil() as usize; if !(1..=Self::MAX_DEPTH).contains(&depth) { @@ -55,9 +84,50 @@ impl MerkleSumTree Result, Box> + where + [usize; N_ASSETS + 1]: Sized, + [usize; 2 * (1 + N_ASSETS)]: Sized, + { + let index = self.index_of_username(username)?; + + // Update the leaf node. + let updated_leaf = self.entries[index].recompute_leaf(new_balances); + self.nodes[0][index] = updated_leaf; + + // Recompute the hashes and balances up the tree. + let mut current_index = index; + for depth in 1..=self.depth { + let parent_index = current_index / 2; + let left_child = &self.nodes[depth - 1][2 * parent_index]; + let right_child = &self.nodes[depth - 1][2 * parent_index + 1]; + self.nodes[depth][parent_index] = Node::::middle(left_child, right_child); + current_index = parent_index; + } + + let root = self.nodes[self.depth][0].clone(); + + Ok(root) + } + pub fn root(&self) -> &Node { &self.root } @@ -94,6 +164,25 @@ impl MerkleSumTree Result> + where + [usize; N_ASSETS + 1]: Sized, + { + if !self.is_sorted { + self.entries + .iter() + .enumerate() + .find(|(_, entry)| entry.username() == username) + .map(|(index, _)| index) + .ok_or_else(|| Box::from("Username not found")) + } else { + self.entries + .binary_search_by_key(&username, |entry| entry.username()) + .map_err(|_| Box::from("Username not found")) + } + } + /// Generates a MerkleProof for the user with the given index pub fn generate_proof(&self, index: usize) -> Result, &'static str> { create_proof(index, &self.entries, self.depth, &self.nodes, &self.root) diff --git a/zk_prover/src/merkle_sum_tree/node.rs b/zk_prover/src/merkle_sum_tree/node.rs new file mode 100644 index 00000000..e4eef23e --- /dev/null +++ b/zk_prover/src/merkle_sum_tree/node.rs @@ -0,0 +1,60 @@ +use halo2_proofs::halo2curves::bn256::Fr as Fp; +use num_bigint::BigUint; + +use super::{ + big_uint_to_fp, + utils::{poseidon_entry, poseidon_node}, +}; + +#[derive(Clone, Debug)] +pub struct Node { + pub hash: Fp, + pub balances: [Fp; N_ASSETS], +} +impl Node { + /// Builds a "middle" (non-leaf-level) node of the MST + pub fn middle(child_l: &Node, child_r: &Node) -> Node + where + [usize; 2 * (1 + N_ASSETS)]: Sized, + { + let mut balances_sum = [Fp::zero(); N_ASSETS]; + for (i, balance) in balances_sum.iter_mut().enumerate() { + *balance = child_l.balances[i] + child_r.balances[i]; + } + + Node { + hash: poseidon_node( + child_l.hash, + child_l.balances, + child_r.hash, + child_r.balances, + ), + balances: balances_sum, + } + } + + /// Builds a leaf-level node of the MST + pub fn leaf(username: &BigUint, balances: &[BigUint; N_ASSETS]) -> Node + where + [usize; N_ASSETS + 1]: Sized, + { + Node { + hash: poseidon_entry::( + big_uint_to_fp(username), + balances + .iter() + .map(big_uint_to_fp) + .collect::>() + .try_into() + .unwrap(), + ), + //Map the array of balances using big_int_to_fp: + balances: balances + .iter() + .map(big_uint_to_fp) + .collect::>() + .try_into() + .unwrap(), + } + } +} diff --git a/zk_prover/src/merkle_sum_tree/tests.rs b/zk_prover/src/merkle_sum_tree/tests.rs index f6a78892..164f64ba 100644 --- a/zk_prover/src/merkle_sum_tree/tests.rs +++ b/zk_prover/src/merkle_sum_tree/tests.rs @@ -87,6 +87,83 @@ mod test { proof_invalid_3.sibling_sums[0] = [0.into(), 0.into()]; } + #[test] + fn test_update_mst_leaf() { + let merkle_tree_1 = + MerkleSumTree::::new("src/merkle_sum_tree/csv/entry_16.csv") + .unwrap(); + + let root_hash_1 = merkle_tree_1.root().hash; + + //Create the second tree with the 7th entry different from the the first tree + let mut merkle_tree_2 = MerkleSumTree::::new( + "src/merkle_sum_tree/csv/entry_16_modified.csv", + ) + .unwrap(); + + let root_hash_2 = merkle_tree_2.root().hash; + assert!(root_hash_1 != root_hash_2); + + //Update the 7th leaf of the second tree so all the entries now match the first tree + let new_root = merkle_tree_2 + .update_leaf( + "RkLzkDun", + &[2087.to_biguint().unwrap(), 79731.to_biguint().unwrap()], + ) + .unwrap(); + //The roots should match + assert!(root_hash_1 == new_root.hash); + } + + #[test] + fn test_update_invalid_mst_leaf() { + let mut merkle_tree = + MerkleSumTree::::new_sorted("src/merkle_sum_tree/csv/entry_16.csv") + .unwrap(); + + let new_root = merkle_tree.update_leaf( + "non_existing_user", //This username is not present in the tree + &[11888.to_biguint().unwrap(), 41163.to_biguint().unwrap()], + ); + + if let Err(e) = new_root { + assert_eq!(e.to_string(), "Username not found"); + } + } + + #[test] + fn test_sorted_mst() { + let merkle_tree = + MerkleSumTree::::new("src/merkle_sum_tree/csv/entry_16.csv") + .unwrap(); + + let old_root_balances = merkle_tree.root().balances; + let old_root_hash = merkle_tree.root().hash; + + let sorted_merkle_tree = + MerkleSumTree::::new_sorted("src/merkle_sum_tree/csv/entry_16.csv") + .unwrap(); + + let new_root_balances = sorted_merkle_tree.root().balances; + let new_root_hash = sorted_merkle_tree.root().hash; + + // The index of an entry should not be the same for sorted and unsorted MST + assert_ne!( + merkle_tree + .index_of( + "AtwIxZHo", + [35479.to_biguint().unwrap(), 31699.to_biguint().unwrap()] + ) + .unwrap(), + sorted_merkle_tree.index_of_username("AtwIxZHo").unwrap() + ); + + // The root balances should be the same for sorted and unsorted MST + assert!(old_root_balances == new_root_balances); + // The root hash should not be the same for sorted and unsorted MST + assert!(old_root_hash != new_root_hash); + } + // Passing a csv file with a single entry that has a balance that is not in the expected range will fail #[test] fn test_mst_overflow_1() { @@ -126,7 +203,7 @@ mod test { assert!(result.is_ok()); } - #[test] + #[test] fn test_big_uint_conversion() { let big_uint = 3.to_biguint().unwrap(); let fp = big_uint_to_fp(&big_uint); diff --git a/zk_prover/src/merkle_sum_tree/utils/build_tree.rs b/zk_prover/src/merkle_sum_tree/utils/build_tree.rs index 968cbe81..66b5860b 100644 --- a/zk_prover/src/merkle_sum_tree/utils/build_tree.rs +++ b/zk_prover/src/merkle_sum_tree/utils/build_tree.rs @@ -1,4 +1,3 @@ -use crate::merkle_sum_tree::utils::create_middle_node::create_middle_node; use crate::merkle_sum_tree::{Entry, Node}; use halo2_proofs::halo2curves::bn256::Fr as Fp; use std::thread; @@ -94,7 +93,7 @@ fn build_middle_level( handles.push(thread::spawn(move || { chunk .chunks(2) - .map(|pair| create_middle_node(&pair[0], &pair[1])) + .map(|pair| Node::middle(&pair[0], &pair[1])) .collect::>() })); } diff --git a/zk_prover/src/merkle_sum_tree/utils/create_middle_node.rs b/zk_prover/src/merkle_sum_tree/utils/create_middle_node.rs deleted file mode 100644 index e473f35e..00000000 --- a/zk_prover/src/merkle_sum_tree/utils/create_middle_node.rs +++ /dev/null @@ -1,27 +0,0 @@ -use halo2_proofs::halo2curves::bn256::Fr as Fp; - -use crate::merkle_sum_tree::utils::hash::poseidon_node; -use crate::merkle_sum_tree::Node; - -pub fn create_middle_node( - child_l: &Node, - child_r: &Node, -) -> Node -where - [usize; 2 * (1 + N_ASSETS)]: Sized, -{ - let mut balances_sum = [Fp::zero(); N_ASSETS]; - for (i, balance) in balances_sum.iter_mut().enumerate() { - *balance = child_l.balances[i] + child_r.balances[i]; - } - - Node { - hash: poseidon_node( - child_l.hash, - child_l.balances, - child_r.hash, - child_r.balances, - ), - balances: balances_sum, - } -} diff --git a/zk_prover/src/merkle_sum_tree/utils/mod.rs b/zk_prover/src/merkle_sum_tree/utils/mod.rs index 53d4db35..86962fa8 100644 --- a/zk_prover/src/merkle_sum_tree/utils/mod.rs +++ b/zk_prover/src/merkle_sum_tree/utils/mod.rs @@ -1,5 +1,4 @@ mod build_tree; -mod create_middle_node; mod create_proof; mod csv_parser; mod hash; diff --git a/zk_prover/src/merkle_sum_tree/utils/proof_verification.rs b/zk_prover/src/merkle_sum_tree/utils/proof_verification.rs index 09b02b88..78c6ec72 100644 --- a/zk_prover/src/merkle_sum_tree/utils/proof_verification.rs +++ b/zk_prover/src/merkle_sum_tree/utils/proof_verification.rs @@ -1,4 +1,4 @@ -use crate::merkle_sum_tree::utils::{big_uint_to_fp, create_middle_node::create_middle_node}; +use crate::merkle_sum_tree::utils::big_uint_to_fp; use crate::merkle_sum_tree::{MerkleProof, Node}; use halo2_proofs::halo2curves::bn256::Fr as Fp; @@ -22,9 +22,9 @@ where }; if proof.path_indices[i] == 0.into() { - node = create_middle_node(&node, &sibling_node); + node = Node::middle(&node, &sibling_node); } else { - node = create_middle_node(&sibling_node, &node); + node = Node::middle(&sibling_node, &node); } for (balance, sibling_balance) in balances.iter_mut().zip(sibling_node.balances.iter()) {