Skip to content

Commit

Permalink
Rearrange queried_values_by_layer for merkle.
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware committed Dec 2, 2024
1 parent 4af2d44 commit cad8b62
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 155 deletions.
28 changes: 13 additions & 15 deletions crates/prover/src/core/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,9 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {

let mut fri_witness = self.proof.fri_witness.iter().copied();
let mut decommitment_positions_by_log_size = BTreeMap::new();
let mut all_column_decommitment_values = Vec::new();
let mut folded_evals_by_column = Vec::new();

let mut decommitmented_values = vec![];
for (&column_domain, column_query_evals) in
zip_eq(&self.column_commitment_domains, query_evals_by_column)
{
Expand All @@ -722,15 +722,13 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {
decommitment_positions_by_log_size
.insert(column_domain.log_size(), column_decommitment_positions);

// Prepare values in the structure needed for merkle decommitment.
let column_decommitment_values: SecureColumnByCoords<CpuBackend> = sparse_evaluation
.subset_evals
.iter()
.flatten()
.copied()
.collect();

all_column_decommitment_values.extend(column_decommitment_values.columns);
decommitmented_values.extend(
sparse_evaluation
.subset_evals
.iter()
.flatten()
.flat_map(|qm31| qm31.to_m31_array()),
);

let folded_evals = sparse_evaluation.fold_circle(self.folding_alpha, column_domain);
folded_evals_by_column.push(folded_evals);
Expand All @@ -752,7 +750,7 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {
merkle_verifier
.verify(
&decommitment_positions_by_log_size,
all_column_decommitment_values,
decommitmented_values,
self.proof.decommitment.clone(),
)
.map_err(|error| FriVerificationError::FirstLayerCommitmentInvalid { error })?;
Expand Down Expand Up @@ -814,12 +812,12 @@ impl<H: MerkleHasher> FriInnerLayerVerifier<H> {
});
}

let decommitment_values: SecureColumnByCoords<CpuBackend> = sparse_evaluation
let decommitmented_values = sparse_evaluation
.subset_evals
.iter()
.flatten()
.copied()
.collect();
.flat_map(|qm31| qm31.to_m31_array())
.collect_vec();

let merkle_verifier = MerkleVerifier::new(
self.proof.commitment,
Expand All @@ -829,7 +827,7 @@ impl<H: MerkleHasher> FriInnerLayerVerifier<H> {
merkle_verifier
.verify(
&BTreeMap::from_iter([(self.domain.log_size(), decommitment_positions)]),
decommitment_values.columns.to_vec(),
decommitmented_values,
self.proof.decommitment.clone(),
)
.map_err(|e| FriVerificationError::InnerLayerCommitmentInvalid {
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/pcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ pub struct CommitmentSchemeProof<H: MerkleHasher> {
pub commitments: TreeVec<H::Hash>,
pub sampled_values: TreeVec<ColumnVec<Vec<SecureField>>>,
pub decommitments: TreeVec<MerkleDecommitment<H>>,
pub queried_values: TreeVec<ColumnVec<Vec<BaseField>>>,
pub queried_values: TreeVec<Vec<BaseField>>,
pub proof_of_work: u64,
pub fri_proof: FriProof<H>,
}
Expand Down Expand Up @@ -231,7 +231,7 @@ impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
fn decommit(
&self,
queries: &BTreeMap<u32, Vec<usize>>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<MC::H>) {
) -> (Vec<BaseField>, MerkleDecommitment<MC::H>) {
let eval_vec = self
.evaluations
.iter()
Expand Down
46 changes: 24 additions & 22 deletions crates/prover/src/core/pcs/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::iter::zip;
use itertools::{izip, multiunzip, Itertools};
use tracing::{span, Level};

use super::TreeVec;
use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -100,25 +101,30 @@ pub fn compute_fri_quotients<B: QuotientOps>(
}

pub fn fri_answers(
column_log_sizes: Vec<u32>,
samples: &[Vec<PointSample>],
column_log_sizes: TreeVec<Vec<u32>>,
samples: TreeVec<Vec<Vec<PointSample>>>,
random_coeff: SecureField,
query_positions_per_log_size: &BTreeMap<u32, Vec<usize>>,
queried_values_per_column: &[Vec<BaseField>],
queried_values: TreeVec<Vec<BaseField>>,
n_columns_per_log_size: TreeVec<&BTreeMap<u32, usize>>,
) -> Result<ColumnVec<Vec<SecureField>>, VerificationError> {
izip!(column_log_sizes, samples, queried_values_per_column)
let mut queried_values = queried_values.map(|values| values.into_iter());

izip!(column_log_sizes.flatten(), samples.flatten().iter())
.sorted_by_key(|(log_size, ..)| Reverse(*log_size))
.group_by(|(log_size, ..)| *log_size)
.into_iter()
.map(|(log_size, tuples)| {
let (_, samples, queried_values_per_column): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(tuples);
let (_, samples): (Vec<_>, Vec<_>) = multiunzip(tuples);
fri_answers_for_log_size(
log_size,
&samples,
random_coeff,
&query_positions_per_log_size[&log_size],
&queried_values_per_column,
&mut queried_values,
n_columns_per_log_size
.as_ref()
.map(|colums_log_sizes| *colums_log_sizes.get(&log_size).unwrap_or(&0)),
)
})
.collect()
Expand All @@ -129,27 +135,23 @@ pub fn fri_answers_for_log_size(
samples: &[&Vec<PointSample>],
random_coeff: SecureField,
query_positions: &[usize],
queried_values_per_column: &[&Vec<BaseField>],
queried_values: &mut TreeVec<impl Iterator<Item = BaseField>>,
n_columns: TreeVec<usize>,
) -> Result<Vec<SecureField>, VerificationError> {
for queried_values in queried_values_per_column {
if queried_values.len() != query_positions.len() {
return Err(VerificationError::InvalidStructure(
"Insufficient number of queried values".to_string(),
));
}
}

let sample_batches = ColumnSampleBatch::new_vec(samples);
let quotient_constants = quotient_constants(&sample_batches, random_coeff);
let commitment_domain = CanonicCoset::new(log_size).circle_domain();
let mut quotient_evals_at_queries = Vec::new();

for (row, &query_position) in query_positions.iter().enumerate() {
let mut quotient_evals_at_queries = Vec::new();
for &query_position in query_positions {
let domain_point = commitment_domain.at(bit_reverse_index(query_position, log_size));
let queried_values_at_row = queried_values_per_column
.iter()
.map(|col| col[row])
.collect_vec();

let queried_values_at_row = queried_values
.as_mut()
.zip_eq(n_columns.as_ref())
.map(|(queried_values, n_columns)| queried_values.take(*n_columns).collect())
.flatten();

quotient_evals_at_queries.push(accumulate_row_quotients(
&sample_batches,
&queried_values_at_row,
Expand Down
18 changes: 10 additions & 8 deletions crates/prover/src/core/pcs/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,23 @@ impl<MC: MerkleChannel> CommitmentSchemeVerifier<MC> {
.collect::<Result<_, _>>()?;

// Answer FRI queries.
let samples = sampled_points
.zip_cols(proof.sampled_values)
.map_cols(|(sampled_points, sampled_values)| {
let samples = sampled_points.zip_cols(proof.sampled_values).map_cols(
|(sampled_points, sampled_values)| {
zip(sampled_points, sampled_values)
.map(|(point, value)| PointSample { point, value })
.collect_vec()
})
.flatten();
},
);

let n_columns_per_log_size = self.trees.as_ref().map(|tree| &tree.n_columns_per_log_size);

let fri_answers = fri_answers(
self.column_log_sizes().flatten().into_iter().collect(),
&samples,
self.column_log_sizes(),
samples,
random_coeff,
&query_positions_per_log_size,
&proof.queried_values.flatten(),
proof.queried_values,
n_columns_per_log_size,
)?;

fri_verifier.decommit(fri_answers)?;
Expand Down
10 changes: 5 additions & 5 deletions crates/prover/src/core/vcs/blake2_merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ mod tests {
#[test]
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3][2] = BaseField::zero();
values[6] = BaseField::zero();

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
Expand Down Expand Up @@ -119,22 +119,22 @@ mod tests {
#[test]
fn test_merkle_column_values_too_long() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3].push(BaseField::zero());
values.insert(3, BaseField::zero());

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooLong
MerkleVerificationError::TooManyQueriedValues
);
}

#[test]
fn test_merkle_column_values_too_short() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3].pop();
values.remove(3);

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooShort
MerkleVerificationError::TooFewQueriedValues
);
}

Expand Down
14 changes: 7 additions & 7 deletions crates/prover/src/core/vcs/poseidon252_merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ mod tests {
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3][2] = BaseField::zero();
values[6] = BaseField::zero();

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
Expand Down Expand Up @@ -147,26 +147,26 @@ mod tests {
}

#[test]
fn test_merkle_column_values_too_long() {
fn test_merkle_values_too_long() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3].push(BaseField::zero());
values.insert(3, BaseField::zero());

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooLong
MerkleVerificationError::TooManyQueriedValues
);
}

#[test]
fn test_merkle_column_values_too_short() {
fn test_merkle_values_too_short() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3].pop();
values.remove(3);

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooShort
MerkleVerificationError::TooFewQueriedValues
);
}
}
51 changes: 6 additions & 45 deletions crates/prover/src/core/vcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use super::utils::{next_decommitment_node, option_flatten_peekable};
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::utils::PeekableExt;
use crate::core::ColumnVec;

pub struct MerkleProver<B: MerkleOps<H>, H: MerkleHasher> {
/// Layers of the Merkle tree.
Expand Down Expand Up @@ -48,6 +47,7 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
.into_iter()
.sorted_by_key(|c| Reverse(c.len()))
.peekable();

let mut layers: Vec<Col<B, H::Hash>> = Vec::new();

let max_log_size = columns.peek().unwrap().len().ilog2();
Expand Down Expand Up @@ -75,15 +75,16 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
/// # Returns
///
/// A tuple containing:
/// * A vector of vectors of queried values for each column, in the order of the input columns.
/// * A vector queried values sorted by the order they were queried from the largest layer to
/// the smallest.
/// * A `MerkleDecommitment` containing the hash and column witnesses.
pub fn decommit(
&self,
queries_per_log_size: &BTreeMap<u32, Vec<usize>>,
columns: Vec<&Col<B, BaseField>>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<H>) {
) -> (Vec<BaseField>, MerkleDecommitment<H>) {
// Prepare output buffers.
let mut queried_values_by_layer = vec![];
let mut queried_values = vec![];
let mut decommitment = MerkleDecommitment::empty();

// Sort columns by layer.
Expand All @@ -94,9 +95,6 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {

let mut last_layer_queries = vec![];
for layer_log_size in (0..self.layers.len() as u32).rev() {
// Prepare write buffer for queried values to the current layer.
let mut layer_queried_values = vec![];

// Prepare write buffer for queries to the current layer. This will propagate to the
// next layer.
let mut layer_total_queries = vec![];
Expand Down Expand Up @@ -140,7 +138,7 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
// If the column values were queried, return them.
let node_values = layer_columns.iter().map(|c| c.at(node_index));
if layer_column_queries.next_if_eq(&node_index).is_some() {
layer_queried_values.push(node_values.collect_vec());
queried_values.extend(node_values);
} else {
// Otherwise, add them to the witness.
decommitment.column_witness.extend(node_values);
Expand All @@ -149,50 +147,13 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
layer_total_queries.push(node_index);
}

queried_values_by_layer.push(layer_queried_values);

// Propagate queries to the next layer.
last_layer_queries = layer_total_queries;
}
queried_values_by_layer.reverse();

// Rearrange returned queried values according to input, and not by layer.
let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns);

(queried_values, decommitment)
}

/// Given queried values by layer, rearranges in the order of input columns.
fn rearrange_queried_values(
queried_values_by_layer: Vec<Vec<Vec<BaseField>>>,
columns: Vec<&Col<B, BaseField>>,
) -> Vec<Vec<BaseField>> {
// Turn each column queried values into an iterator.
let mut queried_values_by_layer = queried_values_by_layer
.into_iter()
.map(|layer_results| {
layer_results
.into_iter()
.map(|x| x.into_iter())
.collect_vec()
})
.collect_vec();

// For each input column, fetch the queried values from the corresponding layer.
let queried_values = columns
.iter()
.map(|column| {
queried_values_by_layer
.get_mut(column.len().ilog2() as usize)
.unwrap()
.iter_mut()
.map(|x| x.next().unwrap())
.collect_vec()
})
.collect_vec();
queried_values
}

pub fn root(&self) -> H::Hash {
self.layers.first().unwrap().at(0)
}
Expand Down
Loading

0 comments on commit cad8b62

Please sign in to comment.