Skip to content

Commit

Permalink
Add MST leaf update function (#156)
Browse files Browse the repository at this point in the history
* Add MST leaf update function

* Implement code review changes

* Fix documentation
  • Loading branch information
alxkzmn authored Sep 18, 2023
1 parent bc63f43 commit 51e4286
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 68 deletions.
23 changes: 22 additions & 1 deletion zk_prover/benches/full_solvency_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn build_mstree(_c: &mut Criterion) {
);

let bench_name = format!(
"build merkle sum tree for 2 power of {} entries with {} assets",
"build Merkle sum tree for 2 power of {} entries with {} assets",
LEVELS, N_ASSETS
);

Expand All @@ -40,6 +40,26 @@ fn build_mstree(_c: &mut Criterion) {
});
}

fn build_sorted_mstree(_c: &mut Criterion) {
let mut criterion = Criterion::default().sample_size(SAMPLE_SIZE);

let csv_file = format!(
"benches/csv/{}/{}_entry_2_{}.csv",
PATH_NAME, PATH_NAME, LEVELS
);

let bench_name = format!(
"build sorted Merkle sum tree for 2 power of {} entries with {} assets",
LEVELS, N_ASSETS
);

criterion.bench_function(&bench_name, |b| {
b.iter(|| {
MerkleSumTree::<N_ASSETS, N_BYTES>::new_sorted(&csv_file).unwrap();
})
});
}

fn verification_key_gen_mst_inclusion_circuit(_c: &mut Criterion) {
let mut criterion = Criterion::default().sample_size(SAMPLE_SIZE);

Expand Down Expand Up @@ -239,6 +259,7 @@ fn verify_zk_proof_solvency_circuit(_c: &mut Criterion) {
criterion_group!(
benches,
build_mstree,
build_sorted_mstree,
verification_key_gen_mst_inclusion_circuit,
proving_key_gen_mst_inclusion_circuit,
generate_zk_proof_mst_inclusion_circuit,
Expand Down
2 changes: 1 addition & 1 deletion zk_prover/src/circuits/merkle_sum_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ where
// Assign the entry username
let username = self.assign_value_to_witness(
layouter.namespace(|| "assign entry username"),
big_uint_to_fp(self.entry.username_to_big_uint()),
big_uint_to_fp(self.entry.username_as_big_uint()),
"entry username",
config.advices[0],
)?;
Expand Down
17 changes: 17 additions & 0 deletions zk_prover/src/merkle_sum_tree/csv/entry_16_modified.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
username;balances
dxGaEAii;11888,41163
MBlfbBGI;67823,18651
lAhWlEWZ;18651,2087
nuZweYtO;22073,55683
gbdSwiuY;34897,83296
RZNneNuP;83296,16881
YsscHXkp;31699,35479
RkLzkDun;2086,79732
HlQlnEYI;30605,11888
RqkZOFYe;16881,14874
NjCSRAfD;41163,67823
pHniJMQY;14874,22073
dOGIMzKR;10032,10032
HfMDmNLp;55683,34897
xPLKzCBl;79731,30605
AtwIxZHo;35479,31699
44 changes: 18 additions & 26 deletions zk_prover/src/merkle_sum_tree/entry.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
use crate::merkle_sum_tree::utils::{big_intify_username, big_uint_to_fp, poseidon_entry};
use crate::merkle_sum_tree::utils::big_intify_username;
use crate::merkle_sum_tree::Node;
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use num_bigint::BigUint;

/// An entry in the Merkle Sum Tree from the database of the CEX.
/// It contains the username and the balances of the user.
#[derive(Clone, Debug)]
pub struct Entry<const N_ASSETS: usize> {
username_to_big_uint: BigUint,
username_as_big_uint: BigUint,
balances: [BigUint; N_ASSETS],
username: String,
}

impl<const N_ASSETS: usize> Entry<N_ASSETS> {
pub fn new(username: String, balances: [BigUint; N_ASSETS]) -> Result<Self, &'static str> {
Ok(Entry {
username_to_big_uint: big_intify_username(&username),
username_as_big_uint: big_intify_username(&username),
balances,
username,
})
Expand All @@ -25,7 +24,7 @@ impl<const N_ASSETS: usize> Entry<N_ASSETS> {
let empty_balances: [BigUint; N_ASSETS] = std::array::from_fn(|_| BigUint::from(0u32));

Entry {
username_to_big_uint: BigUint::from(0u32),
username_as_big_uint: BigUint::from(0u32),
balances: empty_balances,
username: "".to_string(),
}
Expand All @@ -35,33 +34,26 @@ impl<const N_ASSETS: usize> Entry<N_ASSETS> {
where
[usize; N_ASSETS + 1]: Sized,
{
Node {
hash: poseidon_entry::<N_ASSETS>(
big_uint_to_fp(&self.username_to_big_uint),
self.balances
.iter()
.map(big_uint_to_fp)
.collect::<Vec<Fp>>()
.try_into()
.unwrap(),
),
//Map the array of balances using big_int_to_fp:
balances: self
.balances
.iter()
.map(big_uint_to_fp)
.collect::<Vec<Fp>>()
.try_into()
.unwrap(),
}
Node::leaf(&self.username_as_big_uint, &self.balances)
}

/// Stores the new balance values
///
/// Returns the updated node
pub fn recompute_leaf(&mut self, updated_balances: &[BigUint; N_ASSETS]) -> Node<N_ASSETS>
where
[usize; N_ASSETS + 1]: Sized,
{
self.balances = updated_balances.clone();
Node::leaf(&self.username_as_big_uint, updated_balances)
}

pub fn balances(&self) -> &[BigUint; N_ASSETS] {
&self.balances
}

pub fn username_to_big_uint(&self) -> &BigUint {
&self.username_to_big_uint
pub fn username_as_big_uint(&self) -> &BigUint {
&self.username_as_big_uint
}

pub fn username(&self) -> &str {
Expand Down
8 changes: 2 additions & 6 deletions zk_prover/src/merkle_sum_tree/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod entry;
mod mst;
mod node;
mod tests;
pub mod utils;
use halo2_proofs::halo2curves::bn256::Fr as Fp;
Expand All @@ -13,12 +14,7 @@ pub struct MerkleProof<const N_ASSETS: usize> {
pub path_indices: Vec<Fp>,
}

#[derive(Clone, Debug)]
pub struct Node<const N_ASSETS: usize> {
pub hash: Fp,
pub balances: [Fp; N_ASSETS],
}

pub use entry::Entry;
pub use mst::MerkleSumTree;
pub use node::Node;
pub use utils::{big_intify_username, big_uint_to_fp};
89 changes: 89 additions & 0 deletions zk_prover/src/merkle_sum_tree/mst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub struct MerkleSumTree<const N_ASSETS: usize, const N_BYTES: usize> {
nodes: Vec<Vec<Node<N_ASSETS>>>,
depth: usize,
entries: Vec<Entry<N_ASSETS>>,
is_sorted: bool,
}

impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTES> {
Expand All @@ -38,6 +39,34 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
let entries = parse_csv_to_entries::<&str, N_ASSETS, N_BYTES>(path)?;
Self::build_tree(entries, false)
}

/// Builds a Merkle Sum Tree from a CSV file stored at `path`. The MST leaves are sorted by the username byte values. The CSV file must be formatted as follows:
///
/// `username;balances`
///
/// `dxGaEAii;11888,41163`
pub fn new_sorted(path: &str) -> Result<Self, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
let mut entries = parse_csv_to_entries::<&str, N_ASSETS, N_BYTES>(path)?;

entries.sort_by(|a, b| a.username().cmp(b.username()));

Self::build_tree(entries, true)
}

fn build_tree(
entries: Vec<Entry<N_ASSETS>>,
is_sorted: bool,
) -> Result<MerkleSumTree<N_ASSETS, N_BYTES>, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
let depth = (entries.len() as f64).log2().ceil() as usize;

if !(1..=Self::MAX_DEPTH).contains(&depth) {
Expand All @@ -55,9 +84,50 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE
nodes,
depth,
entries,
is_sorted,
})
}

/// Updates the balances of the entry with the given username and returns the new root of the tree.
///
/// # Arguments
///
/// * `username`: The username of the entry to update
/// * `new_balances`: The new balances of the entry
///
/// # Returns
///
/// The new root of the tree
pub fn update_leaf(
&mut self,
username: &str,
new_balances: &[BigUint; N_ASSETS],
) -> Result<Node<N_ASSETS>, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
let index = self.index_of_username(username)?;

// Update the leaf node.
let updated_leaf = self.entries[index].recompute_leaf(new_balances);
self.nodes[0][index] = updated_leaf;

// Recompute the hashes and balances up the tree.
let mut current_index = index;
for depth in 1..=self.depth {
let parent_index = current_index / 2;
let left_child = &self.nodes[depth - 1][2 * parent_index];
let right_child = &self.nodes[depth - 1][2 * parent_index + 1];
self.nodes[depth][parent_index] = Node::<N_ASSETS>::middle(left_child, right_child);
current_index = parent_index;
}

let root = self.nodes[self.depth][0].clone();

Ok(root)
}

pub fn root(&self) -> &Node<N_ASSETS> {
&self.root
}
Expand Down Expand Up @@ -94,6 +164,25 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE
index_of(username, balances, &self.nodes)
}

/// Returns the index of the leaf with the matching username
pub fn index_of_username(&self, username: &str) -> Result<usize, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
{
if !self.is_sorted {
self.entries
.iter()
.enumerate()
.find(|(_, entry)| entry.username() == username)
.map(|(index, _)| index)
.ok_or_else(|| Box::from("Username not found"))
} else {
self.entries
.binary_search_by_key(&username, |entry| entry.username())
.map_err(|_| Box::from("Username not found"))
}
}

/// Generates a MerkleProof for the user with the given index
pub fn generate_proof(&self, index: usize) -> Result<MerkleProof<N_ASSETS>, &'static str> {
create_proof(index, &self.entries, self.depth, &self.nodes, &self.root)
Expand Down
60 changes: 60 additions & 0 deletions zk_prover/src/merkle_sum_tree/node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use num_bigint::BigUint;

use super::{
big_uint_to_fp,
utils::{poseidon_entry, poseidon_node},
};

#[derive(Clone, Debug)]
pub struct Node<const N_ASSETS: usize> {
pub hash: Fp,
pub balances: [Fp; N_ASSETS],
}
impl<const N_ASSETS: usize> Node<N_ASSETS> {
/// Builds a "middle" (non-leaf-level) node of the MST
pub fn middle(child_l: &Node<N_ASSETS>, child_r: &Node<N_ASSETS>) -> Node<N_ASSETS>
where
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
let mut balances_sum = [Fp::zero(); N_ASSETS];
for (i, balance) in balances_sum.iter_mut().enumerate() {
*balance = child_l.balances[i] + child_r.balances[i];
}

Node {
hash: poseidon_node(
child_l.hash,
child_l.balances,
child_r.hash,
child_r.balances,
),
balances: balances_sum,
}
}

/// Builds a leaf-level node of the MST
pub fn leaf(username: &BigUint, balances: &[BigUint; N_ASSETS]) -> Node<N_ASSETS>
where
[usize; N_ASSETS + 1]: Sized,
{
Node {
hash: poseidon_entry::<N_ASSETS>(
big_uint_to_fp(username),
balances
.iter()
.map(big_uint_to_fp)
.collect::<Vec<Fp>>()
.try_into()
.unwrap(),
),
//Map the array of balances using big_int_to_fp:
balances: balances
.iter()
.map(big_uint_to_fp)
.collect::<Vec<Fp>>()
.try_into()
.unwrap(),
}
}
}
Loading

0 comments on commit 51e4286

Please sign in to comment.