Skip to content

Commit

Permalink
Rearrange queried_values_by_layer for merkle.
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware committed Nov 28, 2024
1 parent c2ef3ac commit 44550e7
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 157 deletions.
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl TraceLocationAllocator {
}
}

pub fn preprocessed_columns(&self) -> &HashMap<PreprocessedColumn, usize> {
pub const fn preprocessed_columns(&self) -> &HashMap<PreprocessedColumn, usize> {
&self.preprocessed_columns
}

Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ pub struct RelationEntry<'a, F: Clone, EF: RelationEFTraitBound<F>, R: Relation<
values: &'a [F],
}
impl<'a, F: Clone, EF: RelationEFTraitBound<F>, R: Relation<F, EF>> RelationEntry<'a, F, EF, R> {
pub fn new(relation: &'a R, multiplicity: EF, values: &'a [F]) -> Self {
pub const fn new(relation: &'a R, multiplicity: EF, values: &'a [F]) -> Self {
Self {
relation,
multiplicity,
Expand Down
28 changes: 13 additions & 15 deletions crates/prover/src/core/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,9 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {

let mut fri_witness = self.proof.fri_witness.iter().copied();
let mut decommitment_positions_by_log_size = BTreeMap::new();
let mut all_column_decommitment_values = Vec::new();
let mut folded_evals_by_column = Vec::new();

let mut decommitment = vec![];
for (&column_domain, column_query_evals) in
zip_eq(&self.column_commitment_domains, query_evals_by_column)
{
Expand All @@ -722,15 +722,13 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {
decommitment_positions_by_log_size
.insert(column_domain.log_size(), column_decommitment_positions);

// Prepare values in the structure needed for merkle decommitment.
let column_decommitment_values: SecureColumnByCoords<CpuBackend> = sparse_evaluation
.subset_evals
.iter()
.flatten()
.copied()
.collect();

all_column_decommitment_values.extend(column_decommitment_values.columns);
decommitment.extend(
sparse_evaluation
.subset_evals
.iter()
.flatten()
.flat_map(|qm31| qm31.to_m31_array()),
);

let folded_evals = sparse_evaluation.fold_circle(self.folding_alpha, column_domain);
folded_evals_by_column.push(folded_evals);
Expand All @@ -752,7 +750,7 @@ impl<H: MerkleHasher> FriFirstLayerVerifier<H> {
merkle_verifier
.verify(
&decommitment_positions_by_log_size,
all_column_decommitment_values,
decommitment,
self.proof.decommitment.clone(),
)
.map_err(|error| FriVerificationError::FirstLayerCommitmentInvalid { error })?;
Expand Down Expand Up @@ -814,12 +812,12 @@ impl<H: MerkleHasher> FriInnerLayerVerifier<H> {
});
}

let decommitment_values: SecureColumnByCoords<CpuBackend> = sparse_evaluation
let decommitment = sparse_evaluation
.subset_evals
.iter()
.flatten()
.copied()
.collect();
.flat_map(|qm31| qm31.to_m31_array())
.collect_vec();

let merkle_verifier = MerkleVerifier::new(
self.proof.commitment,
Expand All @@ -829,7 +827,7 @@ impl<H: MerkleHasher> FriInnerLayerVerifier<H> {
merkle_verifier
.verify(
&BTreeMap::from_iter([(self.domain.log_size(), decommitment_positions)]),
decommitment_values.columns.to_vec(),
decommitment,
self.proof.decommitment.clone(),
)
.map_err(|e| FriVerificationError::InnerLayerCommitmentInvalid {
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/pcs/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ pub struct CommitmentSchemeProof<H: MerkleHasher> {
pub commitments: TreeVec<H::Hash>,
pub sampled_values: TreeVec<ColumnVec<Vec<SecureField>>>,
pub decommitments: TreeVec<MerkleDecommitment<H>>,
pub queried_values: TreeVec<ColumnVec<Vec<BaseField>>>,
pub queried_values: TreeVec<Vec<BaseField>>,
pub proof_of_work: u64,
pub fri_proof: FriProof<H>,
}
Expand Down Expand Up @@ -231,7 +231,7 @@ impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
fn decommit(
&self,
queries: &BTreeMap<u32, Vec<usize>>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<MC::H>) {
) -> (Vec<BaseField>, MerkleDecommitment<MC::H>) {
let eval_vec = self
.evaluations
.iter()
Expand Down
46 changes: 24 additions & 22 deletions crates/prover/src/core/pcs/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::iter::zip;
use itertools::{izip, multiunzip, Itertools};
use tracing::{span, Level};

use super::TreeVec;
use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -100,25 +101,30 @@ pub fn compute_fri_quotients<B: QuotientOps>(
}

pub fn fri_answers(
column_log_sizes: Vec<u32>,
samples: &[Vec<PointSample>],
column_log_sizes: TreeVec<Vec<u32>>,
samples: TreeVec<Vec<Vec<PointSample>>>,
random_coeff: SecureField,
query_positions_per_log_size: &BTreeMap<u32, Vec<usize>>,
queried_values_per_column: &[Vec<BaseField>],
queried_values: TreeVec<Vec<BaseField>>,
n_columns_per_log_size: TreeVec<&BTreeMap<u32, usize>>,
) -> Result<ColumnVec<Vec<SecureField>>, VerificationError> {
izip!(column_log_sizes, samples, queried_values_per_column)
let mut queried_values = queried_values.map(|values| values.into_iter());

izip!(column_log_sizes.flatten(), samples.flatten().iter())
.sorted_by_key(|(log_size, ..)| Reverse(*log_size))
.group_by(|(log_size, ..)| *log_size)
.into_iter()
.map(|(log_size, tuples)| {
let (_, samples, queried_values_per_column): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(tuples);
let (_, samples): (Vec<_>, Vec<_>) = multiunzip(tuples);
fri_answers_for_log_size(
log_size,
&samples,
random_coeff,
&query_positions_per_log_size[&log_size],
&queried_values_per_column,
&mut queried_values,
n_columns_per_log_size
.as_ref()
.map(|colums_log_sizes| *colums_log_sizes.get(&log_size).unwrap_or(&0)),
)
})
.collect()
Expand All @@ -129,27 +135,23 @@ pub fn fri_answers_for_log_size(
samples: &[&Vec<PointSample>],
random_coeff: SecureField,
query_positions: &[usize],
queried_values_per_column: &[&Vec<BaseField>],
queried_values: &mut TreeVec<impl Iterator<Item = BaseField>>,
n_columns: TreeVec<usize>,
) -> Result<Vec<SecureField>, VerificationError> {
for queried_values in queried_values_per_column {
if queried_values.len() != query_positions.len() {
return Err(VerificationError::InvalidStructure(
"Insufficient number of queried values".to_string(),
));
}
}

let sample_batches = ColumnSampleBatch::new_vec(samples);
let quotient_constants = quotient_constants(&sample_batches, random_coeff);
let commitment_domain = CanonicCoset::new(log_size).circle_domain();
let mut quotient_evals_at_queries = Vec::new();

for (row, &query_position) in query_positions.iter().enumerate() {
let mut quotient_evals_at_queries = Vec::new();
for &query_position in query_positions {
let domain_point = commitment_domain.at(bit_reverse_index(query_position, log_size));
let queried_values_at_row = queried_values_per_column
.iter()
.map(|col| col[row])
.collect_vec();

let queried_values_at_row = queried_values
.as_mut()
.zip_eq(n_columns.as_ref())
.map(|(queried_values, n_columns)| queried_values.take(*n_columns).collect())
.flatten();

quotient_evals_at_queries.push(accumulate_row_quotients(
&sample_batches,
&queried_values_at_row,
Expand Down
18 changes: 10 additions & 8 deletions crates/prover/src/core/pcs/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,23 @@ impl<MC: MerkleChannel> CommitmentSchemeVerifier<MC> {
.collect::<Result<_, _>>()?;

// Answer FRI queries.
let samples = sampled_points
.zip_cols(proof.sampled_values)
.map_cols(|(sampled_points, sampled_values)| {
let samples = sampled_points.zip_cols(proof.sampled_values).map_cols(
|(sampled_points, sampled_values)| {
zip(sampled_points, sampled_values)
.map(|(point, value)| PointSample { point, value })
.collect_vec()
})
.flatten();
},
);

let n_columns_per_log_size = self.trees.as_ref().map(|tree| &tree.n_columns_per_log_size);

let fri_answers = fri_answers(
self.column_log_sizes().flatten().into_iter().collect(),
&samples,
self.column_log_sizes(),
samples,
random_coeff,
&query_positions_per_log_size,
&proof.queried_values.flatten(),
proof.queried_values,
n_columns_per_log_size,
)?;

fri_verifier.decommit(fri_answers)?;
Expand Down
10 changes: 5 additions & 5 deletions crates/prover/src/core/vcs/blake2_merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ mod tests {
#[test]
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3][2] = BaseField::zero();
values[6] = BaseField::zero();

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
Expand Down Expand Up @@ -119,22 +119,22 @@ mod tests {
#[test]
fn test_merkle_column_values_too_long() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3].push(BaseField::zero());
values.insert(3, BaseField::zero());

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooLong
MerkleVerificationError::TooManyQueriedValues
);
}

#[test]
fn test_merkle_column_values_too_short() {
let (queries, decommitment, mut values, verifier) = prepare_merkle::<Blake2sMerkleHasher>();
values[3].pop();
values.remove(3);

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooShort
MerkleVerificationError::TooFewQueriedValues
);
}

Expand Down
14 changes: 7 additions & 7 deletions crates/prover/src/core/vcs/poseidon252_merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ mod tests {
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3][2] = BaseField::zero();
values[6] = BaseField::zero();

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
Expand Down Expand Up @@ -147,26 +147,26 @@ mod tests {
}

#[test]
fn test_merkle_column_values_too_long() {
fn test_merkle_values_too_long() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3].push(BaseField::zero());
values.insert(3, BaseField::zero());

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooLong
MerkleVerificationError::TooManyQueriedValues
);
}

#[test]
fn test_merkle_column_values_too_short() {
fn test_merkle_values_too_short() {
let (queries, decommitment, mut values, verifier) =
prepare_merkle::<Poseidon252MerkleHasher>();
values[3].pop();
values.remove(3);

assert_eq!(
verifier.verify(&queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooShort
MerkleVerificationError::TooFewQueriedValues
);
}
}
Loading

0 comments on commit 44550e7

Please sign in to comment.