From 44550e73f86dd8244176016e72d273174ca34ec2 Mon Sep 17 00:00:00 2001 From: Ilya Lesokhin Date: Thu, 21 Nov 2024 13:14:31 +0200 Subject: [PATCH] Rearrange queried_values_by_layer for merkle. --- .../src/constraint_framework/component.rs | 2 +- crates/prover/src/constraint_framework/mod.rs | 2 +- crates/prover/src/core/fri.rs | 28 +++--- crates/prover/src/core/pcs/prover.rs | 4 +- crates/prover/src/core/pcs/quotients.rs | 46 +++++----- crates/prover/src/core/pcs/verifier.rs | 18 ++-- crates/prover/src/core/vcs/blake2_merkle.rs | 10 +-- .../prover/src/core/vcs/poseidon252_merkle.rs | 14 +-- crates/prover/src/core/vcs/prover.rs | 51 ++--------- crates/prover/src/core/vcs/test_utils.rs | 7 +- crates/prover/src/core/vcs/verifier.rs | 90 +++++++++---------- 11 files changed, 115 insertions(+), 157 deletions(-) diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 078e84bb8..23981cc06 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -92,7 +92,7 @@ impl TraceLocationAllocator { } } - pub fn preprocessed_columns(&self) -> &HashMap { + pub const fn preprocessed_columns(&self) -> &HashMap { &self.preprocessed_columns } diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 9bbf05402..0870e149d 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -234,7 +234,7 @@ pub struct RelationEntry<'a, F: Clone, EF: RelationEFTraitBound, R: Relation< values: &'a [F], } impl<'a, F: Clone, EF: RelationEFTraitBound, R: Relation> RelationEntry<'a, F, EF, R> { - pub fn new(relation: &'a R, multiplicity: EF, values: &'a [F]) -> Self { + pub const fn new(relation: &'a R, multiplicity: EF, values: &'a [F]) -> Self { Self { relation, multiplicity, diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index 46eabeedb..c80222635 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -699,9 +699,9 @@ impl FriFirstLayerVerifier { let mut fri_witness = self.proof.fri_witness.iter().copied(); let mut decommitment_positions_by_log_size = BTreeMap::new(); - let mut all_column_decommitment_values = Vec::new(); let mut folded_evals_by_column = Vec::new(); + let mut decommitment = vec![]; for (&column_domain, column_query_evals) in zip_eq(&self.column_commitment_domains, query_evals_by_column) { @@ -722,15 +722,13 @@ impl FriFirstLayerVerifier { decommitment_positions_by_log_size .insert(column_domain.log_size(), column_decommitment_positions); - // Prepare values in the structure needed for merkle decommitment. - let column_decommitment_values: SecureColumnByCoords = sparse_evaluation - .subset_evals - .iter() - .flatten() - .copied() - .collect(); - - all_column_decommitment_values.extend(column_decommitment_values.columns); + decommitment.extend( + sparse_evaluation + .subset_evals + .iter() + .flatten() + .flat_map(|qm31| qm31.to_m31_array()), + ); let folded_evals = sparse_evaluation.fold_circle(self.folding_alpha, column_domain); folded_evals_by_column.push(folded_evals); @@ -752,7 +750,7 @@ impl FriFirstLayerVerifier { merkle_verifier .verify( &decommitment_positions_by_log_size, - all_column_decommitment_values, + decommitment, self.proof.decommitment.clone(), ) .map_err(|error| FriVerificationError::FirstLayerCommitmentInvalid { error })?; @@ -814,12 +812,12 @@ impl FriInnerLayerVerifier { }); } - let decommitment_values: SecureColumnByCoords = sparse_evaluation + let decommitment = sparse_evaluation .subset_evals .iter() .flatten() - .copied() - .collect(); + .flat_map(|qm31| qm31.to_m31_array()) + .collect_vec(); let merkle_verifier = MerkleVerifier::new( self.proof.commitment, @@ -829,7 +827,7 @@ impl FriInnerLayerVerifier { merkle_verifier .verify( &BTreeMap::from_iter([(self.domain.log_size(), decommitment_positions)]), - decommitment_values.columns.to_vec(), + decommitment, self.proof.decommitment.clone(), ) .map_err(|e| FriVerificationError::InnerLayerCommitmentInvalid { diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index ef27f706d..59a2e8e8d 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -152,7 +152,7 @@ pub struct CommitmentSchemeProof { pub commitments: TreeVec, pub sampled_values: TreeVec>>, pub decommitments: TreeVec>, - pub queried_values: TreeVec>>, + pub queried_values: TreeVec>, pub proof_of_work: u64, pub fri_proof: FriProof, } @@ -231,7 +231,7 @@ impl, MC: MerkleChannel> CommitmentTreeProver { fn decommit( &self, queries: &BTreeMap>, - ) -> (ColumnVec>, MerkleDecommitment) { + ) -> (Vec, MerkleDecommitment) { let eval_vec = self .evaluations .iter() diff --git a/crates/prover/src/core/pcs/quotients.rs b/crates/prover/src/core/pcs/quotients.rs index aca0901c6..d41f17670 100644 --- a/crates/prover/src/core/pcs/quotients.rs +++ b/crates/prover/src/core/pcs/quotients.rs @@ -5,6 +5,7 @@ use std::iter::zip; use itertools::{izip, multiunzip, Itertools}; use tracing::{span, Level}; +use super::TreeVec; use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants}; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; @@ -100,25 +101,30 @@ pub fn compute_fri_quotients( } pub fn fri_answers( - column_log_sizes: Vec, - samples: &[Vec], + column_log_sizes: TreeVec>, + samples: TreeVec>>, random_coeff: SecureField, query_positions_per_log_size: &BTreeMap>, - queried_values_per_column: &[Vec], + queried_values: TreeVec>, + n_columns_per_log_size: TreeVec<&BTreeMap>, ) -> Result>, VerificationError> { - izip!(column_log_sizes, samples, queried_values_per_column) + let mut queried_values = queried_values.map(|values| values.into_iter()); + + izip!(column_log_sizes.flatten(), samples.flatten().iter()) .sorted_by_key(|(log_size, ..)| Reverse(*log_size)) .group_by(|(log_size, ..)| *log_size) .into_iter() .map(|(log_size, tuples)| { - let (_, samples, queried_values_per_column): (Vec<_>, Vec<_>, Vec<_>) = - multiunzip(tuples); + let (_, samples): (Vec<_>, Vec<_>) = multiunzip(tuples); fri_answers_for_log_size( log_size, &samples, random_coeff, &query_positions_per_log_size[&log_size], - &queried_values_per_column, + &mut queried_values, + n_columns_per_log_size + .as_ref() + .map(|colums_log_sizes| *colums_log_sizes.get(&log_size).unwrap_or(&0)), ) }) .collect() @@ -129,27 +135,23 @@ pub fn fri_answers_for_log_size( samples: &[&Vec], random_coeff: SecureField, query_positions: &[usize], - queried_values_per_column: &[&Vec], + queried_values: &mut TreeVec>, + n_columns: TreeVec, ) -> Result, VerificationError> { - for queried_values in queried_values_per_column { - if queried_values.len() != query_positions.len() { - return Err(VerificationError::InvalidStructure( - "Insufficient number of queried values".to_string(), - )); - } - } - let sample_batches = ColumnSampleBatch::new_vec(samples); let quotient_constants = quotient_constants(&sample_batches, random_coeff); let commitment_domain = CanonicCoset::new(log_size).circle_domain(); - let mut quotient_evals_at_queries = Vec::new(); - for (row, &query_position) in query_positions.iter().enumerate() { + let mut quotient_evals_at_queries = Vec::new(); + for &query_position in query_positions { let domain_point = commitment_domain.at(bit_reverse_index(query_position, log_size)); - let queried_values_at_row = queried_values_per_column - .iter() - .map(|col| col[row]) - .collect_vec(); + + let queried_values_at_row = queried_values + .as_mut() + .zip_eq(n_columns.as_ref()) + .map(|(queried_values, n_columns)| queried_values.take(*n_columns).collect()) + .flatten(); + quotient_evals_at_queries.push(accumulate_row_quotients( &sample_batches, &queried_values_at_row, diff --git a/crates/prover/src/core/pcs/verifier.rs b/crates/prover/src/core/pcs/verifier.rs index 200fe98d5..db7bea3d3 100644 --- a/crates/prover/src/core/pcs/verifier.rs +++ b/crates/prover/src/core/pcs/verifier.rs @@ -99,21 +99,23 @@ impl CommitmentSchemeVerifier { .collect::>()?; // Answer FRI queries. - let samples = sampled_points - .zip_cols(proof.sampled_values) - .map_cols(|(sampled_points, sampled_values)| { + let samples = sampled_points.zip_cols(proof.sampled_values).map_cols( + |(sampled_points, sampled_values)| { zip(sampled_points, sampled_values) .map(|(point, value)| PointSample { point, value }) .collect_vec() - }) - .flatten(); + }, + ); + + let n_columns_per_log_size = self.trees.as_ref().map(|tree| &tree.n_columns_per_log_size); let fri_answers = fri_answers( - self.column_log_sizes().flatten().into_iter().collect(), - &samples, + self.column_log_sizes(), + samples, random_coeff, &query_positions_per_log_size, - &proof.queried_values.flatten(), + proof.queried_values, + n_columns_per_log_size, )?; fri_verifier.decommit(fri_answers)?; diff --git a/crates/prover/src/core/vcs/blake2_merkle.rs b/crates/prover/src/core/vcs/blake2_merkle.rs index 8401716fa..3664ea147 100644 --- a/crates/prover/src/core/vcs/blake2_merkle.rs +++ b/crates/prover/src/core/vcs/blake2_merkle.rs @@ -86,7 +86,7 @@ mod tests { #[test] fn test_merkle_invalid_value() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3][2] = BaseField::zero(); + values[6] = BaseField::zero(); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), @@ -119,22 +119,22 @@ mod tests { #[test] fn test_merkle_column_values_too_long() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].push(BaseField::zero()); + values.insert(3, BaseField::zero()); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooLong + MerkleVerificationError::TooManyQueriedValues ); } #[test] fn test_merkle_column_values_too_short() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].pop(); + values.remove(3); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooShort + MerkleVerificationError::TooFewQueriedValues ); } diff --git a/crates/prover/src/core/vcs/poseidon252_merkle.rs b/crates/prover/src/core/vcs/poseidon252_merkle.rs index 5ffba1ea6..f39a2c62d 100644 --- a/crates/prover/src/core/vcs/poseidon252_merkle.rs +++ b/crates/prover/src/core/vcs/poseidon252_merkle.rs @@ -114,7 +114,7 @@ mod tests { fn test_merkle_invalid_value() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3][2] = BaseField::zero(); + values[6] = BaseField::zero(); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), @@ -147,26 +147,26 @@ mod tests { } #[test] - fn test_merkle_column_values_too_long() { + fn test_merkle_values_too_long() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].push(BaseField::zero()); + values.insert(3, BaseField::zero()); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooLong + MerkleVerificationError::TooManyQueriedValues ); } #[test] - fn test_merkle_column_values_too_short() { + fn test_merkle_values_too_short() { let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].pop(); + values.remove(3); assert_eq!( verifier.verify(&queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooShort + MerkleVerificationError::TooFewQueriedValues ); } } diff --git a/crates/prover/src/core/vcs/prover.rs b/crates/prover/src/core/vcs/prover.rs index bc788e51f..da4695d3f 100644 --- a/crates/prover/src/core/vcs/prover.rs +++ b/crates/prover/src/core/vcs/prover.rs @@ -9,7 +9,6 @@ use super::utils::{next_decommitment_node, option_flatten_peekable}; use crate::core::backend::{Col, Column}; use crate::core::fields::m31::BaseField; use crate::core::utils::PeekableExt; -use crate::core::ColumnVec; pub struct MerkleProver, H: MerkleHasher> { /// Layers of the Merkle tree. @@ -48,6 +47,7 @@ impl, H: MerkleHasher> MerkleProver { .into_iter() .sorted_by_key(|c| Reverse(c.len())) .peekable(); + let mut layers: Vec> = Vec::new(); let max_log_size = columns.peek().unwrap().len().ilog2(); @@ -75,15 +75,16 @@ impl, H: MerkleHasher> MerkleProver { /// # Returns /// /// A tuple containing: - /// * A vector of vectors of queried values for each column, in the order of the input columns. + /// * A vector queried values sorted by the order they were queried from the largest layer to + /// the smallest. /// * A `MerkleDecommitment` containing the hash and column witnesses. pub fn decommit( &self, queries_per_log_size: &BTreeMap>, columns: Vec<&Col>, - ) -> (ColumnVec>, MerkleDecommitment) { + ) -> (Vec, MerkleDecommitment) { // Prepare output buffers. - let mut queried_values_by_layer = vec![]; + let mut queried_values = vec![]; let mut decommitment = MerkleDecommitment::empty(); // Sort columns by layer. @@ -94,9 +95,6 @@ impl, H: MerkleHasher> MerkleProver { let mut last_layer_queries = vec![]; for layer_log_size in (0..self.layers.len() as u32).rev() { - // Prepare write buffer for queried values to the current layer. - let mut layer_queried_values = vec![]; - // Prepare write buffer for queries to the current layer. This will propagate to the // next layer. let mut layer_total_queries = vec![]; @@ -140,7 +138,7 @@ impl, H: MerkleHasher> MerkleProver { // If the column values were queried, return them. let node_values = layer_columns.iter().map(|c| c.at(node_index)); if layer_column_queries.next_if_eq(&node_index).is_some() { - layer_queried_values.push(node_values.collect_vec()); + queried_values.extend(node_values); } else { // Otherwise, add them to the witness. decommitment.column_witness.extend(node_values); @@ -149,50 +147,13 @@ impl, H: MerkleHasher> MerkleProver { layer_total_queries.push(node_index); } - queried_values_by_layer.push(layer_queried_values); - // Propagate queries to the next layer. last_layer_queries = layer_total_queries; } - queried_values_by_layer.reverse(); - - // Rearrange returned queried values according to input, and not by layer. - let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns); (queried_values, decommitment) } - /// Given queried values by layer, rearranges in the order of input columns. - fn rearrange_queried_values( - queried_values_by_layer: Vec>>, - columns: Vec<&Col>, - ) -> Vec> { - // Turn each column queried values into an iterator. - let mut queried_values_by_layer = queried_values_by_layer - .into_iter() - .map(|layer_results| { - layer_results - .into_iter() - .map(|x| x.into_iter()) - .collect_vec() - }) - .collect_vec(); - - // For each input column, fetch the queried values from the corresponding layer. - let queried_values = columns - .iter() - .map(|column| { - queried_values_by_layer - .get_mut(column.len().ilog2() as usize) - .unwrap() - .iter_mut() - .map(|x| x.next().unwrap()) - .collect_vec() - }) - .collect_vec(); - queried_values - } - pub fn root(&self) -> H::Hash { self.layers.first().unwrap().at(0) } diff --git a/crates/prover/src/core/vcs/test_utils.rs b/crates/prover/src/core/vcs/test_utils.rs index b92f9e971..c906f05d0 100644 --- a/crates/prover/src/core/vcs/test_utils.rs +++ b/crates/prover/src/core/vcs/test_utils.rs @@ -14,7 +14,7 @@ use crate::core::vcs::prover::MerkleProver; pub type TestData = ( BTreeMap>, MerkleDecommitment, - Vec>, + Vec, MerkleVerifier, ); @@ -52,9 +52,6 @@ where let (values, decommitment) = merkle.decommit(&queries, cols.iter().collect_vec()); - let verifier = MerkleVerifier { - root: merkle.root(), - column_log_sizes: log_sizes, - }; + let verifier = MerkleVerifier::new(merkle.root(), log_sizes); (queries, decommitment, values, verifier) } diff --git a/crates/prover/src/core/vcs/verifier.rs b/crates/prover/src/core/vcs/verifier.rs index 9c1b0b39a..fcd0453a3 100644 --- a/crates/prover/src/core/vcs/verifier.rs +++ b/crates/prover/src/core/vcs/verifier.rs @@ -1,4 +1,3 @@ -use std::cmp::Reverse; use std::collections::BTreeMap; use itertools::Itertools; @@ -9,27 +8,35 @@ use super::prover::MerkleDecommitment; use super::utils::{next_decommitment_node, option_flatten_peekable}; use crate::core::fields::m31::BaseField; use crate::core::utils::PeekableExt; -use crate::core::ColumnVec; pub struct MerkleVerifier { pub root: H::Hash, pub column_log_sizes: Vec, + pub n_columns_per_log_size: BTreeMap, } impl MerkleVerifier { pub fn new(root: H::Hash, column_log_sizes: Vec) -> Self { + let mut n_columns_per_log_size = BTreeMap::new(); + for log_size in &column_log_sizes { + *n_columns_per_log_size.entry(*log_size).or_insert(0) += 1; + } + Self { root, column_log_sizes, + n_columns_per_log_size, } } /// Verifies the decommitment of the columns. /// + /// Returns `Ok(())` if the decommitment is successfully verified. + /// /// # Arguments /// /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that /// log_size. - /// * `queried_values` - A vector of vectors of queried values. For each column, there is a - /// vector of queried values to that column. + /// * `queried_values` - A vector of queried values according to the order in + /// [`MerkleProver::decommit()`]. /// * `decommitment` - The decommitment object containing the witness and column values. /// /// # Errors @@ -38,45 +45,35 @@ impl MerkleVerifier { /// /// * The witness is too long (not fully consumed). /// * The witness is too short (missing values). - /// * The column values are too long (not fully consumed). - /// * The column values are too short (missing values). + /// * Too many queried values (not fully consumed). + /// * Too few queried values (missing values). /// * The computed root does not match the expected root. /// - /// # Returns - /// - /// Returns `Ok(())` if the decommitment is successfully verified. + /// [`MerkleProver::decommit()`]: crate::core::...::MerkleProver::decommit + pub fn verify( &self, queries_per_log_size: &BTreeMap>, - queried_values: ColumnVec>, + queried_values: Vec, decommitment: MerkleDecommitment, ) -> Result<(), MerkleVerificationError> { let Some(max_log_size) = self.column_log_sizes.iter().max() else { return Ok(()); }; + let mut queried_values = queried_values.into_iter(); + // Prepare read buffers. - let mut queried_values_by_layer = self - .column_log_sizes - .iter() - .copied() - .zip( - queried_values - .into_iter() - .map(|column_values| column_values.into_iter()), - ) - .sorted_by_key(|(log_size, _)| Reverse(*log_size)) - .peekable(); + let mut hash_witness = decommitment.hash_witness.into_iter(); let mut column_witness = decommitment.column_witness.into_iter(); let mut last_layer_hashes: Option> = None; for layer_log_size in (0..=*max_log_size).rev() { - // Prepare read buffer for queried values to the current layer. - let mut layer_queried_values = queried_values_by_layer - .peek_take_while(|(log_size, _)| *log_size == layer_log_size) - .collect_vec(); - let n_columns_in_layer = layer_queried_values.len(); + let n_columns_in_layer = *self + .n_columns_per_log_size + .get(&layer_log_size) + .unwrap_or(&0); // Prepare write buffer for queries to the current layer. This will propagate to the // next layer. @@ -132,29 +129,26 @@ impl MerkleVerifier { .transpose()?; // If the column values were queried, read them from `queried_value`. - let node_values = if layer_column_queries.next_if_eq(&node_index).is_some() { - layer_queried_values - .iter_mut() - .map(|(_, ref mut column_queries)| { - column_queries - .next() - .ok_or(MerkleVerificationError::ColumnValuesTooShort) - }) - .collect::, _>>()? - } else { + let (err, node_values_iter) = match layer_column_queries.next_if_eq(&node_index) { + Some(_) => ( + MerkleVerificationError::TooFewQueriedValues, + &mut queried_values, + ), // Otherwise, read them from the witness. - (&mut column_witness).take(n_columns_in_layer).collect_vec() + None => ( + MerkleVerificationError::WitnessTooShort, + &mut column_witness, + ), }; + + let node_values = node_values_iter.take(n_columns_in_layer).collect_vec(); if node_values.len() != n_columns_in_layer { - return Err(MerkleVerificationError::WitnessTooShort); + return Err(err); } layer_total_queries.push((node_index, H::hash_node(node_hashes, &node_values))); } - if !layer_queried_values.iter().all(|(_, c)| c.is_empty()) { - return Err(MerkleVerificationError::ColumnValuesTooLong); - } last_layer_hashes = Some(layer_total_queries); } @@ -162,6 +156,9 @@ impl MerkleVerifier { if !hash_witness.is_empty() { return Err(MerkleVerificationError::WitnessTooLong); } + if !queried_values.is_empty() { + return Err(MerkleVerificationError::TooManyQueriedValues); + } if !column_witness.is_empty() { return Err(MerkleVerificationError::WitnessTooLong); } @@ -175,16 +172,17 @@ impl MerkleVerifier { } } +// TODO(ilya): Make error messages consistent. #[derive(Clone, Copy, Debug, Error, PartialEq, Eq)] pub enum MerkleVerificationError { - #[error("Witness is too short.")] + #[error("Witness is too short")] WitnessTooShort, #[error("Witness is too long.")] WitnessTooLong, - #[error("Column values are too long.")] - ColumnValuesTooLong, - #[error("Column values are too short.")] - ColumnValuesTooShort, + #[error("too many Queried values")] + TooManyQueriedValues, + #[error("too few queried values")] + TooFewQueriedValues, #[error("Root mismatch.")] RootMismatch, }