Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simple code cleanups in mst implementation #269

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ env:

jobs:
wakeup:
if: github.event.pull_request.head.repo.full_name == 'summa-dev/summa-solvency'
runs-on: ubuntu-latest
permissions:
id-token: write
Expand All @@ -31,9 +32,10 @@ jobs:
aws-region: us-west-2

- name: Wakeup runner
run: .github/scripts/wakeup.sh
run: .github/scripts/wakeup.sh

build:
if: github.event.pull_request.head.repo.full_name == 'summa-dev/summa-solvency'
runs-on: [summa-solvency-runner]
needs: [wakeup]

Expand Down Expand Up @@ -71,4 +73,56 @@ jobs:
run: |
cd backend
cargo run --release --example summa_solvency_flow


test-zk-prover:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Test Zk Prover
run: |
cd zk_prover
cargo test --release --features dev-graph -- --nocapture

test-zk-prover-examples:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install solc
run: (hash svm 2>/dev/null || cargo install --version 0.2.23 svm-rs) && svm install 0.8.20 && solc --version
- name: Test Zk Prover examples
run: |
cd zk_prover
cargo run --release --example gen_inclusion_verifier
cargo run --release --example gen_commitment
cargo run --release --example gen_inclusion_proof

test-zk-prover-examples-nova:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Test Zk Prover examples
run: |
cd zk_prover
cargo run --release --example nova_incremental_verifier

test-backend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install Foundry
uses: foundry-rs/foundry-toolchain@v1
- name: Test backend
run: |
cd backend
cargo test --release -- --nocapture

test-backend-examples:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install Foundry
uses: foundry-rs/foundry-toolchain@v1
- name: Test backend example
run: |
cd backend
cargo run --release --example summa_solvency_flow
2 changes: 1 addition & 1 deletion backend/src/apis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ where
.map(|balance| BigUint::from_str_radix(balance, 10).unwrap())
.collect();

let entry: Entry<N_CURRENCIES> = Entry::new(username, balances.try_into().unwrap()).unwrap();
let entry: Entry<N_CURRENCIES> = Entry::new(username, balances.try_into().unwrap());

// Convert Fp to U256
let hash_str = format!("{:?}", entry.compute_leaf().hash);
Expand Down
2 changes: 1 addition & 1 deletion zk_prover/src/circuits/merkle_sum_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ where
// Assign the entry username to the witness
let username = self.assign_value_to_witness(
layouter.namespace(|| "assign entry username"),
big_uint_to_fp(self.entry.username_as_big_uint()),
big_uint_to_fp(&self.entry.username_as_big_uint()),
"entry username",
config.advices[0],
)?;
Expand Down
3 changes: 1 addition & 2 deletions zk_prover/src/circuits/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ mod test {
let invalid_leaf_balances = [1000.to_biguint().unwrap(), 1000.to_biguint().unwrap()];

// invalidate user entry
let invalid_entry =
Entry::new(circuit.entry.username().to_string(), invalid_leaf_balances).unwrap();
let invalid_entry = Entry::new(circuit.entry.username().to_string(), invalid_leaf_balances);

circuit.entry = invalid_entry;

Expand Down
6 changes: 3 additions & 3 deletions zk_prover/src/merkle_sum_tree/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ pub struct Entry<const N_CURRENCIES: usize> {
}

impl<const N_CURRENCIES: usize> Entry<N_CURRENCIES> {
pub fn new(username: String, balances: [BigUint; N_CURRENCIES]) -> Result<Self, &'static str> {
pub fn new(username: String, balances: [BigUint; N_CURRENCIES]) -> Self {
teddav marked this conversation as resolved.
Show resolved Hide resolved
// Security Assumptions:
// Using `keccak256` for `hashed_username` ensures high collision resistance,
// appropriate for the assumed userbase of $2^{30}$.
// The `hashed_username` utilizes the full 256 bits produced by `keccak256`,
// but is adjusted to the field size through the Poseidon hash function's modulo operation.
let hashed_username: BigUint = BigUint::from_bytes_be(&keccak256(username.as_bytes()));
Ok(Entry {
Entry {
hashed_username,
balances,
username,
})
}
}

/// Returns a zero entry where the username is 0 and the balances are all 0
Expand Down
5 changes: 1 addition & 4 deletions zk_prover/src/merkle_sum_tree/mst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ impl<const N_CURRENCIES: usize, const N_BYTES: usize> MerkleSumTree<N_CURRENCIES
{
let depth = (entries.len() as f64).log2().ceil() as usize;

let mut nodes = vec![];

// Pad the entries with empty entries to make the number of entries equal to 2^depth
if entries.len() < 2usize.pow(depth as u32) {
entries.extend(vec![
Expand All @@ -123,7 +121,7 @@ impl<const N_CURRENCIES: usize, const N_BYTES: usize> MerkleSumTree<N_CURRENCIES

let leaves = build_leaves_from_entries(&entries);

let root = build_merkle_tree_from_leaves(&leaves, depth, &mut nodes)?;
let (root, nodes) = build_merkle_tree_from_leaves(&leaves, depth)?;

Ok(MerkleSumTree {
root,
Expand Down Expand Up @@ -202,7 +200,6 @@ impl<const N_CURRENCIES: usize, const N_BYTES: usize> MerkleSumTree<N_CURRENCIES
}

let root = self.nodes[self.depth][0].clone();

Ok(root)
}

Expand Down
45 changes: 8 additions & 37 deletions zk_prover/src/merkle_sum_tree/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ impl<const N_CURRENCIES: usize> Node<N_CURRENCIES> {
where
[usize; N_CURRENCIES + 1]: Sized,
{
let hash =
poseidon::Hash::<Fp, PoseidonSpec, ConstantLength<{ N_CURRENCIES + 1 }>, 2, 1>::init()
.hash(preimage.clone());
Node {
hash: Self::poseidon_hash_leaf(preimage[0], preimage[1..].try_into().unwrap()),
hash,
balances: preimage[1..].try_into().unwrap(),
}
}
Expand All @@ -71,44 +74,12 @@ impl<const N_CURRENCIES: usize> Node<N_CURRENCIES> {
where
[usize; N_CURRENCIES + 2]: Sized,
{
let hash =
poseidon::Hash::<Fp, PoseidonSpec, ConstantLength<{ N_CURRENCIES + 2 }>, 2, 1>::init()
.hash(preimage.clone());
Node {
hash: Self::poseidon_hash_middle(
preimage[0..N_CURRENCIES].try_into().unwrap(),
preimage[N_CURRENCIES],
preimage[N_CURRENCIES + 1],
),
hash,
balances: preimage[0..N_CURRENCIES].try_into().unwrap(),
}
}

fn poseidon_hash_middle(
balances_sum: [Fp; N_CURRENCIES],
hash_child_left: Fp,
hash_child_right: Fp,
) -> Fp
where
[usize; N_CURRENCIES + 2]: Sized,
{
let mut hash_inputs: [Fp; N_CURRENCIES + 2] = [Fp::zero(); N_CURRENCIES + 2];

hash_inputs[0..N_CURRENCIES].copy_from_slice(&balances_sum);
hash_inputs[N_CURRENCIES] = hash_child_left;
hash_inputs[N_CURRENCIES + 1] = hash_child_right;

poseidon::Hash::<Fp, PoseidonSpec, ConstantLength<{ N_CURRENCIES + 2 }>, 2, 1>::init()
.hash(hash_inputs)
}

fn poseidon_hash_leaf(username: Fp, balances: [Fp; N_CURRENCIES]) -> Fp
where
[usize; N_CURRENCIES + 1]: Sized,
{
let mut hash_inputs: [Fp; N_CURRENCIES + 1] = [Fp::zero(); N_CURRENCIES + 1];

hash_inputs[0] = username;
hash_inputs[1..N_CURRENCIES + 1].copy_from_slice(&balances);

poseidon::Hash::<Fp, PoseidonSpec, ConstantLength<{ N_CURRENCIES + 1 }>, 2, 1>::init()
.hash(hash_inputs)
}
}
3 changes: 1 addition & 2 deletions zk_prover/src/merkle_sum_tree/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ mod test {
let invalid_entry = Entry::new(
"AtwIxZHo".to_string(),
[35479.to_biguint().unwrap(), 35479.to_biguint().unwrap()],
)
.unwrap();
);
let invalid_entry = invalid_entry;
let mut proof_invalid_1 = proof.clone();
proof_invalid_1.entry = invalid_entry;
Expand Down
17 changes: 9 additions & 8 deletions zk_prover/src/merkle_sum_tree/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub trait Tree<const N_CURRENCIES: usize> {
let mut preimage = [Fp::zero(); N_CURRENCIES + 1];

// Add username to preimage
preimage[0] = big_uint_to_fp(entry.username_as_big_uint());
preimage[0] = big_uint_to_fp(&entry.username_as_big_uint());

// Add balances to preimage
for (i, balance) in preimage.iter_mut().enumerate().skip(1).take(N_CURRENCIES) {
Expand All @@ -97,6 +97,7 @@ pub trait Tree<const N_CURRENCIES: usize> {
if index >= nodes[0].len() {
return Err(Box::from("Index out of bounds"));
}
assert_eq!(nodes[0].len(), 2usize.pow(depth as u32));

let mut sibling_middle_node_hash_preimages = Vec::with_capacity(depth - 1);

Expand All @@ -111,7 +112,9 @@ pub trait Tree<const N_CURRENCIES: usize> {
let position = current_index % 2;
let sibling_index = current_index - position + (1 - position);

if sibling_index < nodes[level].len() && level != 0 {
// we asserted that the leaves vec length is a power of 2
// so the index shouldn't overflow the level's length
if level > 0 {
// Fetch hash preimage for sibling middle nodes
let sibling_node_preimage =
self.get_middle_node_hash_preimage(level, sibling_index)?;
Expand Down Expand Up @@ -152,14 +155,13 @@ pub trait Tree<const N_CURRENCIES: usize> {
if proof.path_indices[0] == 0.into() {
hash_preimage[N_CURRENCIES] = node.hash;
hash_preimage[N_CURRENCIES + 1] = sibling_leaf_node.hash;
node = Node::middle_node_from_preimage(&hash_preimage);
} else {
hash_preimage[N_CURRENCIES] = sibling_leaf_node.hash;
hash_preimage[N_CURRENCIES + 1] = node.hash;
node = Node::middle_node_from_preimage(&hash_preimage);
}
node = Node::middle_node_from_preimage(&hash_preimage);

for i in 1..proof.path_indices.len() {
for (i, path_index) in proof.path_indices.iter().enumerate().skip(1) {
let sibling_node = Node::<N_CURRENCIES>::middle_node_from_preimage(
&proof.sibling_middle_node_hash_preimages[i - 1],
);
Expand All @@ -169,15 +171,14 @@ pub trait Tree<const N_CURRENCIES: usize> {
*balance = node.balances[i] + sibling_node.balances[i];
}

if proof.path_indices[i] == 0.into() {
if *path_index == 0.into() {
hash_preimage[N_CURRENCIES] = node.hash;
hash_preimage[N_CURRENCIES + 1] = sibling_node.hash;
node = Node::middle_node_from_preimage(&hash_preimage);
} else {
hash_preimage[N_CURRENCIES] = sibling_node.hash;
hash_preimage[N_CURRENCIES + 1] = node.hash;
node = Node::middle_node_from_preimage(&hash_preimage);
}
node = Node::middle_node_from_preimage(&hash_preimage);
}

proof.root.hash == node.hash && proof.root.balances == node.balances
Expand Down
45 changes: 11 additions & 34 deletions zk_prover/src/merkle_sum_tree/utils/build_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,25 @@ use rayon::prelude::*;
pub fn build_merkle_tree_from_leaves<const N_CURRENCIES: usize>(
leaves: &[Node<N_CURRENCIES>],
depth: usize,
nodes: &mut Vec<Vec<Node<N_CURRENCIES>>>,
) -> Result<Node<N_CURRENCIES>, Box<dyn std::error::Error>>
) -> Result<(Node<N_CURRENCIES>, Vec<Vec<Node<N_CURRENCIES>>>), Box<dyn std::error::Error>>
where
[usize; N_CURRENCIES + 1]: Sized,
[usize; N_CURRENCIES + 2]: Sized,
{
let n = leaves.len();

let mut tree: Vec<Vec<Node<N_CURRENCIES>>> = Vec::with_capacity(depth + 1);

tree.push(vec![
Node {
hash: Fp::from(0),
balances: [Fp::from(0); N_CURRENCIES]
};
n
]);

for _ in 1..=depth {
let previous_level = tree.last().unwrap();
let nodes_in_level = (previous_level.len() + 1) / 2;
// the size of a leaf layer must be a power of 2
// if not, the `leaves` Vec should be completed with "zero entries" until a power of 2
assert_eq!(leaves.len(), 2usize.pow(depth as u32));

tree.push(vec![
Node {
hash: Fp::from(0),
balances: [Fp::from(0); N_CURRENCIES]
};
nodes_in_level
]);
}

for (index, leaf) in leaves.iter().enumerate() {
tree[0][index] = leaf.clone();
}
tree.push(leaves.to_vec());

for level in 1..=depth {
build_middle_level(level, &mut tree)
}

let root = tree[depth][0].clone();
*nodes = tree;
Ok(root)
Ok((root, tree))
}

pub fn build_leaves_from_entries<const N_CURRENCIES: usize>(
Expand Down Expand Up @@ -74,8 +51,10 @@ where
leaves
}

fn build_middle_level<const N_CURRENCIES: usize>(level: usize, tree: &mut [Vec<Node<N_CURRENCIES>>])
where
fn build_middle_level<const N_CURRENCIES: usize>(
level: usize,
tree: &mut Vec<Vec<Node<N_CURRENCIES>>>,
) where
[usize; N_CURRENCIES + 2]: Sized,
{
let results: Vec<Node<N_CURRENCIES>> = (0..tree[level - 1].len())
Expand All @@ -95,7 +74,5 @@ where
})
.collect();

for (index, new_node) in results.into_iter().enumerate() {
tree[level][index] = new_node;
}
tree.push(results);
}
3 changes: 2 additions & 1 deletion zk_prover/src/merkle_sum_tree/utils/csv_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ pub fn parse_csv_to_entries<P: AsRef<Path>, const N_CURRENCIES: usize, const N_B
balances_big_int.push(balance);
}

let entry = Entry::new(username, balances_big_int.try_into().unwrap())?;
let entry = Entry::new(username, balances_big_int.try_into().unwrap());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reviewed the potential for errors during the try_into operation. It appears that converting BigUint in this line does not produce any errors.
However, for the record, this should be checked again if the balance type in the Entry struct changes.


entries.push(entry);
}

Expand Down
Loading