From 33d676f3353ca9726b83a3db713f5f26e0f88f82 Mon Sep 17 00:00:00 2001 From: Ilya Lesokhin Date: Thu, 21 Nov 2024 13:14:31 +0200 Subject: [PATCH] rearrange queried_values_by_layer for merkle. --- crates/prover/src/core/fri.rs | 42 ++++++++++++++++++-- crates/prover/src/core/pcs/quotients.rs | 45 +++++++++++---------- crates/prover/src/core/pcs/verifier.rs | 8 +++- crates/prover/src/core/vcs/prover.rs | 42 +++----------------- crates/prover/src/core/vcs/test_utils.rs | 5 +-- crates/prover/src/core/vcs/verifier.rs | 50 +++++++++++------------- 6 files changed, 101 insertions(+), 91 deletions(-) diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index 03dac47da..1443ef07c 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -4,7 +4,7 @@ use std::fmt::Debug; use std::iter::zip; use std::ops::RangeInclusive; -use itertools::{zip_eq, Itertools}; +use itertools::{izip, zip_eq, Itertools}; use num_traits::Zero; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -738,6 +738,28 @@ impl FriFirstLayerVerifier { return Err(FriVerificationError::FirstLayerEvaluationsInvalid); } + let mut column_iters = all_column_decommitment_values + .iter() + .map(|column| column.iter()) + .collect_vec(); + + let mut decommit_by_log_size = vec![Vec::new(); (max_column_log_size + 1) as usize]; + let mut done = false; + while !done { + done = true; + for (log_size, column_iter) in zip_eq( + self.column_commitment_domains + .iter() + .flat_map(|column_domain| [column_domain.log_size(); SECURE_EXTENSION_DEGREE]), + column_iters.iter_mut(), + ) { + if let Some(value) = column_iter.next() { + decommit_by_log_size[log_size as usize].push(*value); + done = false; + } + } + } + let merkle_verifier = MerkleVerifier::new( self.proof.commitment, self.column_commitment_domains @@ -749,7 +771,7 @@ impl FriFirstLayerVerifier { merkle_verifier .verify( &decommitment_positions_by_log_size, - all_column_decommitment_values, + decommit_by_log_size, self.proof.decommitment.clone(), ) .map_err(|error| FriVerificationError::FirstLayerCommitmentInvalid { error })?; @@ -823,10 +845,24 @@ impl FriInnerLayerVerifier { vec![self.domain.log_size(); SECURE_EXTENSION_DEGREE], ); + let mut decommit_by_log_size = vec![Vec::new(); (self.domain.log_size() + 1) as usize]; + + for (a, b, c, d) in izip!( + &decommitment_values.columns[0], + &decommitment_values.columns[1], + &decommitment_values.columns[2], + &decommitment_values.columns[3] + ) { + decommit_by_log_size[self.domain.log_size() as usize].push(*a); + decommit_by_log_size[self.domain.log_size() as usize].push(*b); + decommit_by_log_size[self.domain.log_size() as usize].push(*c); + decommit_by_log_size[self.domain.log_size() as usize].push(*d); + } + merkle_verifier .verify( &BTreeMap::from_iter([(self.domain.log_size(), decommitment_positions)]), - decommitment_values.columns.to_vec(), + decommit_by_log_size, self.proof.decommitment.clone(), ) .map_err(|e| FriVerificationError::InnerLayerCommitmentInvalid { diff --git a/crates/prover/src/core/pcs/quotients.rs b/crates/prover/src/core/pcs/quotients.rs index aca0901c6..244af5677 100644 --- a/crates/prover/src/core/pcs/quotients.rs +++ b/crates/prover/src/core/pcs/quotients.rs @@ -5,6 +5,7 @@ use std::iter::zip; use itertools::{izip, multiunzip, Itertools}; use tracing::{span, Level}; +use super::TreeVec; use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants}; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; @@ -104,21 +105,29 @@ pub fn fri_answers( samples: &[Vec], random_coeff: SecureField, query_positions_per_log_size: &BTreeMap>, - queried_values_per_column: &[Vec], + mut queried_values_per_layer: TreeVec>>, + mut columns_per_log_size: TreeVec>, ) -> Result>, VerificationError> { - izip!(column_log_sizes, samples, queried_values_per_column) + izip!(column_log_sizes, samples) .sorted_by_key(|(log_size, ..)| Reverse(*log_size)) .group_by(|(log_size, ..)| *log_size) .into_iter() .map(|(log_size, tuples)| { - let (_, samples, queried_values_per_column): (Vec<_>, Vec<_>, Vec<_>) = - multiunzip(tuples); + let (_, samples): (Vec<_>, Vec<_>) = multiunzip(tuples); fri_answers_for_log_size( log_size, &samples, random_coeff, &query_positions_per_log_size[&log_size], - &queried_values_per_column, + queried_values_per_layer.as_mut().map(|queried_values| { + match queried_values.get_mut(log_size as usize) { + Some(queried_values) => std::mem::take(queried_values).into_iter(), + None => (vec![]).into_iter(), + } + }), + columns_per_log_size + .as_mut() + .map(|colums_log_sizes| *colums_log_sizes.get(&log_size).unwrap_or(&0)), ) }) .collect() @@ -129,27 +138,23 @@ pub fn fri_answers_for_log_size( samples: &[&Vec], random_coeff: SecureField, query_positions: &[usize], - queried_values_per_column: &[&Vec], + mut queried_values_per_layer: TreeVec>, + n_columns: TreeVec, ) -> Result, VerificationError> { - for queried_values in queried_values_per_column { - if queried_values.len() != query_positions.len() { - return Err(VerificationError::InvalidStructure( - "Insufficient number of queried values".to_string(), - )); - } - } - let sample_batches = ColumnSampleBatch::new_vec(samples); let quotient_constants = quotient_constants(&sample_batches, random_coeff); let commitment_domain = CanonicCoset::new(log_size).circle_domain(); - let mut quotient_evals_at_queries = Vec::new(); - for (row, &query_position) in query_positions.iter().enumerate() { + let mut quotient_evals_at_queries = Vec::new(); + for &query_position in query_positions { let domain_point = commitment_domain.at(bit_reverse_index(query_position, log_size)); - let queried_values_at_row = queried_values_per_column - .iter() - .map(|col| col[row]) - .collect_vec(); + + let queried_values_at_row = queried_values_per_layer + .as_mut() + .zip_eq(n_columns.as_ref()) + .map(|(queried_values, n_columns)| queried_values.take(*n_columns).collect()) + .flatten(); + quotient_evals_at_queries.push(accumulate_row_quotients( &sample_batches, &queried_values_at_row, diff --git a/crates/prover/src/core/pcs/verifier.rs b/crates/prover/src/core/pcs/verifier.rs index 200fe98d5..d77bfdf78 100644 --- a/crates/prover/src/core/pcs/verifier.rs +++ b/crates/prover/src/core/pcs/verifier.rs @@ -108,12 +108,18 @@ impl CommitmentSchemeVerifier { }) .flatten(); + let n_columns_per_log_size = self + .trees + .as_ref() + .map(|tree| tree.n_columns_per_log_size.clone()); + let fri_answers = fri_answers( self.column_log_sizes().flatten().into_iter().collect(), &samples, random_coeff, &query_positions_per_log_size, - &proof.queried_values.flatten(), + proof.queried_values, + n_columns_per_log_size, )?; fri_verifier.decommit(fri_answers)?; diff --git a/crates/prover/src/core/vcs/prover.rs b/crates/prover/src/core/vcs/prover.rs index bc788e51f..d1cb81a9c 100644 --- a/crates/prover/src/core/vcs/prover.rs +++ b/crates/prover/src/core/vcs/prover.rs @@ -48,6 +48,7 @@ impl, H: MerkleHasher> MerkleProver { .into_iter() .sorted_by_key(|c| Reverse(c.len())) .peekable(); + let mut layers: Vec> = Vec::new(); let max_log_size = columns.peek().unwrap().len().ilog2(); @@ -140,7 +141,7 @@ impl, H: MerkleHasher> MerkleProver { // If the column values were queried, return them. let node_values = layer_columns.iter().map(|c| c.at(node_index)); if layer_column_queries.next_if_eq(&node_index).is_some() { - layer_queried_values.push(node_values.collect_vec()); + layer_queried_values.extend(node_values); } else { // Otherwise, add them to the witness. decommitment.column_witness.extend(node_values); @@ -154,43 +155,12 @@ impl, H: MerkleHasher> MerkleProver { // Propagate queries to the next layer. last_layer_queries = layer_total_queries; } - queried_values_by_layer.reverse(); - // Rearrange returned queried values according to input, and not by layer. - let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns); - - (queried_values, decommitment) - } + // Reverse the queried values by layer so that queried_values_by_layer[i] + // corresponds to the quiries for the layer whose log size is i. + queried_values_by_layer.reverse(); - /// Given queried values by layer, rearranges in the order of input columns. - fn rearrange_queried_values( - queried_values_by_layer: Vec>>, - columns: Vec<&Col>, - ) -> Vec> { - // Turn each column queried values into an iterator. - let mut queried_values_by_layer = queried_values_by_layer - .into_iter() - .map(|layer_results| { - layer_results - .into_iter() - .map(|x| x.into_iter()) - .collect_vec() - }) - .collect_vec(); - - // For each input column, fetch the queried values from the corresponding layer. - let queried_values = columns - .iter() - .map(|column| { - queried_values_by_layer - .get_mut(column.len().ilog2() as usize) - .unwrap() - .iter_mut() - .map(|x| x.next().unwrap()) - .collect_vec() - }) - .collect_vec(); - queried_values + (queried_values_by_layer, decommitment) } pub fn root(&self) -> H::Hash { diff --git a/crates/prover/src/core/vcs/test_utils.rs b/crates/prover/src/core/vcs/test_utils.rs index b92f9e971..2b9cdf186 100644 --- a/crates/prover/src/core/vcs/test_utils.rs +++ b/crates/prover/src/core/vcs/test_utils.rs @@ -52,9 +52,6 @@ where let (values, decommitment) = merkle.decommit(&queries, cols.iter().collect_vec()); - let verifier = MerkleVerifier { - root: merkle.root(), - column_log_sizes: log_sizes, - }; + let verifier = MerkleVerifier::new(merkle.root(), log_sizes); (queries, decommitment, values, verifier) } diff --git a/crates/prover/src/core/vcs/verifier.rs b/crates/prover/src/core/vcs/verifier.rs index 163fed2f1..bf5270230 100644 --- a/crates/prover/src/core/vcs/verifier.rs +++ b/crates/prover/src/core/vcs/verifier.rs @@ -1,4 +1,3 @@ -use std::cmp::Reverse; use std::collections::BTreeMap; use itertools::Itertools; @@ -9,17 +8,22 @@ use super::prover::MerkleDecommitment; use super::utils::{next_decommitment_node, option_flatten_peekable}; use crate::core::fields::m31::BaseField; use crate::core::utils::PeekableExt; -use crate::core::ColumnVec; pub struct MerkleVerifier { pub root: H::Hash, pub column_log_sizes: Vec, + pub n_columns_per_log_size: BTreeMap, } impl MerkleVerifier { pub fn new(root: H::Hash, column_log_sizes: Vec) -> Self { + let mut n_columns_per_log_size = BTreeMap::::new(); + for log_size in &column_log_sizes { + *n_columns_per_log_size.entry(*log_size).or_insert(0) += 1; + } Self { root, column_log_sizes, + n_columns_per_log_size, } } /// Verifies the decommitment of the columns. @@ -53,35 +57,30 @@ impl MerkleVerifier { pub fn verify( &self, queries_per_log_size: &BTreeMap>, - queried_values: ColumnVec>, + queried_values_by_layer: Vec>, decommitment: MerkleDecommitment, ) -> Result<(), MerkleVerificationError> { let Some(max_log_size) = self.column_log_sizes.iter().max() else { return Ok(()); }; + let mut queried_values_by_layer_iter = queried_values_by_layer.into_iter().rev(); + // Prepare read buffers. - let mut queried_values_by_layer = self - .column_log_sizes - .iter() - .copied() - .zip( - queried_values - .into_iter() - .map(|column_values| column_values.into_iter()), - ) - .sorted_by_key(|(log_size, _)| Reverse(*log_size)) - .peekable(); + let mut hash_witness = decommitment.hash_witness.into_iter(); let mut column_witness = decommitment.column_witness.into_iter(); let mut last_layer_hashes: Option> = None; for layer_log_size in (0..=*max_log_size).rev() { // Prepare read buffer for queried values to the current layer. - let mut layer_queried_values = queried_values_by_layer - .peek_take_while(|(log_size, _)| *log_size == layer_log_size) - .collect_vec(); - let n_columns_in_layer = layer_queried_values.len(); + let mut layer_queried_values = queried_values_by_layer_iter.next().unwrap().into_iter(); + + let n_columns_in_layer = self + .n_columns_per_log_size + .get(&layer_log_size) + .copied() + .unwrap_or(0); // Prepare write buffer for queries to the current layer. This will propagate to the // next layer. @@ -138,18 +137,15 @@ impl MerkleVerifier { // If the column values were queried, read them from `queried_value`. let node_values = if layer_column_queries.next_if_eq(&node_index).is_some() { - layer_queried_values - .iter_mut() - .map(|(_, ref mut column_queries)| { - column_queries - .next() - .ok_or(MerkleVerificationError::ColumnValuesTooShort) - }) - .collect::, _>>()? + (&mut layer_queried_values) + .take(n_columns_in_layer) + .collect_vec() } else { // Otherwise, read them from the witness. (&mut column_witness).take(n_columns_in_layer).collect_vec() }; + + assert_eq!(node_values.len(), n_columns_in_layer); if node_values.len() != n_columns_in_layer { return Err(MerkleVerificationError::WitnessTooShort); } @@ -157,7 +153,7 @@ impl MerkleVerifier { layer_total_queries.push((node_index, H::hash_node(node_hashes, &node_values))); } - if !layer_queried_values.iter().all(|(_, c)| c.is_empty()) { + if !layer_queried_values.is_empty() { return Err(MerkleVerificationError::ColumnValuesTooLong); } last_layer_hashes = Some(layer_total_queries);