Skip to content

Commit

Permalink
feat: fix create_proof to be more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
enricobottazzi committed Nov 1, 2023
1 parent e8f68d6 commit 7a510a5
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 87 deletions.
4 changes: 3 additions & 1 deletion zk_prover/src/circuits/merkle_sum_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ impl<const LEVELS: usize, const N_ASSETS: usize, const N_BYTES: usize>
assert_eq!(proof.sibling_hashes.len(), LEVELS);
assert_eq!(proof.sibling_sums.len(), LEVELS);

let entry = merkle_sum_tree.get_entry(user_index).clone();

Self {
entry: proof.entry,
entry,
path_element_hashes: proof.sibling_hashes,
path_element_balances: proof.sibling_sums,
path_indices: proof.path_indices,
Expand Down
10 changes: 3 additions & 7 deletions zk_prover/src/merkle_sum_tree/aggregation_mst.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use crate::merkle_sum_tree::utils::{
build_merkle_tree_from_leaves, create_proof, create_top_tree_proof, verify_proof,
};
use crate::merkle_sum_tree::utils::{build_merkle_tree_from_leaves, create_proof, verify_proof};
use crate::merkle_sum_tree::{MerkleProof, MerkleSumTree, Node};

/// Aggregation Merkle Sum Tree Data Structure.
Expand Down Expand Up @@ -91,19 +89,17 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> AggregationMerkleSumTree<N_ASS

let partial_proof = create_proof(
user_index,
mini_tree.entries(),
*mini_tree.depth(),
mini_tree.nodes(),
mini_tree.root(),
)?;

let top_tree_proof =
create_top_tree_proof(mini_tree_index, self.depth, &self.nodes, &self.root)?;
let top_tree_proof = create_proof(mini_tree_index, self.depth, &self.nodes, &self.root)?;

// Merge the two proofs
let final_proof = MerkleProof {
root_hash: self.root.hash,
entry: partial_proof.entry,
leaf: partial_proof.leaf,
sibling_hashes: [partial_proof.sibling_hashes, top_tree_proof.sibling_hashes].concat(),
sibling_sums: [partial_proof.sibling_sums, top_tree_proof.sibling_sums].concat(),
path_indices: [partial_proof.path_indices, top_tree_proof.path_indices].concat(),
Expand Down
10 changes: 1 addition & 9 deletions zk_prover/src/merkle_sum_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,7 @@ use halo2_proofs::halo2curves::bn256::Fr as Fp;

#[derive(Clone, Debug)]
pub struct MerkleProof<const N_ASSETS: usize> {
pub root_hash: Fp,
pub entry: Entry<N_ASSETS>,
pub sibling_hashes: Vec<Fp>,
pub sibling_sums: Vec<[Fp; N_ASSETS]>,
pub path_indices: Vec<Fp>,
}

#[derive(Clone, Debug)]
pub struct TopTreeMerkleProof<const N_ASSETS: usize> {
pub leaf: Node<N_ASSETS>,
pub root_hash: Fp,
pub sibling_hashes: Vec<Fp>,
pub sibling_sums: Vec<[Fp; N_ASSETS]>,
Expand Down
6 changes: 5 additions & 1 deletion zk_prover/src/merkle_sum_tree/mst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE
&self.entries
}

pub fn get_entry(&self, index: usize) -> &Entry<N_ASSETS> {
&self.entries[index]
}

pub fn nodes(&self) -> &[Vec<Node<N_ASSETS>>] {
&self.nodes
}
Expand Down Expand Up @@ -185,7 +189,7 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE

/// Generates a MerkleProof for the user with the given index
pub fn generate_proof(&self, index: usize) -> Result<MerkleProof<N_ASSETS>, &'static str> {
create_proof(index, &self.entries, self.depth, &self.nodes, &self.root)
create_proof(index, self.depth, &self.nodes, &self.root)
}

/// Verifies a MerkleProof
Expand Down
27 changes: 10 additions & 17 deletions zk_prover/src/merkle_sum_tree/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ mod test {
// shouldn't create a proof for an entry that doesn't exist in the tree
assert!(merkle_tree.generate_proof(16).is_err());

// shouldn't verify a proof with a wrong entry
let mut proof_invalid_1 = proof.clone();
proof_invalid_1.entry = Entry::new(
// shouldn't verify a proof with a wrong leaf
let invalid_entry = Entry::new(
"AtwIxZHo".to_string(),
[35479.to_biguint().unwrap(), 35479.to_biguint().unwrap()],
)
.unwrap();
let invalid_leaf = invalid_entry.compute_leaf();
let mut proof_invalid_1 = proof.clone();
proof_invalid_1.leaf = invalid_leaf;
assert!(!merkle_tree.verify_proof(&proof_invalid_1));

// shouldn't verify a proof with a wrong root hash
Expand Down Expand Up @@ -336,23 +338,14 @@ mod test {
// shouldn't create a proof for an entry that doesn't exist in the tree
assert!(aggregation_mst.generate_proof(16, 0).is_err());

// shouldn't verify a proof with a wrong entry
// shouldn't verify a proof with a wrong root hash
let mut proof_invalid_1 = proof.clone();
proof_invalid_1.entry = Entry::new(
"AtwIxZHo".to_string(),
[35479.to_biguint().unwrap(), 35479.to_biguint().unwrap()],
)
.unwrap();
proof_invalid_1.root_hash = 0.into();
assert!(!aggregation_mst.verify_proof(&proof_invalid_1));

// shouldn't verify a proof with a wrong root hash
let mut proof_invalid_2 = proof.clone();
proof_invalid_2.root_hash = 0.into();
assert!(!aggregation_mst.verify_proof(&proof_invalid_2));

// shouldn't verify a proof with a wrong computed balance
let mut proof_invalid_3 = proof;
proof_invalid_3.sibling_sums[0] = [0.into(), 0.into()];
assert!(!aggregation_mst.verify_proof(&proof_invalid_3))
let mut proof_invalid_2 = proof;
proof_invalid_2.sibling_sums[0] = [0.into(), 0.into()];
assert!(!aggregation_mst.verify_proof(&proof_invalid_2))
}
}
46 changes: 4 additions & 42 deletions zk_prover/src/merkle_sum_tree/utils/create_proof.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::merkle_sum_tree::{Entry, MerkleProof, Node, TopTreeMerkleProof};
use crate::merkle_sum_tree::{MerkleProof, Node};
use halo2_proofs::halo2curves::bn256::Fr as Fp;

pub fn create_proof<const N_ASSETS: usize>(
index: usize,
entries: &[Entry<N_ASSETS>],
depth: usize,
nodes: &[Vec<Node<N_ASSETS>>],
root: &Node<N_ASSETS>,
Expand All @@ -17,45 +16,7 @@ pub fn create_proof<const N_ASSETS: usize>(
let mut path_indices = vec![Fp::from(0); depth];
let mut current_index = index;

for level in 0..depth {
let position = current_index % 2;
let level_start_index = current_index - position;
let level_end_index = level_start_index + 2;

path_indices[level] = Fp::from(position as u64);

for i in level_start_index..level_end_index {
if i != current_index {
sibling_hashes[level] = nodes[level][i].hash;
sibling_sums[level] = nodes[level][i].balances;
}
}
current_index /= 2;
}

Ok(MerkleProof {
root_hash: root.hash,
entry: entries[index].clone(),
sibling_hashes,
sibling_sums,
path_indices,
})
}

pub fn create_top_tree_proof<const N_ASSETS: usize>(
index: usize,
depth: usize,
nodes: &[Vec<Node<N_ASSETS>>],
root: &Node<N_ASSETS>,
) -> Result<TopTreeMerkleProof<N_ASSETS>, &'static str> {
if index >= nodes[0].len() {
return Err("The leaf does not exist in this tree");
}

let mut sibling_hashes = vec![Fp::from(0); depth];
let mut sibling_sums = vec![[Fp::from(0); N_ASSETS]; depth];
let mut path_indices = vec![Fp::from(0); depth];
let mut current_index = index;
let leaf = &nodes[0][index]; // Added this line to store the leaf node

for level in 0..depth {
let position = current_index % 2;
Expand All @@ -73,7 +34,8 @@ pub fn create_top_tree_proof<const N_ASSETS: usize>(
current_index /= 2;
}

Ok(TopTreeMerkleProof {
Ok(MerkleProof {
leaf: leaf.clone(),
root_hash: root.hash,
sibling_hashes,
sibling_sums,
Expand Down
2 changes: 1 addition & 1 deletion zk_prover/src/merkle_sum_tree/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod operation_helpers;
mod proof_verification;

pub use build_tree::{build_merkle_tree_from_leaves, compute_leaves};
pub use create_proof::{create_proof, create_top_tree_proof};
pub use create_proof::create_proof;
pub use csv_parser::parse_csv_to_entries;
pub use generate_leaf_hash::generate_leaf_hash;
pub use hash::{poseidon_entry, poseidon_node};
Expand Down
12 changes: 3 additions & 9 deletions zk_prover/src/merkle_sum_tree/utils/proof_verification.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
use crate::merkle_sum_tree::utils::big_uint_to_fp;
use crate::merkle_sum_tree::{MerkleProof, Node};
use halo2_proofs::halo2curves::bn256::Fr as Fp;

pub fn verify_proof<const N_ASSETS: usize>(proof: &MerkleProof<N_ASSETS>) -> bool
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
let mut node = proof.entry.compute_leaf();
let mut balances = proof
.entry
.balances()
.iter()
.map(big_uint_to_fp)
.collect::<Vec<Fp>>();
let mut node = proof.leaf.clone();

let mut balances = proof.leaf.balances;

for i in 0..proof.sibling_hashes.len() {
let sibling_node = Node {
Expand Down

0 comments on commit 7a510a5

Please sign in to comment.