From bb5e4f350689549768350405d3e78a3022b76b2f Mon Sep 17 00:00:00 2001 From: Shahar Papini <43779613+spapinistarkware@users.noreply.github.com> Date: Tue, 2 Apr 2024 14:18:30 +0300 Subject: [PATCH] merkle intermediate layers (#548) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change is [Reviewable](https://reviewable.io/reviews/starkware-libs/stwo/548) --- src/commitment_scheme/blake2_merkle.rs | 81 ++++---- src/commitment_scheme/prover.rs | 182 ++++++++++++++--- src/commitment_scheme/utils.rs | 20 ++ src/commitment_scheme/verifier.rs | 260 +++++++++++-------------- src/core/utils.rs | 37 ++++ 5 files changed, 376 insertions(+), 204 deletions(-) diff --git a/src/commitment_scheme/blake2_merkle.rs b/src/commitment_scheme/blake2_merkle.rs index e3e48ff64..5fad8755e 100644 --- a/src/commitment_scheme/blake2_merkle.rs +++ b/src/commitment_scheme/blake2_merkle.rs @@ -6,9 +6,11 @@ use super::ops::{MerkleHasher, MerkleOps}; use crate::core::backend::CPUBackend; use crate::core::fields::m31::BaseField; +#[derive(Copy, Clone, PartialEq, Eq, Default)] +pub struct Blake2sHash(pub [u32; 8]); pub struct Blake2Hasher; impl MerkleHasher for Blake2Hasher { - type Hash = [u32; 8]; + type Hash = Blake2sHash; fn hash_node( children_hashes: Option<(Self::Hash, Self::Hash)>, @@ -33,16 +35,26 @@ impl MerkleHasher for Blake2Hasher { for chunk in padded_values.array_chunks::<16>() { state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0); } - state + Blake2sHash(state) + } +} + +impl std::fmt::Debug for Blake2sHash { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Write as hex. + for &byte in self.0.iter() { + write!(f, "{:02x}", byte)?; + } + Ok(()) } } impl MerkleOps for CPUBackend { fn commit_on_layer( log_size: u32, - prev_layer: Option<&Vec<[u32; 8]>>, + prev_layer: Option<&Vec>, columns: &[&Vec], - ) -> Vec<[u32; 8]> { + ) -> Vec { (0..(1 << log_size)) .map(|i| { Blake2Hasher::hash_node( @@ -56,34 +68,34 @@ impl MerkleOps for CPUBackend { #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use itertools::Itertools; use num_traits::Zero; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; - use crate::commitment_scheme::blake2_merkle::Blake2Hasher; + use crate::commitment_scheme::blake2_merkle::{Blake2Hasher, Blake2sHash}; use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver}; - use crate::commitment_scheme::verifier::{MerkleTreeVerifier, MerkleVerificationError}; + use crate::commitment_scheme::verifier::{MerkleVerificationError, MerkleVerifier}; use crate::core::backend::CPUBackend; use crate::core::fields::m31::BaseField; type TestData = ( - Vec, + BTreeMap>, MerkleDecommitment, - Vec<(u32, Vec)>, - MerkleTreeVerifier, + Vec>, + MerkleVerifier, ); fn prepare_merkle() -> TestData { const N_COLS: usize = 400; const N_QUERIES: usize = 7; + let log_size_range = 6..9; let rng = &mut StdRng::seed_from_u64(0); let log_sizes = (0..N_COLS) - .map(|_| rng.gen_range(6..9)) - .sorted() - .rev() + .map(|_| rng.gen_range(log_size_range.clone())) .collect_vec(); - let max_log_size = *log_sizes.iter().max().unwrap(); let cols = log_sizes .iter() .map(|&log_size| { @@ -94,26 +106,21 @@ mod tests { .collect_vec(); let merkle = MerkleProver::::commit(cols.iter().collect_vec()); - let queries = (0..N_QUERIES) - .map(|_| rng.gen_range(0..(1 << max_log_size))) - .sorted() - .dedup() - .collect_vec(); - let decommitment = merkle.decommit(queries.clone()); - let values = cols - .iter() - .map(|col| { - let layer_queries = queries - .iter() - .map(|&q| q >> (max_log_size - col.len().ilog2())) - .dedup(); - layer_queries.map(|q| col[q]).collect_vec() - }) - .collect_vec(); - let values = log_sizes.into_iter().zip(values).collect_vec(); + let mut queries = BTreeMap::>::new(); + for log_size in log_size_range.rev() { + let layer_queries = (0..N_QUERIES) + .map(|_| rng.gen_range(0..(1 << log_size))) + .sorted() + .dedup() + .collect_vec(); + queries.insert(log_size, layer_queries); + } + + let (values, decommitment) = merkle.decommit(queries.clone(), cols.iter().collect_vec()); - let verifier = MerkleTreeVerifier { + let verifier = MerkleVerifier { root: merkle.root(), + column_log_sizes: log_sizes, }; (queries, decommitment, values, verifier) } @@ -128,7 +135,7 @@ mod tests { #[test] fn test_merkle_invalid_witness() { let (queries, mut decommitment, values, verifier) = prepare_merkle(); - decommitment.witness[20] = [0; 8]; + decommitment.hash_witness[20] = Blake2sHash([0; 8]); assert_eq!( verifier.verify(queries, values, decommitment).unwrap_err(), @@ -139,7 +146,7 @@ mod tests { #[test] fn test_merkle_invalid_value() { let (queries, decommitment, mut values, verifier) = prepare_merkle(); - values[3].1[6] = BaseField::zero(); + values[3][6] = BaseField::zero(); assert_eq!( verifier.verify(queries, values, decommitment).unwrap_err(), @@ -150,7 +157,7 @@ mod tests { #[test] fn test_merkle_witness_too_short() { let (queries, mut decommitment, values, verifier) = prepare_merkle(); - decommitment.witness.pop(); + decommitment.hash_witness.pop(); assert_eq!( verifier.verify(queries, values, decommitment).unwrap_err(), @@ -161,7 +168,7 @@ mod tests { #[test] fn test_merkle_column_values_too_long() { let (queries, decommitment, mut values, verifier) = prepare_merkle(); - values[3].1.push(BaseField::zero()); + values[3].push(BaseField::zero()); assert_eq!( verifier.verify(queries, values, decommitment).unwrap_err(), @@ -172,7 +179,7 @@ mod tests { #[test] fn test_merkle_column_values_too_short() { let (queries, decommitment, mut values, verifier) = prepare_merkle(); - values[3].1.pop(); + values[3].pop(); assert_eq!( verifier.verify(queries, values, decommitment).unwrap_err(), @@ -183,7 +190,7 @@ mod tests { #[test] fn test_merkle_witness_too_long() { let (queries, mut decommitment, values, verifier) = prepare_merkle(); - decommitment.witness.push([0; 8]); + decommitment.hash_witness.push(Blake2sHash([0; 8])); assert_eq!( verifier.verify(queries, values, decommitment).unwrap_err(), diff --git a/src/commitment_scheme/prover.rs b/src/commitment_scheme/prover.rs index 8823213e2..108afd61d 100644 --- a/src/commitment_scheme/prover.rs +++ b/src/commitment_scheme/prover.rs @@ -1,15 +1,19 @@ use std::cmp::Reverse; +use std::collections::BTreeMap; use itertools::Itertools; use super::ops::{MerkleHasher, MerkleOps}; +use super::utils::{next_decommitment_node, option_flatten_peekable}; use crate::core::backend::{Col, Column}; use crate::core::fields::m31::BaseField; +use crate::core::utils::PeekableExt; +use crate::core::ColumnVec; pub struct MerkleProver, H: MerkleHasher> { /// Layers of the Merkle tree. - /// The first layer is the largest column. - /// The last layer is the root. + /// The first layer is the root layer. + /// The last layer is the largest layer. /// See [MerkleOps::commit_on_layer] for more details. pub layers: Vec>, } @@ -18,7 +22,7 @@ pub struct MerkleProver, H: MerkleHasher> { /// hasher respectively. impl, H: MerkleHasher> MerkleProver { /// Commits to columns. - /// Columns must be of power of 2 sizes and sorted in descending order. + /// Columns must be of power of 2 sizes. /// /// # Arguments /// @@ -33,22 +37,24 @@ impl, H: MerkleHasher> MerkleProver { /// /// A new instance of `MerkleProver` with the committed layers. pub fn commit(columns: Vec<&Col>) -> Self { - // Check that columns are of descending order. assert!(!columns.is_empty()); - assert!(columns.is_sorted_by_key(|c| Reverse(c.len()))); - let mut columns = &mut columns.into_iter().peekable(); + let columns = &mut columns + .into_iter() + .sorted_by_key(|c| Reverse(c.len())) + .peekable(); let mut layers: Vec> = Vec::new(); let max_log_size = columns.peek().unwrap().len().ilog2(); for log_size in (0..=max_log_size).rev() { // Take columns of the current log_size. - let layer_columns = (&mut columns) - .take_while(|column| column.len().ilog2() == log_size) + let layer_columns = columns + .peek_take_while(|column| column.len().ilog2() == log_size) .collect_vec(); layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns)); } + layers.reverse(); Self { layers } } @@ -57,36 +63,160 @@ impl, H: MerkleHasher> MerkleProver { /// /// # Arguments /// - /// * `queries` - A vector of query indices to the largest column. + /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that + /// log_size. + /// * `columns` - A vector of references to columns. /// /// # Returns /// - /// A `Decommitment` struct containing the witness. - pub fn decommit(&self, mut queries: Vec) -> MerkleDecommitment { - let mut witness = Vec::new(); - for layer in &self.layers[..self.layers.len() - 1] { - let mut queries_iter = queries.into_iter().peekable(); - - // Propagate queries and hashes to the next layer. - let mut next_queries = Vec::new(); - while let Some(query) = queries_iter.next() { - next_queries.push(query / 2); - if queries_iter.next_if_eq(&(query ^ 1)).is_some() { - continue; + /// A tuple containing: + /// * A vector of vectors of queried values for each column, in the order of the input columns. + /// * A `MerkleDecommitment` containing the hash and column witnesses. + pub fn decommit( + &self, + queries_per_log_size: BTreeMap>, + columns: Vec<&Col>, + ) -> (ColumnVec>, MerkleDecommitment) { + // Check that queries are sorted and deduped. + for queries in queries_per_log_size.values() { + assert!( + queries.windows(2).all(|w| w[0] < w[1]), + "Queries are not sorted." + ); + } + + // Prepare output buffers. + let mut queried_values_by_layer = vec![]; + let mut decommitment = MerkleDecommitment::empty(); + + // Sort columns by layer. + let mut columns_by_layer = columns + .iter() + .sorted_by_key(|c| Reverse(c.len())) + .peekable(); + + let mut last_layer_queries = vec![]; + for layer_log_size in (0..self.layers.len() as u32).rev() { + // Prepare write buffer for queried values to the current layer. + let mut layer_queried_values = vec![]; + + // Prepare write buffer for queries to the current layer. This will propagate to the + // next layer. + let mut layer_total_queries = vec![]; + + // Each layer node is a hash of column values as previous layer hashes. + // Prepare the relevant columns and previous layer hashes to read from. + let layer_columns = columns_by_layer + .peek_take_while(|column| column.len().ilog2() == layer_log_size) + .collect_vec(); + let previous_layer_hashes = self.layers.get(layer_log_size as usize + 1); + + // Queries to this layer come from queried node in the previous layer and queried + // columns in this one. + let mut prev_layer_queries = last_layer_queries.into_iter().peekable(); + let mut layer_column_queries = + option_flatten_peekable(queries_per_log_size.get(&layer_log_size)); + + // Merge previous layer queries and column queries. + while let Some(node_index) = + next_decommitment_node(&mut prev_layer_queries, &mut layer_column_queries) + { + if let Some(previous_layer_hashes) = previous_layer_hashes { + // If the left child was not computed, add it to the witness. + if prev_layer_queries.next_if_eq(&(2 * node_index)).is_none() { + decommitment + .hash_witness + .push(previous_layer_hashes.at(2 * node_index)); + } + + // If the right child was not computed, add it to the witness. + if prev_layer_queries + .next_if_eq(&(2 * node_index + 1)) + .is_none() + { + decommitment + .hash_witness + .push(previous_layer_hashes.at(2 * node_index + 1)); + } + } + + // If the column values were queried, return them. + let node_values = layer_columns.iter().map(|c| c.at(node_index)); + if layer_column_queries.next_if_eq(&node_index).is_some() { + layer_queried_values.push(node_values.collect_vec()); + } else { + // Otherwise, add them to the witness. + decommitment.column_witness.extend(node_values); } - witness.push(layer.at(query ^ 1)); + + layer_total_queries.push(node_index); } - queries = next_queries; + + queried_values_by_layer.push(layer_queried_values); + + // Propagate queries to the next layer. + last_layer_queries = layer_total_queries; } - MerkleDecommitment { witness } + queried_values_by_layer.reverse(); + + // Rearrange returned queried values according to input, and not by layer. + let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns); + + (queried_values, decommitment) + } + + /// Given queried values by layer, rearranges in the order of input columns. + fn rearrange_queried_values( + queried_values_by_layer: Vec>>, + columns: Vec<&Col>, + ) -> Vec> { + // Turn each column queried values into an iterator. + let mut queried_values_by_layer = queried_values_by_layer + .into_iter() + .map(|layer_results| { + layer_results + .into_iter() + .map(|x| x.into_iter()) + .collect_vec() + }) + .collect_vec(); + + // For each input column, fetch the queried values from the corresponding layer. + let queried_values = columns + .iter() + .map(|column| { + queried_values_by_layer + .get_mut(column.len().ilog2() as usize) + .unwrap() + .iter_mut() + .map(|x| x.next().unwrap()) + .collect_vec() + }) + .collect_vec(); + queried_values } pub fn root(&self) -> H::Hash { - self.layers.last().unwrap().at(0) + self.layers.first().unwrap().at(0) } } #[derive(Debug)] pub struct MerkleDecommitment { - pub witness: Vec, + /// Hash values that the verifier needs but cannot deduce from previous computations, in the + /// order they are needed. + pub hash_witness: Vec, + /// Column values that the verifier needs but cannot deduce from previous computations, in the + /// order they are needed. + /// This complements the column values that were queried. These must be supplied directly to + /// the verifier. + pub column_witness: Vec, +} +impl MerkleDecommitment { + fn empty() -> Self { + Self { + hash_witness: Vec::new(), + column_witness: Vec::new(), + } + } } diff --git a/src/commitment_scheme/utils.rs b/src/commitment_scheme/utils.rs index 34420c297..260fe72b9 100644 --- a/src/commitment_scheme/utils.rs +++ b/src/commitment_scheme/utils.rs @@ -1,4 +1,5 @@ use std::collections::BTreeMap; +use std::iter::Peekable; use std::slice::Iter; use super::hasher::Hasher; @@ -15,6 +16,25 @@ pub fn allocate_layer(n_bytes: usize) -> TreeLayer { unsafe { Box::<[T]>::new_zeroed_slice(n_bytes).assume_init() } } +/// Fetches the next node that needs to be decommited in the current Merkle layer. +pub fn next_decommitment_node( + prev_queries: &mut Peekable>, + layer_queries: &mut Peekable>, +) -> Option { + prev_queries + .peek() + .map(|q| *q / 2) + .into_iter() + .chain(layer_queries.peek().into_iter().copied()) + .min() +} + +pub fn option_flatten_peekable<'a, I: IntoIterator>( + a: Option, +) -> Peekable as IntoIterator>::IntoIter>>> { + a.into_iter().flatten().copied().peekable() +} + pub fn allocate_balanced_tree( bottom_layer_length: usize, size_of_node_bytes: usize, diff --git a/src/commitment_scheme/verifier.rs b/src/commitment_scheme/verifier.rs index dabff528f..800549277 100644 --- a/src/commitment_scheme/verifier.rs +++ b/src/commitment_scheme/verifier.rs @@ -1,27 +1,30 @@ use std::cmp::Reverse; -use std::iter::Peekable; +use std::collections::BTreeMap; use itertools::Itertools; use thiserror::Error; use super::ops::MerkleHasher; use super::prover::MerkleDecommitment; +use super::utils::{next_decommitment_node, option_flatten_peekable}; use crate::core::fields::m31::BaseField; +use crate::core::utils::PeekableExt; +use crate::core::ColumnVec; // TODO(spapini): This struct is not necessary. Make it a function on decommitment? -pub struct MerkleTreeVerifier { +pub struct MerkleVerifier { pub root: H::Hash, + pub column_log_sizes: Vec, } -impl MerkleTreeVerifier { +impl MerkleVerifier { /// Verifies the decommitment of the columns. /// /// # Arguments /// - /// * `queries` - A vector of indices representing the queries to the largest column. - /// Note: This is sufficient for bit reversed STARK commitments. - /// It could be extended to support queries to any column. - /// * `values` - A vector of pairs containing the log_size of the column and the decommitted - /// values of the column. Must be given in the same order as the columns were committed. + /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that + /// log_size. + /// * `queried_values` - A vector of vectors of queried values. For each column, there is a + /// vector of queried values to that column. /// * `decommitment` - The decommitment object containing the witness and column values. /// /// # Errors @@ -44,157 +47,132 @@ impl MerkleTreeVerifier { /// Returns `Ok(())` if the decommitment is successfully verified. pub fn verify( &self, - queries: Vec, - values: Vec<(u32, Vec)>, + queries_per_log_size: BTreeMap>, + queried_values: ColumnVec>, decommitment: MerkleDecommitment, ) -> Result<(), MerkleVerificationError> { - // Check that columns are of descending order. - assert!(values.is_sorted_by_key(|(log_size, _)| Reverse(log_size))); - - // Compute root from decommitment. - let mut verifier = MerkleVerifier:: { - witness: decommitment.witness.into_iter(), - column_values: values.into_iter().peekable(), - layer_column_values: Vec::new(), - }; - let computed_root = verifier.compute_root_from_decommitment(queries)?; - - // Check that all witnesses and values have been consumed. - if !verifier.witness.is_empty() { - return Err(MerkleVerificationError::WitnessTooLong); - } - if !verifier.column_values.is_empty() { - return Err(MerkleVerificationError::ColumnValuesTooLong); - } - - // Check that the computed root matches the expected root. - if computed_root != self.root { - return Err(MerkleVerificationError::RootMismatch); - } - - Ok(()) - } -} - -/// A helper struct for verifying a [MerkleDecommitment]. -struct MerkleVerifier { - /// A queue for consuming the next hash witness from the decommitment. - witness: std::vec::IntoIter<::Hash>, - /// A queue for consuming the next claimed values for each column. - column_values: Peekable)>>, - /// A queue for consuming the next claimed values for each column in the current layer. - layer_column_values: Vec>, -} -impl MerkleVerifier { - /// Computes the root hash of a Merkle tree from the decommitment information. - /// - /// # Arguments - /// - /// * `queries` - A vector of query indices to the largest column. - /// - /// # Returns - /// - /// Returns the computed root hash of the Merkle tree. - /// - /// # Errors - /// - /// Returns a `MerkleVerificationError` if there is an error during the computation. - pub fn compute_root_from_decommitment( - &mut self, - queries: Vec, - ) -> Result { - let max_log_size = self.column_values.peek().unwrap().0; - assert!(*queries.iter().max().unwrap() < 1 << max_log_size); - - // A sequence of queries to the current layer. - // Each query is a pair of the query index and the known hashes of the children, if any. - // The known hashes are represented as ChildrenHashesAtQuery. - // None on the largest layer, or a pair of Option, for the known hashes of the left - // and right children. - let mut queries = queries.into_iter().map(|query| (query, None)).collect_vec(); - + let max_log_size = self.column_log_sizes.iter().max().copied().unwrap_or(0); + + // Prepare read buffers. + let mut queried_values_by_layer = self + .column_log_sizes + .iter() + .copied() + .zip( + queried_values + .into_iter() + .map(|column_values| column_values.into_iter()), + ) + .sorted_by_key(|(log_size, _)| Reverse(*log_size)) + .peekable(); + let mut hash_witness = decommitment.hash_witness.into_iter(); + let mut column_witness = decommitment.column_witness.into_iter(); + + let mut last_layer_hashes: Option> = None; for layer_log_size in (0..=max_log_size).rev() { - // Take values for columns of the current log_size. - self.layer_column_values = (&mut self.column_values) - .take_while(|(log_size, _)| *log_size == layer_log_size) - .map(|(_, values)| values.into_iter()) - .collect(); - - // Compute node hashes for the current layer. - let mut hashes_at_layer = queries + // Prepare read buffer for queried values to the current layer. + let mut layer_queried_values = queried_values_by_layer + .peek_take_while(|(log_size, _)| *log_size == layer_log_size) + .collect_vec(); + let n_columns_in_layer = layer_queried_values.len(); + + // Prepare write buffer for queries to the current layer. This will propagate to the + // next layer. + let mut layer_total_queries = vec![]; + + // Queries to this layer come from queried node in the previous layer and queried + // columns in this one. + let mut prev_layer_queries = last_layer_hashes + .iter() + .flatten() + .map(|(q, _)| *q) + .collect_vec() .into_iter() - .map(|(index, children_hashes)| (index, self.compute_node_hash(children_hashes))) .peekable(); + let mut prev_layer_hashes = last_layer_hashes.as_ref().map(|x| x.iter().peekable()); + let mut layer_column_queries = + option_flatten_peekable(queries_per_log_size.get(&layer_log_size)); - // Propagate queries and hashes to the next layer. - let mut next_queries = Vec::new(); - while let Some((index, node_hash)) = hashes_at_layer.next() { - // If the sibling hash is known, propagate it to the next layer. - if let Some((_, sibling_hash)) = - hashes_at_layer.next_if(|(next_index, _)| *next_index == index ^ 1) - { - next_queries.push((index / 2, Some((Some(node_hash?), Some(sibling_hash?))))); - continue; - } - // Otherwise, propagate the node hash to the next layer, in the correct direction. - if index & 1 == 0 { - next_queries.push((index / 2, Some((Some(node_hash?), None)))); + // Merge previous layer queries and column queries. + while let Some(node_index) = + next_decommitment_node(&mut prev_layer_queries, &mut layer_column_queries) + { + prev_layer_queries + .peek_take_while(|q| q / 2 == node_index) + .for_each(drop); + + let node_hashes = prev_layer_hashes + .as_mut() + .map(|prev_layer_hashes| { + { + // If the left child was not computed, read it from the witness. + let left_hash = prev_layer_hashes + .next_if(|(index, _)| *index == 2 * node_index) + .map(|(_, hash)| Ok(hash.clone())) + .unwrap_or_else(|| { + hash_witness + .next() + .ok_or(MerkleVerificationError::WitnessTooShort) + })?; + + // If the right child was not computed, read it to from the witness. + let right_hash = prev_layer_hashes + .next_if(|(index, _)| *index == 2 * node_index + 1) + .map(|(_, hash)| Ok(hash.clone())) + .unwrap_or_else(|| { + hash_witness + .next() + .ok_or(MerkleVerificationError::WitnessTooShort) + })?; + Ok((left_hash, right_hash)) + } + }) + .transpose()?; + + // If the column values were queried, read them from `queried_value`. + let node_values = if layer_column_queries.next_if_eq(&node_index).is_some() { + layer_queried_values + .iter_mut() + .map(|(_, ref mut column_queries)| { + column_queries + .next() + .ok_or(MerkleVerificationError::ColumnValuesTooShort) + }) + .collect::, _>>()? } else { - next_queries.push((index / 2, Some((None, Some(node_hash?))))); + // Otherwise, read them from the witness. + (&mut column_witness).take(n_columns_in_layer).collect_vec() + }; + if node_values.len() != n_columns_in_layer { + return Err(MerkleVerificationError::WitnessTooShort); } + + layer_total_queries.push((node_index, H::hash_node(node_hashes, &node_values))); } - queries = next_queries; - // Check that all layer_column_values have been consumed. - if self - .layer_column_values - .iter_mut() - .any(|values| values.next().is_some()) - { + if !layer_queried_values.iter().all(|(_, c)| c.is_empty()) { return Err(MerkleVerificationError::ColumnValuesTooLong); } + last_layer_hashes = Some(layer_total_queries); } - assert_eq!(queries.len(), 1); - Ok(queries.pop().unwrap().1.unwrap().0.unwrap()) - } + // Check that all witnesses and values have been consumed. + if !hash_witness.is_empty() { + return Err(MerkleVerificationError::WitnessTooLong); + } + if !column_witness.is_empty() { + return Err(MerkleVerificationError::WitnessTooLong); + } - fn compute_node_hash( - &mut self, - children_hashes: ChildrenHashesAtQuery, - ) -> Result { - // For each child with an unknown hash, fill it from the witness queue. - let hashes_part = children_hashes - .map(|(l, r)| { - let l = l - .or_else(|| self.witness.next()) - .ok_or(MerkleVerificationError::WitnessTooShort)?; - let r = r - .or_else(|| self.witness.next()) - .ok_or(MerkleVerificationError::WitnessTooShort)?; - Ok((l, r)) - }) - .transpose()?; - // Fill the column values from the layer_column_values queue. - let column_values = self - .layer_column_values - .iter_mut() - .map(|values| { - values - .next() - .ok_or(MerkleVerificationError::ColumnValuesTooShort) - }) - .collect::, _>>()?; - // Hash the node. - Ok(H::hash_node(hashes_part, &column_values)) + let [(_, computed_root)] = last_layer_hashes.unwrap().try_into().unwrap(); + if computed_root != self.root { + return Err(MerkleVerificationError::RootMismatch); + } + + Ok(()) } } -type ChildrenHashesAtQuery = Option<( - Option<::Hash>, - Option<::Hash>, -)>; - #[derive(Clone, Copy, Debug, Error, PartialEq, Eq)] pub enum MerkleVerificationError { #[error("Witness is too short.")] diff --git a/src/core/utils.rs b/src/core/utils.rs index 8a6483e36..f9d928c31 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -1,3 +1,5 @@ +use std::iter::Peekable; + pub trait IteratorMutExt<'a, T: 'a>: Iterator { fn assign(self, other: impl IntoIterator) where @@ -9,6 +11,41 @@ pub trait IteratorMutExt<'a, T: 'a>: Iterator { impl<'a, T: 'a, I: Iterator> IteratorMutExt<'a, T> for I {} +/// An iterator that takes elements from the underlying [Peekable] while the predicate is true. +/// Used to implement [PeekableExt::peek_take_while]. +pub struct PeekTakeWhile<'a, I: Iterator, P: FnMut(&I::Item) -> bool> { + iter: &'a mut Peekable, + predicate: P, +} +impl<'a, I: Iterator, P: FnMut(&I::Item) -> bool> Iterator for PeekTakeWhile<'a, I, P> { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next_if(&mut self.predicate) + } +} +pub trait PeekableExt<'a, I: Iterator> { + /// Returns an iterator that takes elements from the underlying [Peekable] while the predicate + /// is true. + /// Unlike [Iterator::take_while], this iterator does not consume the first element that does + /// not satisfy the predicate. + fn peek_take_while bool>( + &'a mut self, + predicate: P, + ) -> PeekTakeWhile<'a, I, P>; +} +impl<'a, I: Iterator> PeekableExt<'a, I> for Peekable { + fn peek_take_while bool>( + &'a mut self, + predicate: P, + ) -> PeekTakeWhile<'a, I, P> { + PeekTakeWhile { + iter: self, + predicate, + } + } +} + pub(crate) fn bit_reverse_index(i: usize, log_size: u32) -> usize { if log_size == 0 { return i;