Skip to content

Commit

Permalink
rearrange queried_values_by_layer for merkle.
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware committed Nov 25, 2024
1 parent af9250e commit 3e31dca
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 87 deletions.
44 changes: 41 additions & 3 deletions crates/prover/src/core/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::fmt::Debug;
use std::iter::zip;
use std::ops::RangeInclusive;

use itertools::{zip_eq, Itertools};
use itertools::{izip, zip_eq, Itertools};
use num_traits::Zero;
use serde::{Deserialize, Serialize};
use thiserror::Error;
Expand Down Expand Up @@ -738,6 +738,30 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {
return Err(FriVerificationError::FirstLayerEvaluationsInvalid);
}

// assert_eq!(all_column_decommitment_values.len(), self.column_commitment_domains.len());

let mut column_iters = all_column_decommitment_values
.iter()
.map(|column| column.iter())
.collect_vec();

let mut decommit_by_log_size = vec![Vec::new(); (max_column_log_size + 1) as usize];
let mut done = false;
while !done {
done = true;
for (log_size, column_iter) in zip_eq(
self.column_commitment_domains
.iter()
.flat_map(|column_domain| [column_domain.log_size(); SECURE_EXTENSION_DEGREE]),
column_iters.iter_mut(),
) {
if let Some(value) = column_iter.next() {
decommit_by_log_size[log_size as usize].push(*value);
done = false;
}
}
}

let merkle_verifier = MerkleVerifier::new(
self.proof.commitment,
self.column_commitment_domains
Expand All @@ -749,7 +773,7 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {
merkle_verifier
.verify(
&decommitment_positions_by_log_size,
all_column_decommitment_values,
decommit_by_log_size,
self.proof.decommitment.clone(),
)
.map_err(|error| FriVerificationError::FirstLayerCommitmentInvalid { error })?;
Expand Down Expand Up @@ -823,10 +847,24 @@ impl<H: MerkleHasher> FriInnerLayerVerifier<H> {
vec![self.domain.log_size(); SECURE_EXTENSION_DEGREE],
);

let mut decommit_by_log_size = vec![Vec::new(); (self.domain.log_size() + 1) as usize];

for (a, b, c, d) in izip!(
&decommitment_values.columns[0],
&decommitment_values.columns[1],
&decommitment_values.columns[2],
&decommitment_values.columns[3]
) {
decommit_by_log_size[self.domain.log_size() as usize].push(*a);
decommit_by_log_size[self.domain.log_size() as usize].push(*b);
decommit_by_log_size[self.domain.log_size() as usize].push(*c);
decommit_by_log_size[self.domain.log_size() as usize].push(*d);
}

merkle_verifier
.verify(
&BTreeMap::from_iter([(self.domain.log_size(), decommitment_positions)]),
decommitment_values.columns.to_vec(),
decommit_by_log_size,
self.proof.decommitment.clone(),
)
.map_err(|e| FriVerificationError::InnerLayerCommitmentInvalid {
Expand Down
47 changes: 27 additions & 20 deletions crates/prover/src/core/pcs/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::iter::zip;
use itertools::{izip, multiunzip, Itertools};
use tracing::{span, Level};

use super::TreeVec;
use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -104,21 +105,29 @@ pub fn fri_answers(
samples: &[Vec<PointSample>],
random_coeff: SecureField,
query_positions_per_log_size: &BTreeMap<u32, Vec<usize>>,
queried_values_per_column: &[Vec<BaseField>],
mut queried_values_per_layer: TreeVec<Vec<Vec<BaseField>>>,
mut columns_per_log_size: TreeVec<BTreeMap<u32, usize>>,
) -> Result<ColumnVec<Vec<SecureField>>, VerificationError> {
izip!(column_log_sizes, samples, queried_values_per_column)
izip!(column_log_sizes, samples)
.sorted_by_key(|(log_size, ..)| Reverse(*log_size))
.group_by(|(log_size, ..)| *log_size)
.into_iter()
.map(|(log_size, tuples)| {
let (_, samples, queried_values_per_column): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(tuples);
let (_, samples): (Vec<_>, Vec<_>) = multiunzip(tuples);
fri_answers_for_log_size(
log_size,
&samples,
random_coeff,
&query_positions_per_log_size[&log_size],
&queried_values_per_column,
queried_values_per_layer.as_mut().map(|queried_values| {
match queried_values.get_mut(log_size as usize) {
Some(queried_values) => std::mem::take(queried_values).into_iter(),
None => (vec![]).into_iter(),
}
}),
columns_per_log_size
.as_mut()
.map(|colums_log_sizes| *colums_log_sizes.get(&log_size).unwrap_or(&0)),
)
})
.collect()
Expand All @@ -129,27 +138,25 @@ pub fn fri_answers_for_log_size(
samples: &[&Vec<PointSample>],
random_coeff: SecureField,
query_positions: &[usize],
queried_values_per_column: &[&Vec<BaseField>],
mut queried_values_per_layer: TreeVec<std::vec::IntoIter<BaseField>>,
n_columns: TreeVec<usize>,
) -> Result<Vec<SecureField>, VerificationError> {
for queried_values in queried_values_per_column {
if queried_values.len() != query_positions.len() {
return Err(VerificationError::InvalidStructure(
"Insufficient number of queried values".to_string(),
));
}
}

let sample_batches = ColumnSampleBatch::new_vec(samples);
let quotient_constants = quotient_constants(&sample_batches, random_coeff);
let commitment_domain = CanonicCoset::new(log_size).circle_domain();
let mut quotient_evals_at_queries = Vec::new();

for (row, &query_position) in query_positions.iter().enumerate() {
let mut quotient_evals_at_queries = Vec::new();
for &query_position in query_positions {
let domain_point = commitment_domain.at(bit_reverse_index(query_position, log_size));
let queried_values_at_row = queried_values_per_column
.iter()
.map(|col| col[row])
.collect_vec();

let queried_values_at_row = queried_values_per_layer
.as_mut()
.zip_eq(n_columns.as_ref())
.map(|(queried_values, n_columns)| {
queried_values.take(*n_columns).collect()
})
.flatten();

quotient_evals_at_queries.push(accumulate_row_quotients(
&sample_batches,
&queried_values_at_row,
Expand Down
14 changes: 13 additions & 1 deletion crates/prover/src/core/pcs/verifier.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::iter::zip;

use itertools::Itertools;
Expand Down Expand Up @@ -86,6 +87,16 @@ impl<MC: MerkleChannel> CommitmentSchemeVerifier<MC> {
// Get FRI query positions.
let query_positions_per_log_size = fri_verifier.sample_query_positions(channel);

let colums_per_log_size = self.column_log_sizes().map(|log_sizes| {
let mut columns_per_layer = BTreeMap::new();

for log_size in &log_sizes {
*columns_per_layer.entry(*log_size).or_default() += 1;
}

columns_per_layer
});

// Verify merkle decommitments.
self.trees
.as_ref()
Expand Down Expand Up @@ -113,7 +124,8 @@ impl<MC: MerkleChannel> CommitmentSchemeVerifier<MC> {
&samples,
random_coeff,
&query_positions_per_log_size,
&proof.queried_values.flatten(),
proof.queried_values,
colums_per_log_size,
)?;

fri_verifier.decommit(fri_answers)?;
Expand Down
42 changes: 6 additions & 36 deletions crates/prover/src/core/vcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
};
}

assert!(columns.is_sorted_by_key(|c| Reverse(c.len())));

let columns = &mut columns
.into_iter()
.sorted_by_key(|c| Reverse(c.len()))
.peekable();

let mut layers: Vec<Col<B, H::Hash>> = Vec::new();

let max_log_size = columns.peek().unwrap().len().ilog2();
Expand Down Expand Up @@ -140,7 +143,7 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
// If the column values were queried, return them.
let node_values = layer_columns.iter().map(|c| c.at(node_index));
if layer_column_queries.next_if_eq(&node_index).is_some() {
layer_queried_values.push(node_values.collect_vec());
layer_queried_values.extend(node_values);
} else {
// Otherwise, add them to the witness.
decommitment.column_witness.extend(node_values);
Expand All @@ -154,43 +157,10 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
// Propagate queries to the next layer.
last_layer_queries = layer_total_queries;
}
queried_values_by_layer.reverse();

// Rearrange returned queried values according to input, and not by layer.
let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns);

(queried_values, decommitment)
}
queried_values_by_layer.reverse();

/// Given queried values by layer, rearranges in the order of input columns.
fn rearrange_queried_values(
queried_values_by_layer: Vec<Vec<Vec<BaseField>>>,
columns: Vec<&Col<B, BaseField>>,
) -> Vec<Vec<BaseField>> {
// Turn each column queried values into an iterator.
let mut queried_values_by_layer = queried_values_by_layer
.into_iter()
.map(|layer_results| {
layer_results
.into_iter()
.map(|x| x.into_iter())
.collect_vec()
})
.collect_vec();

// For each input column, fetch the queried values from the corresponding layer.
let queried_values = columns
.iter()
.map(|column| {
queried_values_by_layer
.get_mut(column.len().ilog2() as usize)
.unwrap()
.iter_mut()
.map(|x| x.next().unwrap())
.collect_vec()
})
.collect_vec();
queried_values
(queried_values_by_layer, decommitment)
}

pub fn root(&self) -> H::Hash {
Expand Down
47 changes: 20 additions & 27 deletions crates/prover/src/core/vcs/verifier.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::cmp::Reverse;
use std::collections::BTreeMap;

use itertools::Itertools;
Expand All @@ -9,7 +8,6 @@ use super::prover::MerkleDecommitment;
use super::utils::{next_decommitment_node, option_flatten_peekable};
use crate::core::fields::m31::BaseField;
use crate::core::utils::PeekableExt;
use crate::core::ColumnVec;

pub struct MerkleVerifier<H: MerkleHasher> {
pub root: H::Hash,
Expand Down Expand Up @@ -53,35 +51,33 @@ impl<H: MerkleHasher> MerkleVerifier<H> {
pub fn verify(
&self,
queries_per_log_size: &BTreeMap<u32, Vec<usize>>,
queried_values: ColumnVec<Vec<BaseField>>,
queried_values_by_layer: Vec<Vec<BaseField>>,
decommitment: MerkleDecommitment<H>,
) -> Result<(), MerkleVerificationError> {
let Some(max_log_size) = self.column_log_sizes.iter().max() else {
return Ok(());
};

let mut columns_per_layer = vec![0; *max_log_size as usize + 1];

for log_size in &self.column_log_sizes {
columns_per_layer[*log_size as usize] += 1;
}

assert_eq!(queried_values_by_layer.len(), columns_per_layer.len());
let mut queried_values_by_layer_iter = queried_values_by_layer.into_iter().rev();

// Prepare read buffers.
let mut queried_values_by_layer = self
.column_log_sizes
.iter()
.copied()
.zip(
queried_values
.into_iter()
.map(|column_values| column_values.into_iter()),
)
.sorted_by_key(|(log_size, _)| Reverse(*log_size))
.peekable();

let mut hash_witness = decommitment.hash_witness.into_iter();
let mut column_witness = decommitment.column_witness.into_iter();

let mut last_layer_hashes: Option<Vec<(usize, H::Hash)>> = None;
for layer_log_size in (0..=*max_log_size).rev() {
// Prepare read buffer for queried values to the current layer.
let mut layer_queried_values = queried_values_by_layer
.peek_take_while(|(log_size, _)| *log_size == layer_log_size)
.collect_vec();
let n_columns_in_layer = layer_queried_values.len();
let mut layer_queried_values = queried_values_by_layer_iter.next().unwrap().into_iter();

let n_columns_in_layer = columns_per_layer.pop().unwrap();

// Prepare write buffer for queries to the current layer. This will propagate to the
// next layer.
Expand Down Expand Up @@ -138,26 +134,23 @@ impl<H: MerkleHasher> MerkleVerifier<H> {

// If the column values were queried, read them from `queried_value`.
let node_values = if layer_column_queries.next_if_eq(&node_index).is_some() {
layer_queried_values
.iter_mut()
.map(|(_, ref mut column_queries)| {
column_queries
.next()
.ok_or(MerkleVerificationError::ColumnValuesTooShort)
})
.collect::<Result<Vec<_>, _>>()?
(&mut layer_queried_values)
.take(n_columns_in_layer)
.collect_vec()
} else {
// Otherwise, read them from the witness.
(&mut column_witness).take(n_columns_in_layer).collect_vec()
};

assert_eq!(node_values.len(), n_columns_in_layer);
if node_values.len() != n_columns_in_layer {
return Err(MerkleVerificationError::WitnessTooShort);
}

layer_total_queries.push((node_index, H::hash_node(node_hashes, &node_values)));
}

if !layer_queried_values.iter().all(|(_, c)| c.is_empty()) {
if !layer_queried_values.is_empty() {
return Err(MerkleVerificationError::ColumnValuesTooLong);
}
last_layer_hashes = Some(layer_total_queries);
Expand Down

0 comments on commit 3e31dca

Please sign in to comment.