diff --git a/crates/prover/src/core/air/air_ext.rs b/crates/prover/src/core/air/air_ext.rs index e870a90fe..cf31b72b9 100644 --- a/crates/prover/src/core/air/air_ext.rs +++ b/crates/prover/src/core/air/air_ext.rs @@ -10,7 +10,7 @@ use crate::core::pcs::{CommitmentTreeProver, TreeVec}; use crate::core::poly::circle::SecureCirclePoly; use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; use crate::core::vcs::ops::MerkleOps; -use crate::core::{ColumnVec, ComponentVec, InteractionElements}; +use crate::core::{ColumnVec, InteractionElements}; pub trait AirExt: Air { fn composition_log_degree_bound(&self) -> u32 { @@ -55,12 +55,12 @@ pub trait AirExt: Air { fn eval_composition_polynomial_at_point( &self, point: CirclePoint, - mask_values: &ComponentVec>, + mask_values: &Vec>>>, random_coeff: SecureField, interaction_elements: &InteractionElements, ) -> SecureField { let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff); - zip_eq(self.components(), &mask_values.0).for_each(|(component, mask)| { + zip_eq(self.components(), mask_values).for_each(|(component, mask)| { component.evaluate_constraint_quotients_at_point( point, mask, diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index b559780df..f77bece2e 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -71,7 +71,7 @@ pub trait Component { fn evaluate_constraint_quotients_at_point( &self, point: CirclePoint, - mask: &ColumnVec>, + mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, interaction_elements: &InteractionElements, ); diff --git a/crates/prover/src/core/pcs/utils.rs b/crates/prover/src/core/pcs/utils.rs index 975d7bbba..761a6c7aa 100644 --- a/crates/prover/src/core/pcs/utils.rs +++ b/crates/prover/src/core/pcs/utils.rs @@ -16,12 +16,19 @@ impl TreeVec { TreeVec(self.0.into_iter().map(f).collect()) } pub fn zip(self, other: impl Into>) -> TreeVec<(T, U)> { + let other = other.into(); + TreeVec(self.0.into_iter().zip(other.0).collect()) + } + pub fn zip_eq(self, other: impl Into>) -> TreeVec<(T, U)> { let other = other.into(); TreeVec(zip_eq(self.0, other.0).collect()) } pub fn as_ref(&self) -> TreeVec<&T> { TreeVec(self.iter().collect()) } + pub fn as_mut(&mut self) -> TreeVec<&mut T> { + TreeVec(self.iter_mut().collect()) + } } /// Converts `&TreeVec` to `TreeVec<&T>`. diff --git a/crates/prover/src/core/pcs/verifier.rs b/crates/prover/src/core/pcs/verifier.rs index e2241b56b..d490dc02a 100644 --- a/crates/prover/src/core/pcs/verifier.rs +++ b/crates/prover/src/core/pcs/verifier.rs @@ -91,8 +91,8 @@ impl CommitmentSchemeVerifier { // Verify merkle decommitments. self.trees .as_ref() - .zip(proof.decommitments) - .zip(proof.queried_values.clone()) + .zip_eq(proof.decommitments) + .zip_eq(proof.queried_values.clone()) .map(|((tree, decommitment), queried_values)| { let queries = fri_query_domains .iter() diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index 9f9f0ad47..9cc035093 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -25,7 +25,6 @@ use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; use crate::core::vcs::hasher::Hasher; use crate::core::vcs::ops::MerkleOps; use crate::core::vcs::verifier::MerkleVerificationError; -use crate::core::ComponentVec; type Channel = Blake2sChannel; type ChannelHasher = Blake2sHasher; @@ -237,52 +236,32 @@ pub fn verify( commitment_scheme.verify_values(sample_points, proof.commitment_scheme_proof, channel) } +#[allow(clippy::type_complexity)] /// Structures the tree-wise sampled values into component-wise OODS values and a composition /// polynomial OODS value. fn sampled_values_to_mask( air: &impl Air, sampled_values: &TreeVec>>, -) -> Result<(ComponentVec>, SecureField), InvalidOodsSampleStructure> { - // Retrieve sampled mask values for each component. - let flat_trace_values = &mut sampled_values - .first() - .ok_or(InvalidOodsSampleStructure)? - .iter(); - let mut trace_oods_values = vec![]; - air.components().iter().for_each(|component| { - let n_trace_points = component.mask_points(CirclePoint::zero())[0].len(); - trace_oods_values.push( - flat_trace_values - .take(n_trace_points) - .cloned() - .collect_vec(), - ) - }); - - if air.n_interaction_phases() == 2 { - let interaction_values = &mut sampled_values - .get(1) - .ok_or(InvalidOodsSampleStructure)? - .iter(); - - air.components() - .iter() - .zip_eq(&mut trace_oods_values) - .for_each(|(component, values)| { - let n_interaction_points = component.mask_points(CirclePoint::zero())[1].len(); - values.extend( - interaction_values - .take(n_interaction_points) - .cloned() - .collect_vec(), - ) - }); - } +) -> Result<(Vec>>>, SecureField), InvalidOodsSampleStructure> { + let mut sampled_values = sampled_values.as_ref(); + let composition_values = sampled_values.pop().ok_or(InvalidOodsSampleStructure)?; + + let mut sample_iters = sampled_values.map(|tree_value| tree_value.iter()); + let trace_oods_values = air + .components() + .iter() + .map(|component| { + component + .mask_points(CirclePoint::zero()) + .zip(sample_iters.as_mut()) + .map(|(mask_per_tree, tree_iter)| { + tree_iter.take(mask_per_tree.len()).cloned().collect_vec() + }) + }) + .collect_vec(); - let composition_partial_sampled_values = - sampled_values.last().ok_or(InvalidOodsSampleStructure)?; let composition_oods_value = SecureCirclePoly::::eval_from_partial_evals( - composition_partial_sampled_values + composition_values .iter() .flatten() .cloned() @@ -291,7 +270,7 @@ fn sampled_values_to_mask( .map_err(|_| InvalidOodsSampleStructure)?, ); - Ok((ComponentVec(trace_oods_values), composition_oods_value)) + Ok((trace_oods_values, composition_oods_value)) } /// Error when the sampled values have an invalid structure. @@ -428,7 +407,7 @@ mod tests { fn evaluate_constraint_quotients_at_point( &self, _point: CirclePoint, - _mask: &crate::core::ColumnVec>, + _mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, _interaction_elements: &InteractionElements, ) { diff --git a/crates/prover/src/examples/fibonacci/component.rs b/crates/prover/src/examples/fibonacci/component.rs index 2b45c4e4e..7fa317b7f 100644 --- a/crates/prover/src/examples/fibonacci/component.rs +++ b/crates/prover/src/examples/fibonacci/component.rs @@ -110,19 +110,17 @@ impl Component for FibonacciComponent { fn evaluate_constraint_quotients_at_point( &self, point: CirclePoint, - mask: &ColumnVec>, + mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, _interaction_elements: &InteractionElements, ) { evaluation_accumulator.accumulate( - self.step_constraint_eval_quotient_by_mask(point, &mask[0][..].try_into().unwrap()), - ); - evaluation_accumulator.accumulate( - self.boundary_constraint_eval_quotient_by_mask( - point, - &mask[0][..1].try_into().unwrap(), - ), + self.step_constraint_eval_quotient_by_mask(point, &mask[0][0][..].try_into().unwrap()), ); + evaluation_accumulator.accumulate(self.boundary_constraint_eval_quotient_by_mask( + point, + &mask[0][0][..1].try_into().unwrap(), + )); } } diff --git a/crates/prover/src/examples/fibonacci/mod.rs b/crates/prover/src/examples/fibonacci/mod.rs index 940f979b2..5c9cdda12 100644 --- a/crates/prover/src/examples/fibonacci/mod.rs +++ b/crates/prover/src/examples/fibonacci/mod.rs @@ -179,7 +179,7 @@ mod tests { let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff); fib.air.component.evaluate_constraint_quotients_at_point( point, - &mask_values, + &TreeVec::new(vec![mask_values]), &mut evaluation_accumulator, &InteractionElements::new(BTreeMap::new()), ); diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 1810741bb..7fa36dacf 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -121,7 +121,7 @@ impl Component for PoseidonComponent { fn evaluate_constraint_quotients_at_point( &self, point: CirclePoint, - mask: &ColumnVec>, + mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, _interaction_elements: &InteractionElements, ) { @@ -129,7 +129,7 @@ impl Component for PoseidonComponent { let denom = coset_vanishing(constraint_zero_domain, point); let denom_inverse = denom.inverse(); let mut eval = PoseidonEvalAtPoint { - mask, + mask: &mask[0], evaluation_accumulator, col_index: 0, denom_inverse, diff --git a/crates/prover/src/examples/wide_fibonacci/component.rs b/crates/prover/src/examples/wide_fibonacci/component.rs index 6ca3e21ff..47fd6dcc9 100644 --- a/crates/prover/src/examples/wide_fibonacci/component.rs +++ b/crates/prover/src/examples/wide_fibonacci/component.rs @@ -66,7 +66,7 @@ impl WideFibComponent { fn evaluate_lookup_boundary_constraint_at_point( &self, point: CirclePoint, - mask: &ColumnVec>, + mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, constraint_zero_domain: Coset, interaction_elements: &InteractionElements, @@ -74,15 +74,18 @@ impl WideFibComponent { let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]); let value = SecureCirclePoly::::eval_from_partial_evals(std::array::from_fn(|i| { - mask[self.n_columns() + i][0] + mask[1][i][0] })); let numerator = (value * shifted_secure_combination( - &[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]], + &[ + mask[0][self.n_columns() - 2][0], + mask[0][self.n_columns() - 1][0], + ], alpha, z, )) - - shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z); + - shifted_secure_combination(&[mask[0][0][0], mask[0][1][0]], alpha, z); let denom = point_vanishing(constraint_zero_domain.at(0), point); evaluation_accumulator.accumulate(numerator / denom); } @@ -90,7 +93,7 @@ impl WideFibComponent { fn evaluate_lookup_step_constraints_at_point( &self, point: CirclePoint, - mask: &ColumnVec>, + mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, constraint_zero_domain: Coset, interaction_elements: &InteractionElements, @@ -98,19 +101,22 @@ impl WideFibComponent { let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]); let value = SecureCirclePoly::::eval_from_partial_evals(std::array::from_fn(|i| { - mask[self.n_columns() + i][0] + mask[1][i][0] })); let prev_value = SecureCirclePoly::::eval_from_partial_evals(std::array::from_fn(|i| { - mask[self.n_columns() + i][1] + mask[1][i][1] })); let numerator = (value * shifted_secure_combination( - &[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]], + &[ + mask[0][self.n_columns() - 2][0], + mask[0][self.n_columns() - 1][0], + ], alpha, z, )) - - (prev_value * shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z)); + - (prev_value * shifted_secure_combination(&[mask[0][0][0], mask[0][1][0]], alpha, z)); let denom = coset_vanishing(constraint_zero_domain, point) / point_excluder(constraint_zero_domain.at(0), point); evaluation_accumulator.accumulate(numerator / denom); @@ -165,7 +171,7 @@ impl Component for WideFibComponent { fn evaluate_constraint_quotients_at_point( &self, point: CirclePoint, - mask: &ColumnVec>, + mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, interaction_elements: &InteractionElements, ) { @@ -186,7 +192,7 @@ impl Component for WideFibComponent { ); self.evaluate_trace_step_constraints_at_point( point, - mask, + &mask[0], evaluation_accumulator, constraint_zero_domain, ); diff --git a/crates/prover/src/examples/wide_fibonacci/simd.rs b/crates/prover/src/examples/wide_fibonacci/simd.rs index 387ece5e7..4219a0d82 100644 --- a/crates/prover/src/examples/wide_fibonacci/simd.rs +++ b/crates/prover/src/examples/wide_fibonacci/simd.rs @@ -113,7 +113,7 @@ impl Component for SimdWideFibComponent { fn evaluate_constraint_quotients_at_point( &self, point: CirclePoint, - mask: &ColumnVec>, + mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, _interaction_elements: &InteractionElements, ) { @@ -121,7 +121,7 @@ impl Component for SimdWideFibComponent { let denom = coset_vanishing(constraint_zero_domain, point); let denom_inverse = denom.inverse(); for i in 0..self.n_columns() - 2 { - let numerator = mask[i][0].square() + mask[i + 1][0].square() - mask[i + 2][0]; + let numerator = mask[0][i][0].square() + mask[0][i + 1][0].square() - mask[0][i + 2][0]; evaluation_accumulator.accumulate(numerator * denom_inverse); } } diff --git a/crates/prover/src/trace_generation/registry.rs b/crates/prover/src/trace_generation/registry.rs index afaa961f8..916397f82 100644 --- a/crates/prover/src/trace_generation/registry.rs +++ b/crates/prover/src/trace_generation/registry.rs @@ -82,7 +82,7 @@ mod tests { fn evaluate_constraint_quotients_at_point( &self, _point: CirclePoint, - _mask: &ColumnVec>, + _mask: &TreeVec>>, _evaluation_accumulator: &mut PointEvaluationAccumulator, _interaction_elements: &InteractionElements, ) {