From bb5e4f350689549768350405d3e78a3022b76b2f Mon Sep 17 00:00:00 2001
From: Shahar Papini <43779613+spapinistarkware@users.noreply.github.com>
Date: Tue, 2 Apr 2024 14:18:30 +0300
Subject: [PATCH] merkle intermediate layers (#548)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This change is [](https://reviewable.io/reviews/starkware-libs/stwo/548)
---
src/commitment_scheme/blake2_merkle.rs | 81 ++++----
src/commitment_scheme/prover.rs | 182 ++++++++++++++---
src/commitment_scheme/utils.rs | 20 ++
src/commitment_scheme/verifier.rs | 260 +++++++++++--------------
src/core/utils.rs | 37 ++++
5 files changed, 376 insertions(+), 204 deletions(-)
diff --git a/src/commitment_scheme/blake2_merkle.rs b/src/commitment_scheme/blake2_merkle.rs
index e3e48ff64..5fad8755e 100644
--- a/src/commitment_scheme/blake2_merkle.rs
+++ b/src/commitment_scheme/blake2_merkle.rs
@@ -6,9 +6,11 @@ use super::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;
+#[derive(Copy, Clone, PartialEq, Eq, Default)]
+pub struct Blake2sHash(pub [u32; 8]);
pub struct Blake2Hasher;
impl MerkleHasher for Blake2Hasher {
- type Hash = [u32; 8];
+ type Hash = Blake2sHash;
fn hash_node(
children_hashes: Option<(Self::Hash, Self::Hash)>,
@@ -33,16 +35,26 @@ impl MerkleHasher for Blake2Hasher {
for chunk in padded_values.array_chunks::<16>() {
state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0);
}
- state
+ Blake2sHash(state)
+ }
+}
+
+impl std::fmt::Debug for Blake2sHash {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ // Write as hex.
+ for &byte in self.0.iter() {
+ write!(f, "{:02x}", byte)?;
+ }
+ Ok(())
}
}
impl MerkleOps for CPUBackend {
fn commit_on_layer(
log_size: u32,
- prev_layer: Option<&Vec<[u32; 8]>>,
+ prev_layer: Option<&Vec>,
columns: &[&Vec],
- ) -> Vec<[u32; 8]> {
+ ) -> Vec {
(0..(1 << log_size))
.map(|i| {
Blake2Hasher::hash_node(
@@ -56,34 +68,34 @@ impl MerkleOps for CPUBackend {
#[cfg(test)]
mod tests {
+ use std::collections::BTreeMap;
+
use itertools::Itertools;
use num_traits::Zero;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
- use crate::commitment_scheme::blake2_merkle::Blake2Hasher;
+ use crate::commitment_scheme::blake2_merkle::{Blake2Hasher, Blake2sHash};
use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver};
- use crate::commitment_scheme::verifier::{MerkleTreeVerifier, MerkleVerificationError};
+ use crate::commitment_scheme::verifier::{MerkleVerificationError, MerkleVerifier};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;
type TestData = (
- Vec,
+ BTreeMap>,
MerkleDecommitment,
- Vec<(u32, Vec)>,
- MerkleTreeVerifier,
+ Vec>,
+ MerkleVerifier,
);
fn prepare_merkle() -> TestData {
const N_COLS: usize = 400;
const N_QUERIES: usize = 7;
+ let log_size_range = 6..9;
let rng = &mut StdRng::seed_from_u64(0);
let log_sizes = (0..N_COLS)
- .map(|_| rng.gen_range(6..9))
- .sorted()
- .rev()
+ .map(|_| rng.gen_range(log_size_range.clone()))
.collect_vec();
- let max_log_size = *log_sizes.iter().max().unwrap();
let cols = log_sizes
.iter()
.map(|&log_size| {
@@ -94,26 +106,21 @@ mod tests {
.collect_vec();
let merkle = MerkleProver::::commit(cols.iter().collect_vec());
- let queries = (0..N_QUERIES)
- .map(|_| rng.gen_range(0..(1 << max_log_size)))
- .sorted()
- .dedup()
- .collect_vec();
- let decommitment = merkle.decommit(queries.clone());
- let values = cols
- .iter()
- .map(|col| {
- let layer_queries = queries
- .iter()
- .map(|&q| q >> (max_log_size - col.len().ilog2()))
- .dedup();
- layer_queries.map(|q| col[q]).collect_vec()
- })
- .collect_vec();
- let values = log_sizes.into_iter().zip(values).collect_vec();
+ let mut queries = BTreeMap::>::new();
+ for log_size in log_size_range.rev() {
+ let layer_queries = (0..N_QUERIES)
+ .map(|_| rng.gen_range(0..(1 << log_size)))
+ .sorted()
+ .dedup()
+ .collect_vec();
+ queries.insert(log_size, layer_queries);
+ }
+
+ let (values, decommitment) = merkle.decommit(queries.clone(), cols.iter().collect_vec());
- let verifier = MerkleTreeVerifier {
+ let verifier = MerkleVerifier {
root: merkle.root(),
+ column_log_sizes: log_sizes,
};
(queries, decommitment, values, verifier)
}
@@ -128,7 +135,7 @@ mod tests {
#[test]
fn test_merkle_invalid_witness() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
- decommitment.witness[20] = [0; 8];
+ decommitment.hash_witness[20] = Blake2sHash([0; 8]);
assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
@@ -139,7 +146,7 @@ mod tests {
#[test]
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
- values[3].1[6] = BaseField::zero();
+ values[3][6] = BaseField::zero();
assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
@@ -150,7 +157,7 @@ mod tests {
#[test]
fn test_merkle_witness_too_short() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
- decommitment.witness.pop();
+ decommitment.hash_witness.pop();
assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
@@ -161,7 +168,7 @@ mod tests {
#[test]
fn test_merkle_column_values_too_long() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
- values[3].1.push(BaseField::zero());
+ values[3].push(BaseField::zero());
assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
@@ -172,7 +179,7 @@ mod tests {
#[test]
fn test_merkle_column_values_too_short() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
- values[3].1.pop();
+ values[3].pop();
assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
@@ -183,7 +190,7 @@ mod tests {
#[test]
fn test_merkle_witness_too_long() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
- decommitment.witness.push([0; 8]);
+ decommitment.hash_witness.push(Blake2sHash([0; 8]));
assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
diff --git a/src/commitment_scheme/prover.rs b/src/commitment_scheme/prover.rs
index 8823213e2..108afd61d 100644
--- a/src/commitment_scheme/prover.rs
+++ b/src/commitment_scheme/prover.rs
@@ -1,15 +1,19 @@
use std::cmp::Reverse;
+use std::collections::BTreeMap;
use itertools::Itertools;
use super::ops::{MerkleHasher, MerkleOps};
+use super::utils::{next_decommitment_node, option_flatten_peekable};
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
+use crate::core::utils::PeekableExt;
+use crate::core::ColumnVec;
pub struct MerkleProver, H: MerkleHasher> {
/// Layers of the Merkle tree.
- /// The first layer is the largest column.
- /// The last layer is the root.
+ /// The first layer is the root layer.
+ /// The last layer is the largest layer.
/// See [MerkleOps::commit_on_layer] for more details.
pub layers: Vec>,
}
@@ -18,7 +22,7 @@ pub struct MerkleProver, H: MerkleHasher> {
/// hasher respectively.
impl, H: MerkleHasher> MerkleProver {
/// Commits to columns.
- /// Columns must be of power of 2 sizes and sorted in descending order.
+ /// Columns must be of power of 2 sizes.
///
/// # Arguments
///
@@ -33,22 +37,24 @@ impl, H: MerkleHasher> MerkleProver {
///
/// A new instance of `MerkleProver` with the committed layers.
pub fn commit(columns: Vec<&Col>) -> Self {
- // Check that columns are of descending order.
assert!(!columns.is_empty());
- assert!(columns.is_sorted_by_key(|c| Reverse(c.len())));
- let mut columns = &mut columns.into_iter().peekable();
+ let columns = &mut columns
+ .into_iter()
+ .sorted_by_key(|c| Reverse(c.len()))
+ .peekable();
let mut layers: Vec> = Vec::new();
let max_log_size = columns.peek().unwrap().len().ilog2();
for log_size in (0..=max_log_size).rev() {
// Take columns of the current log_size.
- let layer_columns = (&mut columns)
- .take_while(|column| column.len().ilog2() == log_size)
+ let layer_columns = columns
+ .peek_take_while(|column| column.len().ilog2() == log_size)
.collect_vec();
layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns));
}
+ layers.reverse();
Self { layers }
}
@@ -57,36 +63,160 @@ impl, H: MerkleHasher> MerkleProver {
///
/// # Arguments
///
- /// * `queries` - A vector of query indices to the largest column.
+ /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that
+ /// log_size.
+ /// * `columns` - A vector of references to columns.
///
/// # Returns
///
- /// A `Decommitment` struct containing the witness.
- pub fn decommit(&self, mut queries: Vec) -> MerkleDecommitment {
- let mut witness = Vec::new();
- for layer in &self.layers[..self.layers.len() - 1] {
- let mut queries_iter = queries.into_iter().peekable();
-
- // Propagate queries and hashes to the next layer.
- let mut next_queries = Vec::new();
- while let Some(query) = queries_iter.next() {
- next_queries.push(query / 2);
- if queries_iter.next_if_eq(&(query ^ 1)).is_some() {
- continue;
+ /// A tuple containing:
+ /// * A vector of vectors of queried values for each column, in the order of the input columns.
+ /// * A `MerkleDecommitment` containing the hash and column witnesses.
+ pub fn decommit(
+ &self,
+ queries_per_log_size: BTreeMap>,
+ columns: Vec<&Col>,
+ ) -> (ColumnVec>, MerkleDecommitment) {
+ // Check that queries are sorted and deduped.
+ for queries in queries_per_log_size.values() {
+ assert!(
+ queries.windows(2).all(|w| w[0] < w[1]),
+ "Queries are not sorted."
+ );
+ }
+
+ // Prepare output buffers.
+ let mut queried_values_by_layer = vec![];
+ let mut decommitment = MerkleDecommitment::empty();
+
+ // Sort columns by layer.
+ let mut columns_by_layer = columns
+ .iter()
+ .sorted_by_key(|c| Reverse(c.len()))
+ .peekable();
+
+ let mut last_layer_queries = vec![];
+ for layer_log_size in (0..self.layers.len() as u32).rev() {
+ // Prepare write buffer for queried values to the current layer.
+ let mut layer_queried_values = vec![];
+
+ // Prepare write buffer for queries to the current layer. This will propagate to the
+ // next layer.
+ let mut layer_total_queries = vec![];
+
+ // Each layer node is a hash of column values as previous layer hashes.
+ // Prepare the relevant columns and previous layer hashes to read from.
+ let layer_columns = columns_by_layer
+ .peek_take_while(|column| column.len().ilog2() == layer_log_size)
+ .collect_vec();
+ let previous_layer_hashes = self.layers.get(layer_log_size as usize + 1);
+
+ // Queries to this layer come from queried node in the previous layer and queried
+ // columns in this one.
+ let mut prev_layer_queries = last_layer_queries.into_iter().peekable();
+ let mut layer_column_queries =
+ option_flatten_peekable(queries_per_log_size.get(&layer_log_size));
+
+ // Merge previous layer queries and column queries.
+ while let Some(node_index) =
+ next_decommitment_node(&mut prev_layer_queries, &mut layer_column_queries)
+ {
+ if let Some(previous_layer_hashes) = previous_layer_hashes {
+ // If the left child was not computed, add it to the witness.
+ if prev_layer_queries.next_if_eq(&(2 * node_index)).is_none() {
+ decommitment
+ .hash_witness
+ .push(previous_layer_hashes.at(2 * node_index));
+ }
+
+ // If the right child was not computed, add it to the witness.
+ if prev_layer_queries
+ .next_if_eq(&(2 * node_index + 1))
+ .is_none()
+ {
+ decommitment
+ .hash_witness
+ .push(previous_layer_hashes.at(2 * node_index + 1));
+ }
+ }
+
+ // If the column values were queried, return them.
+ let node_values = layer_columns.iter().map(|c| c.at(node_index));
+ if layer_column_queries.next_if_eq(&node_index).is_some() {
+ layer_queried_values.push(node_values.collect_vec());
+ } else {
+ // Otherwise, add them to the witness.
+ decommitment.column_witness.extend(node_values);
}
- witness.push(layer.at(query ^ 1));
+
+ layer_total_queries.push(node_index);
}
- queries = next_queries;
+
+ queried_values_by_layer.push(layer_queried_values);
+
+ // Propagate queries to the next layer.
+ last_layer_queries = layer_total_queries;
}
- MerkleDecommitment { witness }
+ queried_values_by_layer.reverse();
+
+ // Rearrange returned queried values according to input, and not by layer.
+ let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns);
+
+ (queried_values, decommitment)
+ }
+
+ /// Given queried values by layer, rearranges in the order of input columns.
+ fn rearrange_queried_values(
+ queried_values_by_layer: Vec>>,
+ columns: Vec<&Col>,
+ ) -> Vec> {
+ // Turn each column queried values into an iterator.
+ let mut queried_values_by_layer = queried_values_by_layer
+ .into_iter()
+ .map(|layer_results| {
+ layer_results
+ .into_iter()
+ .map(|x| x.into_iter())
+ .collect_vec()
+ })
+ .collect_vec();
+
+ // For each input column, fetch the queried values from the corresponding layer.
+ let queried_values = columns
+ .iter()
+ .map(|column| {
+ queried_values_by_layer
+ .get_mut(column.len().ilog2() as usize)
+ .unwrap()
+ .iter_mut()
+ .map(|x| x.next().unwrap())
+ .collect_vec()
+ })
+ .collect_vec();
+ queried_values
}
pub fn root(&self) -> H::Hash {
- self.layers.last().unwrap().at(0)
+ self.layers.first().unwrap().at(0)
}
}
#[derive(Debug)]
pub struct MerkleDecommitment {
- pub witness: Vec,
+ /// Hash values that the verifier needs but cannot deduce from previous computations, in the
+ /// order they are needed.
+ pub hash_witness: Vec,
+ /// Column values that the verifier needs but cannot deduce from previous computations, in the
+ /// order they are needed.
+ /// This complements the column values that were queried. These must be supplied directly to
+ /// the verifier.
+ pub column_witness: Vec,
+}
+impl MerkleDecommitment {
+ fn empty() -> Self {
+ Self {
+ hash_witness: Vec::new(),
+ column_witness: Vec::new(),
+ }
+ }
}
diff --git a/src/commitment_scheme/utils.rs b/src/commitment_scheme/utils.rs
index 34420c297..260fe72b9 100644
--- a/src/commitment_scheme/utils.rs
+++ b/src/commitment_scheme/utils.rs
@@ -1,4 +1,5 @@
use std::collections::BTreeMap;
+use std::iter::Peekable;
use std::slice::Iter;
use super::hasher::Hasher;
@@ -15,6 +16,25 @@ pub fn allocate_layer(n_bytes: usize) -> TreeLayer {
unsafe { Box::<[T]>::new_zeroed_slice(n_bytes).assume_init() }
}
+/// Fetches the next node that needs to be decommited in the current Merkle layer.
+pub fn next_decommitment_node(
+ prev_queries: &mut Peekable>,
+ layer_queries: &mut Peekable>,
+) -> Option {
+ prev_queries
+ .peek()
+ .map(|q| *q / 2)
+ .into_iter()
+ .chain(layer_queries.peek().into_iter().copied())
+ .min()
+}
+
+pub fn option_flatten_peekable<'a, I: IntoIterator- >(
+ a: Option,
+) -> Peekable as IntoIterator>::IntoIter>>> {
+ a.into_iter().flatten().copied().peekable()
+}
+
pub fn allocate_balanced_tree(
bottom_layer_length: usize,
size_of_node_bytes: usize,
diff --git a/src/commitment_scheme/verifier.rs b/src/commitment_scheme/verifier.rs
index dabff528f..800549277 100644
--- a/src/commitment_scheme/verifier.rs
+++ b/src/commitment_scheme/verifier.rs
@@ -1,27 +1,30 @@
use std::cmp::Reverse;
-use std::iter::Peekable;
+use std::collections::BTreeMap;
use itertools::Itertools;
use thiserror::Error;
use super::ops::MerkleHasher;
use super::prover::MerkleDecommitment;
+use super::utils::{next_decommitment_node, option_flatten_peekable};
use crate::core::fields::m31::BaseField;
+use crate::core::utils::PeekableExt;
+use crate::core::ColumnVec;
// TODO(spapini): This struct is not necessary. Make it a function on decommitment?
-pub struct MerkleTreeVerifier {
+pub struct MerkleVerifier {
pub root: H::Hash,
+ pub column_log_sizes: Vec,
}
-impl MerkleTreeVerifier {
+impl MerkleVerifier {
/// Verifies the decommitment of the columns.
///
/// # Arguments
///
- /// * `queries` - A vector of indices representing the queries to the largest column.
- /// Note: This is sufficient for bit reversed STARK commitments.
- /// It could be extended to support queries to any column.
- /// * `values` - A vector of pairs containing the log_size of the column and the decommitted
- /// values of the column. Must be given in the same order as the columns were committed.
+ /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that
+ /// log_size.
+ /// * `queried_values` - A vector of vectors of queried values. For each column, there is a
+ /// vector of queried values to that column.
/// * `decommitment` - The decommitment object containing the witness and column values.
///
/// # Errors
@@ -44,157 +47,132 @@ impl MerkleTreeVerifier {
/// Returns `Ok(())` if the decommitment is successfully verified.
pub fn verify(
&self,
- queries: Vec,
- values: Vec<(u32, Vec)>,
+ queries_per_log_size: BTreeMap>,
+ queried_values: ColumnVec>,
decommitment: MerkleDecommitment,
) -> Result<(), MerkleVerificationError> {
- // Check that columns are of descending order.
- assert!(values.is_sorted_by_key(|(log_size, _)| Reverse(log_size)));
-
- // Compute root from decommitment.
- let mut verifier = MerkleVerifier:: {
- witness: decommitment.witness.into_iter(),
- column_values: values.into_iter().peekable(),
- layer_column_values: Vec::new(),
- };
- let computed_root = verifier.compute_root_from_decommitment(queries)?;
-
- // Check that all witnesses and values have been consumed.
- if !verifier.witness.is_empty() {
- return Err(MerkleVerificationError::WitnessTooLong);
- }
- if !verifier.column_values.is_empty() {
- return Err(MerkleVerificationError::ColumnValuesTooLong);
- }
-
- // Check that the computed root matches the expected root.
- if computed_root != self.root {
- return Err(MerkleVerificationError::RootMismatch);
- }
-
- Ok(())
- }
-}
-
-/// A helper struct for verifying a [MerkleDecommitment].
-struct MerkleVerifier {
- /// A queue for consuming the next hash witness from the decommitment.
- witness: std::vec::IntoIter<::Hash>,
- /// A queue for consuming the next claimed values for each column.
- column_values: Peekable)>>,
- /// A queue for consuming the next claimed values for each column in the current layer.
- layer_column_values: Vec>,
-}
-impl MerkleVerifier {
- /// Computes the root hash of a Merkle tree from the decommitment information.
- ///
- /// # Arguments
- ///
- /// * `queries` - A vector of query indices to the largest column.
- ///
- /// # Returns
- ///
- /// Returns the computed root hash of the Merkle tree.
- ///
- /// # Errors
- ///
- /// Returns a `MerkleVerificationError` if there is an error during the computation.
- pub fn compute_root_from_decommitment(
- &mut self,
- queries: Vec,
- ) -> Result {
- let max_log_size = self.column_values.peek().unwrap().0;
- assert!(*queries.iter().max().unwrap() < 1 << max_log_size);
-
- // A sequence of queries to the current layer.
- // Each query is a pair of the query index and the known hashes of the children, if any.
- // The known hashes are represented as ChildrenHashesAtQuery.
- // None on the largest layer, or a pair of Option, for the known hashes of the left
- // and right children.
- let mut queries = queries.into_iter().map(|query| (query, None)).collect_vec();
-
+ let max_log_size = self.column_log_sizes.iter().max().copied().unwrap_or(0);
+
+ // Prepare read buffers.
+ let mut queried_values_by_layer = self
+ .column_log_sizes
+ .iter()
+ .copied()
+ .zip(
+ queried_values
+ .into_iter()
+ .map(|column_values| column_values.into_iter()),
+ )
+ .sorted_by_key(|(log_size, _)| Reverse(*log_size))
+ .peekable();
+ let mut hash_witness = decommitment.hash_witness.into_iter();
+ let mut column_witness = decommitment.column_witness.into_iter();
+
+ let mut last_layer_hashes: Option> = None;
for layer_log_size in (0..=max_log_size).rev() {
- // Take values for columns of the current log_size.
- self.layer_column_values = (&mut self.column_values)
- .take_while(|(log_size, _)| *log_size == layer_log_size)
- .map(|(_, values)| values.into_iter())
- .collect();
-
- // Compute node hashes for the current layer.
- let mut hashes_at_layer = queries
+ // Prepare read buffer for queried values to the current layer.
+ let mut layer_queried_values = queried_values_by_layer
+ .peek_take_while(|(log_size, _)| *log_size == layer_log_size)
+ .collect_vec();
+ let n_columns_in_layer = layer_queried_values.len();
+
+ // Prepare write buffer for queries to the current layer. This will propagate to the
+ // next layer.
+ let mut layer_total_queries = vec![];
+
+ // Queries to this layer come from queried node in the previous layer and queried
+ // columns in this one.
+ let mut prev_layer_queries = last_layer_hashes
+ .iter()
+ .flatten()
+ .map(|(q, _)| *q)
+ .collect_vec()
.into_iter()
- .map(|(index, children_hashes)| (index, self.compute_node_hash(children_hashes)))
.peekable();
+ let mut prev_layer_hashes = last_layer_hashes.as_ref().map(|x| x.iter().peekable());
+ let mut layer_column_queries =
+ option_flatten_peekable(queries_per_log_size.get(&layer_log_size));
- // Propagate queries and hashes to the next layer.
- let mut next_queries = Vec::new();
- while let Some((index, node_hash)) = hashes_at_layer.next() {
- // If the sibling hash is known, propagate it to the next layer.
- if let Some((_, sibling_hash)) =
- hashes_at_layer.next_if(|(next_index, _)| *next_index == index ^ 1)
- {
- next_queries.push((index / 2, Some((Some(node_hash?), Some(sibling_hash?)))));
- continue;
- }
- // Otherwise, propagate the node hash to the next layer, in the correct direction.
- if index & 1 == 0 {
- next_queries.push((index / 2, Some((Some(node_hash?), None))));
+ // Merge previous layer queries and column queries.
+ while let Some(node_index) =
+ next_decommitment_node(&mut prev_layer_queries, &mut layer_column_queries)
+ {
+ prev_layer_queries
+ .peek_take_while(|q| q / 2 == node_index)
+ .for_each(drop);
+
+ let node_hashes = prev_layer_hashes
+ .as_mut()
+ .map(|prev_layer_hashes| {
+ {
+ // If the left child was not computed, read it from the witness.
+ let left_hash = prev_layer_hashes
+ .next_if(|(index, _)| *index == 2 * node_index)
+ .map(|(_, hash)| Ok(hash.clone()))
+ .unwrap_or_else(|| {
+ hash_witness
+ .next()
+ .ok_or(MerkleVerificationError::WitnessTooShort)
+ })?;
+
+ // If the right child was not computed, read it to from the witness.
+ let right_hash = prev_layer_hashes
+ .next_if(|(index, _)| *index == 2 * node_index + 1)
+ .map(|(_, hash)| Ok(hash.clone()))
+ .unwrap_or_else(|| {
+ hash_witness
+ .next()
+ .ok_or(MerkleVerificationError::WitnessTooShort)
+ })?;
+ Ok((left_hash, right_hash))
+ }
+ })
+ .transpose()?;
+
+ // If the column values were queried, read them from `queried_value`.
+ let node_values = if layer_column_queries.next_if_eq(&node_index).is_some() {
+ layer_queried_values
+ .iter_mut()
+ .map(|(_, ref mut column_queries)| {
+ column_queries
+ .next()
+ .ok_or(MerkleVerificationError::ColumnValuesTooShort)
+ })
+ .collect::, _>>()?
} else {
- next_queries.push((index / 2, Some((None, Some(node_hash?)))));
+ // Otherwise, read them from the witness.
+ (&mut column_witness).take(n_columns_in_layer).collect_vec()
+ };
+ if node_values.len() != n_columns_in_layer {
+ return Err(MerkleVerificationError::WitnessTooShort);
}
+
+ layer_total_queries.push((node_index, H::hash_node(node_hashes, &node_values)));
}
- queries = next_queries;
- // Check that all layer_column_values have been consumed.
- if self
- .layer_column_values
- .iter_mut()
- .any(|values| values.next().is_some())
- {
+ if !layer_queried_values.iter().all(|(_, c)| c.is_empty()) {
return Err(MerkleVerificationError::ColumnValuesTooLong);
}
+ last_layer_hashes = Some(layer_total_queries);
}
- assert_eq!(queries.len(), 1);
- Ok(queries.pop().unwrap().1.unwrap().0.unwrap())
- }
+ // Check that all witnesses and values have been consumed.
+ if !hash_witness.is_empty() {
+ return Err(MerkleVerificationError::WitnessTooLong);
+ }
+ if !column_witness.is_empty() {
+ return Err(MerkleVerificationError::WitnessTooLong);
+ }
- fn compute_node_hash(
- &mut self,
- children_hashes: ChildrenHashesAtQuery,
- ) -> Result {
- // For each child with an unknown hash, fill it from the witness queue.
- let hashes_part = children_hashes
- .map(|(l, r)| {
- let l = l
- .or_else(|| self.witness.next())
- .ok_or(MerkleVerificationError::WitnessTooShort)?;
- let r = r
- .or_else(|| self.witness.next())
- .ok_or(MerkleVerificationError::WitnessTooShort)?;
- Ok((l, r))
- })
- .transpose()?;
- // Fill the column values from the layer_column_values queue.
- let column_values = self
- .layer_column_values
- .iter_mut()
- .map(|values| {
- values
- .next()
- .ok_or(MerkleVerificationError::ColumnValuesTooShort)
- })
- .collect::, _>>()?;
- // Hash the node.
- Ok(H::hash_node(hashes_part, &column_values))
+ let [(_, computed_root)] = last_layer_hashes.unwrap().try_into().unwrap();
+ if computed_root != self.root {
+ return Err(MerkleVerificationError::RootMismatch);
+ }
+
+ Ok(())
}
}
-type ChildrenHashesAtQuery = Option<(
- Option<::Hash>,
- Option<::Hash>,
-)>;
-
#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)]
pub enum MerkleVerificationError {
#[error("Witness is too short.")]
diff --git a/src/core/utils.rs b/src/core/utils.rs
index 8a6483e36..f9d928c31 100644
--- a/src/core/utils.rs
+++ b/src/core/utils.rs
@@ -1,3 +1,5 @@
+use std::iter::Peekable;
+
pub trait IteratorMutExt<'a, T: 'a>: Iterator
- {
fn assign(self, other: impl IntoIterator
- )
where
@@ -9,6 +11,41 @@ pub trait IteratorMutExt<'a, T: 'a>: Iterator
- {
impl<'a, T: 'a, I: Iterator
- > IteratorMutExt<'a, T> for I {}
+/// An iterator that takes elements from the underlying [Peekable] while the predicate is true.
+/// Used to implement [PeekableExt::peek_take_while].
+pub struct PeekTakeWhile<'a, I: Iterator, P: FnMut(&I::Item) -> bool> {
+ iter: &'a mut Peekable,
+ predicate: P,
+}
+impl<'a, I: Iterator, P: FnMut(&I::Item) -> bool> Iterator for PeekTakeWhile<'a, I, P> {
+ type Item = I::Item;
+
+ fn next(&mut self) -> Option {
+ self.iter.next_if(&mut self.predicate)
+ }
+}
+pub trait PeekableExt<'a, I: Iterator> {
+ /// Returns an iterator that takes elements from the underlying [Peekable] while the predicate
+ /// is true.
+ /// Unlike [Iterator::take_while], this iterator does not consume the first element that does
+ /// not satisfy the predicate.
+ fn peek_take_while bool>(
+ &'a mut self,
+ predicate: P,
+ ) -> PeekTakeWhile<'a, I, P>;
+}
+impl<'a, I: Iterator> PeekableExt<'a, I> for Peekable {
+ fn peek_take_while bool>(
+ &'a mut self,
+ predicate: P,
+ ) -> PeekTakeWhile<'a, I, P> {
+ PeekTakeWhile {
+ iter: self,
+ predicate,
+ }
+ }
+}
+
pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize {
if log_size == 0 {
return i;