Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rearrange queried_values_by_layer for merkle. #902

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions crates/prover/src/core/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,9 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {

let mut fri_witness = self.proof.fri_witness.iter().copied();
let mut decommitment_positions_by_log_size = BTreeMap::new();
let mut all_column_decommitment_values = Vec::new();
let mut folded_evals_by_column = Vec::new();

let mut decommitmented_values = vec![];
for (&column_domain, column_query_evals) in
zip_eq(&self.column_commitment_domains, query_evals_by_column)
{
Expand All @@ -722,15 +722,13 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {
decommitment_positions_by_log_size
.insert(column_domain.log_size(), column_decommitment_positions);

// Prepare values in the structure needed for merkle decommitment.
let column_decommitment_values: SecureColumnByCoords<CpuBackend> = sparse_evaluation
.subset_evals
.iter()
.flatten()
.copied()
.collect();

all_column_decommitment_values.extend(column_decommitment_values.columns);
decommitmented_values.extend(
sparse_evaluation
.subset_evals
.iter()
.flatten()
.flat_map(|qm31| qm31.to_m31_array()),
);

let folded_evals = sparse_evaluation.fold_circle(self.folding_alpha, column_domain);
folded_evals_by_column.push(folded_evals);
Expand All @@ -752,7 +750,7 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {
merkle_verifier
.verify(
&decommitment_positions_by_log_size,
all_column_decommitment_values,
decommitmented_values,
self.proof.decommitment.clone(),
)
.map_err(|error| FriVerificationError::FirstLayerCommitmentInvalid { error })?;
Expand Down Expand Up @@ -814,12 +812,12 @@ impl<H: MerkleHasher> FriInnerLayerVerifier<H> {
});
}

let decommitment_values: SecureColumnByCoords<CpuBackend> = sparse_evaluation
let decommitmented_values = sparse_evaluation
.subset_evals
.iter()
.flatten()
.copied()
.collect();
.flat_map(|qm31| qm31.to_m31_array())
.collect_vec();

let merkle_verifier = MerkleVerifier::new(
self.proof.commitment,
Expand All @@ -829,7 +827,7 @@ impl<H: MerkleHasher> FriInnerLayerVerifier<H> {
merkle_verifier
.verify(
&BTreeMap::from_iter([(self.domain.log_size(), decommitment_positions)]),
decommitment_values.columns.to_vec(),
decommitmented_values,
self.proof.decommitment.clone(),
)
.map_err(|e| FriVerificationError::InnerLayerCommitmentInvalid {
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/pcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ pub struct CommitmentSchemeProof<H: MerkleHasher> {
pub commitments: TreeVec<H::Hash>,
pub sampled_values: TreeVec<ColumnVec<Vec<SecureField>>>,
pub decommitments: TreeVec<MerkleDecommitment<H>>,
pub queried_values: TreeVec<ColumnVec<Vec<BaseField>>>,
pub queried_values: TreeVec<Vec<BaseField>>,
pub proof_of_work: u64,
pub fri_proof: FriProof<H>,
}
Expand Down Expand Up @@ -231,7 +231,7 @@ impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
fn decommit(
&self,
queries: &BTreeMap<u32, Vec<usize>>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<MC::H>) {
) -> (Vec<BaseField>, MerkleDecommitment<MC::H>) {
let eval_vec = self
.evaluations
.iter()
Expand Down
46 changes: 24 additions & 22 deletions crates/prover/src/core/pcs/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::iter::zip;
use itertools::{izip, multiunzip, Itertools};
use tracing::{span, Level};

use super::TreeVec;
use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -100,25 +101,30 @@ pub fn compute_fri_quotients<B: QuotientOps>(
}

pub fn fri_answers(
column_log_sizes: Vec<u32>,
samples: &[Vec<PointSample>],
column_log_sizes: TreeVec<Vec<u32>>,
samples: TreeVec<Vec<Vec<PointSample>>>,
random_coeff: SecureField,
query_positions_per_log_size: &BTreeMap<u32, Vec<usize>>,
queried_values_per_column: &[Vec<BaseField>],
queried_values: TreeVec<Vec<BaseField>>,
n_columns_per_log_size: TreeVec<&BTreeMap<u32, usize>>,
) -> Result<ColumnVec<Vec<SecureField>>, VerificationError> {
izip!(column_log_sizes, samples, queried_values_per_column)
let mut queried_values = queried_values.map(|values| values.into_iter());

izip!(column_log_sizes.flatten(), samples.flatten().iter())
.sorted_by_key(|(log_size, ..)| Reverse(*log_size))
.group_by(|(log_size, ..)| *log_size)
.into_iter()
.map(|(log_size, tuples)| {
let (_, samples, queried_values_per_column): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(tuples);
let (_, samples): (Vec<_>, Vec<_>) = multiunzip(tuples);
fri_answers_for_log_size(
log_size,
&samples,
random_coeff,
&query_positions_per_log_size[&log_size],
&queried_values_per_column,
&mut queried_values,
n_columns_per_log_size
.as_ref()
.map(|colums_log_sizes| *colums_log_sizes.get(&log_size).unwrap_or(&0)),
)
})
.collect()
Expand All @@ -129,27 +135,23 @@ pub fn fri_answers_for_log_size(
samples: &[&Vec<PointSample>],
random_coeff: SecureField,
query_positions: &[usize],
queried_values_per_column: &[&Vec<BaseField>],
queried_values: &mut TreeVec<impl Iterator<Item = BaseField>>,
n_columns: TreeVec<usize>,
) -> Result<Vec<SecureField>, VerificationError> {
for queried_values in queried_values_per_column {
if queried_values.len() != query_positions.len() {
return Err(VerificationError::InvalidStructure(
"Insufficient number of queried values".to_string(),
));
}
}

let sample_batches = ColumnSampleBatch::new_vec(samples);
let quotient_constants = quotient_constants(&sample_batches, random_coeff);
let commitment_domain = CanonicCoset::new(log_size).circle_domain();
let mut quotient_evals_at_queries = Vec::new();

for (row, &query_position) in query_positions.iter().enumerate() {
let mut quotient_evals_at_queries = Vec::new();
for &query_position in query_positions {
let domain_point = commitment_domain.at(bit_reverse_index(query_position, log_size));
let queried_values_at_row = queried_values_per_column
.iter()
.map(|col| col[row])
.collect_vec();

let queried_values_at_row = queried_values
.as_mut()
.zip_eq(n_columns.as_ref())
.map(|(queried_values, n_columns)| queried_values.take(*n_columns).collect())
.flatten();

quotient_evals_at_queries.push(accumulate_row_quotients(
&sample_batches,
&queried_values_at_row,
Expand Down
18 changes: 10 additions & 8 deletions crates/prover/src/core/pcs/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,23 @@ impl<MC: MerkleChannel> CommitmentSchemeVerifier<MC> {
.collect::<Result<_, _>>()?;

// Answer FRI queries.
let samples = sampled_points
.zip_cols(proof.sampled_values)
.map_cols(|(sampled_points, sampled_values)| {
let samples = sampled_points.zip_cols(proof.sampled_values).map_cols(
|(sampled_points, sampled_values)| {
zip(sampled_points, sampled_values)
.map(|(point, value)| PointSample { point, value })
.collect_vec()
})
.flatten();
},
);

let n_columns_per_log_size = self.trees.as_ref().map(|tree| &tree.n_columns_per_log_size);

let fri_answers = fri_answers(
self.column_log_sizes().flatten().into_iter().collect(),
&samples,
self.column_log_sizes(),
samples,
random_coeff,
&query_positions_per_log_size,
&proof.queried_values.flatten(),
proof.queried_values,
n_columns_per_log_size,
)?;

fri_verifier.decommit(fri_answers)?;
Expand Down
10 changes: 5 additions & 5 deletions crates/prover/src/core/vcs/blake2_merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ mod tests {
#[test]
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3][2] = BaseField::zero();
values[6] = BaseField::zero();

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
Expand Down Expand Up @@ -119,22 +119,22 @@ mod tests {
#[test]
fn test_merkle_column_values_too_long() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3].push(BaseField::zero());
values.insert(3, BaseField::zero());

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooLong
MerkleVerificationError::TooManyQueriedValues
);
}

#[test]
fn test_merkle_column_values_too_short() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3].pop();
values.remove(3);

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooShort
MerkleVerificationError::TooFewQueriedValues
);
}

Expand Down
14 changes: 7 additions & 7 deletions crates/prover/src/core/vcs/poseidon252_merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ mod tests {
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3][2] = BaseField::zero();
values[6] = BaseField::zero();

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
Expand Down Expand Up @@ -147,26 +147,26 @@ mod tests {
}

#[test]
fn test_merkle_column_values_too_long() {
fn test_merkle_values_too_long() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3].push(BaseField::zero());
values.insert(3, BaseField::zero());

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooLong
MerkleVerificationError::TooManyQueriedValues
);
}

#[test]
fn test_merkle_column_values_too_short() {
fn test_merkle_values_too_short() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3].pop();
values.remove(3);

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooShort
MerkleVerificationError::TooFewQueriedValues
);
}
}
51 changes: 6 additions & 45 deletions crates/prover/src/core/vcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use super::utils::{next_decommitment_node, option_flatten_peekable};
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::utils::PeekableExt;
use crate::core::ColumnVec;

pub struct MerkleProver<B: MerkleOps<H>, H: MerkleHasher> {
/// Layers of the Merkle tree.
Expand Down Expand Up @@ -48,6 +47,7 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
.into_iter()
.sorted_by_key(|c| Reverse(c.len()))
.peekable();

let mut layers: Vec<Col<B, H::Hash>> = Vec::new();

let max_log_size = columns.peek().unwrap().len().ilog2();
Expand Down Expand Up @@ -75,15 +75,16 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
/// # Returns
///
/// A tuple containing:
/// * A vector of vectors of queried values for each column, in the order of the input columns.
/// * A vector queried values sorted by the order they were queried from the largest layer to
/// the smallest.
/// * A `MerkleDecommitment` containing the hash and column witnesses.
pub fn decommit(
&self,
queries_per_log_size: &BTreeMap<u32, Vec<usize>>,
columns: Vec<&Col<B, BaseField>>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<H>) {
) -> (Vec<BaseField>, MerkleDecommitment<H>) {
// Prepare output buffers.
let mut queried_values_by_layer = vec![];
let mut queried_values = vec![];
let mut decommitment = MerkleDecommitment::empty();

// Sort columns by layer.
Expand All @@ -94,9 +95,6 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {

let mut last_layer_queries = vec![];
for layer_log_size in (0..self.layers.len() as u32).rev() {
// Prepare write buffer for queried values to the current layer.
let mut layer_queried_values = vec![];

// Prepare write buffer for queries to the current layer. This will propagate to the
// next layer.
let mut layer_total_queries = vec![];
Expand Down Expand Up @@ -140,7 +138,7 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
// If the column values were queried, return them.
let node_values = layer_columns.iter().map(|c| c.at(node_index));
if layer_column_queries.next_if_eq(&node_index).is_some() {
layer_queried_values.push(node_values.collect_vec());
queried_values.extend(node_values);
} else {
// Otherwise, add them to the witness.
decommitment.column_witness.extend(node_values);
Expand All @@ -149,50 +147,13 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
layer_total_queries.push(node_index);
}

queried_values_by_layer.push(layer_queried_values);

// Propagate queries to the next layer.
last_layer_queries = layer_total_queries;
}
queried_values_by_layer.reverse();

// Rearrange returned queried values according to input, and not by layer.
let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns);

(queried_values, decommitment)
}

/// Given queried values by layer, rearranges in the order of input columns.
fn rearrange_queried_values(
queried_values_by_layer: Vec<Vec<Vec<BaseField>>>,
columns: Vec<&Col<B, BaseField>>,
) -> Vec<Vec<BaseField>> {
// Turn each column queried values into an iterator.
let mut queried_values_by_layer = queried_values_by_layer
.into_iter()
.map(|layer_results| {
layer_results
.into_iter()
.map(|x| x.into_iter())
.collect_vec()
})
.collect_vec();

// For each input column, fetch the queried values from the corresponding layer.
let queried_values = columns
.iter()
.map(|column| {
queried_values_by_layer
.get_mut(column.len().ilog2() as usize)
.unwrap()
.iter_mut()
.map(|x| x.next().unwrap())
.collect_vec()
})
.collect_vec();
queried_values
}

pub fn root(&self) -> H::Hash {
self.layers.first().unwrap().at(0)
}
Expand Down
Loading
Loading