Skip to content

Commit

Permalink
Make prove() generic in Backend (#490)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/490)
<!-- Reviewable:end -->
  • Loading branch information
spapinistarkware authored Apr 4, 2024
1 parent 8046f4f commit 42e6cf6
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 39 deletions.
27 changes: 15 additions & 12 deletions src/core/air/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
use itertools::Itertools;

use crate::core::backend::cpu::CPUCircleEvaluation;
use crate::core::backend::{Backend, CPUBackend};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::poly::circle::{CanonicCoset, CirclePoly, SecureCirclePoly};
use crate::core::fields::FieldOps;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::generate_secure_powers;

Expand Down Expand Up @@ -101,22 +102,27 @@ impl<B: Backend> DomainEvaluationAccumulator<B> {
}
}

impl DomainEvaluationAccumulator<CPUBackend> {
pub trait AccumulationOps: FieldOps<BaseField> + Sized {
/// Accumulates other into column:
/// column = column + other.
fn accumulate(column: &mut SecureColumn<Self>, other: &SecureColumn<Self>);
}

impl<B: Backend> DomainEvaluationAccumulator<B> {
/// Computes f(P) as coefficients.
pub fn finalize(self) -> SecureCirclePoly {
pub fn finalize(self) -> SecureCirclePoly<B> {
assert_eq!(
self.random_coeff_powers.len(),
0,
"not all random coefficients were used"
);
let mut res_coeffs = SecureColumn::<CPUBackend>::zeros(1 << self.log_size());
let mut res_coeffs = SecureColumn::<B>::zeros(1 << self.log_size());
let res_log_size = self.log_size();
let res_size = 1 << res_log_size;

for (log_size, values) in self.sub_accumulations.into_iter().enumerate().skip(1) {
let coeffs = SecureColumn::<CPUBackend> {
let coeffs = SecureColumn::<B> {
columns: values.columns.map(|c| {
CPUCircleEvaluation::<_, BitReversedOrder>::new(
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(
CanonicCoset::new(log_size as u32).circle_domain(),
c,
)
Expand All @@ -126,10 +132,7 @@ impl DomainEvaluationAccumulator<CPUBackend> {
}),
};
// Add column coefficients into result coefficients, element-wise, in-place.
for i in 0..res_size {
let res_coeff = res_coeffs.at(i) + coeffs.at(i);
res_coeffs.set(i, res_coeff);
}
B::accumulate(&mut res_coeffs, &coeffs);
}

SecureCirclePoly(res_coeffs.columns.map(CirclePoly::new))
Expand Down
16 changes: 8 additions & 8 deletions src/core/air/air_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ use itertools::Itertools;

use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::{Air, ComponentTrace};
use crate::core::backend::CPUBackend;
use crate::core::backend::Backend;
use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::{CanonicCoset, CirclePoly, SecureCirclePoly};
use crate::core::prover::LOG_BLOWUP_FACTOR;
use crate::core::ComponentVec;

pub trait AirExt: Air<CPUBackend> {
pub trait AirExt<B: Backend>: Air<B> {
fn composition_log_degree_bound(&self) -> u32 {
self.components()
.iter()
Expand All @@ -30,8 +30,8 @@ pub trait AirExt: Air<CPUBackend> {
fn compute_composition_polynomial(
&self,
random_coeff: SecureField,
component_traces: &[ComponentTrace<'_, CPUBackend>],
) -> SecureCirclePoly {
component_traces: &[ComponentTrace<'_, B>],
) -> SecureCirclePoly<B> {
let total_constraints: usize = self.components().iter().map(|c| c.n_constraints()).sum();
let mut accumulator = DomainEvaluationAccumulator::new(
random_coeff,
Expand All @@ -47,7 +47,7 @@ pub trait AirExt: Air<CPUBackend> {
fn mask_points_and_values(
&self,
point: CirclePoint<SecureField>,
component_traces: &[ComponentTrace<'_, CPUBackend>],
component_traces: &[ComponentTrace<'_, B>],
) -> (
ComponentVec<Vec<CirclePoint<SecureField>>>,
ComponentVec<Vec<SecureField>>,
Expand Down Expand Up @@ -104,8 +104,8 @@ pub trait AirExt: Air<CPUBackend> {

fn component_traces<'a>(
&'a self,
polynomials: &'a [CirclePoly<CPUBackend>],
) -> Vec<ComponentTrace<'_, CPUBackend>> {
polynomials: &'a [CirclePoly<B>],
) -> Vec<ComponentTrace<'_, B>> {
let poly_iter = &mut polynomials.iter();
self.components()
.iter()
Expand All @@ -118,4 +118,4 @@ pub trait AirExt: Air<CPUBackend> {
}
}

impl<A: Air<CPUBackend>> AirExt for A {}
impl<B: Backend, A: Air<B>> AirExt<B> for A {}
12 changes: 12 additions & 0 deletions src/core/backend/avx512/accumulation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use super::AVX512Backend;
use crate::core::air::accumulation::AccumulationOps;
use crate::core::fields::secure_column::SecureColumn;

impl AccumulationOps for AVX512Backend {
fn accumulate(column: &mut SecureColumn<Self>, other: &SecureColumn<Self>) {
for i in 0..column.n_packs() {
let res_coeff = column.packed_at(i) + other.packed_at(i);
column.set_packed(i, res_coeff);
}
}
}
5 changes: 5 additions & 0 deletions src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod accumulation;
pub mod bit_reverse;
mod blake2s;
pub mod blake2s_avx;
Expand Down Expand Up @@ -217,6 +218,10 @@ impl FromIterator<PackedSecureField> for SecureFieldVec {
}

impl SecureColumn<AVX512Backend> {
pub fn n_packs(&self) -> usize {
self.columns[0].data.len()
}

pub fn packed_at(&self, vec_index: usize) -> PackedSecureField {
unsafe {
PackedSecureField([
Expand Down
12 changes: 12 additions & 0 deletions src/core/backend/cpu/accumulation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use super::CPUBackend;
use crate::core::air::accumulation::AccumulationOps;
use crate::core::fields::secure_column::SecureColumn;

impl AccumulationOps for CPUBackend {
fn accumulate(column: &mut SecureColumn<Self>, other: &SecureColumn<Self>) {
for i in 0..column.len() {
let res_coeff = column.at(i) + other.at(i);
column.set(i, res_coeff);
}
}
}
1 change: 1 addition & 0 deletions src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod accumulation;
mod blake2s;
mod circle;
mod fri;
Expand Down
11 changes: 10 additions & 1 deletion src/core/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt::Debug;

pub use cpu::CPUBackend;

use super::air::accumulation::AccumulationOps;
use super::commitment_scheme::quotients::QuotientOps;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
Expand All @@ -14,7 +15,15 @@ pub mod avx512;
pub mod cpu;

pub trait Backend:
Copy + Clone + Debug + FieldOps<BaseField> + FieldOps<SecureField> + PolyOps + QuotientOps + FriOps
Copy
+ Clone
+ Debug
+ FieldOps<BaseField>
+ FieldOps<SecureField>
+ PolyOps
+ QuotientOps
+ FriOps
+ AccumulationOps
{
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/commitment_scheme/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
mod prover;
pub mod quotients;
pub mod utils;
mod utils;
mod verifier;

pub use self::prover::{CommitmentSchemeProof, CommitmentSchemeProver};
Expand Down
12 changes: 6 additions & 6 deletions src/core/poly/circle/secure_poly.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::ops::Deref;

use super::{CircleDomain, CircleEvaluation};
use crate::core::backend::cpu::{CPUCircleEvaluation, CPUCirclePoly};
use super::{CircleDomain, CircleEvaluation, CirclePoly, PolyOps};
use crate::core::backend::cpu::CPUCircleEvaluation;
use crate::core::backend::CPUBackend;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
Expand All @@ -10,9 +10,9 @@ use crate::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use crate::core::fields::FieldOps;
use crate::core::poly::BitReversedOrder;

pub struct SecureCirclePoly(pub [CPUCirclePoly; SECURE_EXTENSION_DEGREE]);
pub struct SecureCirclePoly<B: FieldOps<BaseField>>(pub [CirclePoly<B>; SECURE_EXTENSION_DEGREE]);

impl SecureCirclePoly {
impl<B: PolyOps> SecureCirclePoly<B> {
pub fn eval_at_point(&self, point: CirclePoint<SecureField>) -> SecureField {
Self::eval_from_partial_evals(self.eval_columns_at_point(point))
}
Expand Down Expand Up @@ -44,8 +44,8 @@ impl SecureCirclePoly {
}
}

impl Deref for SecureCirclePoly {
type Target = [CPUCirclePoly; SECURE_EXTENSION_DEGREE];
impl<B: FieldOps<BaseField>> Deref for SecureCirclePoly<B> {
type Target = [CirclePoly<B>; SECURE_EXTENSION_DEGREE];

fn deref(&self) -> &Self::Target {
&self.0
Expand Down
25 changes: 14 additions & 11 deletions src/core/prover/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use itertools::Itertools;
use thiserror::Error;

use super::backend::Backend;
use super::commitment_scheme::{CommitmentSchemeProof, TreeVec};
use super::fri::FriVerificationError;
use super::poly::circle::{CanonicCoset, PolyOps, SecureCirclePoly, MAX_CIRCLE_DOMAIN_LOG_SIZE};
use super::poly::circle::{CanonicCoset, SecureCirclePoly, MAX_CIRCLE_DOMAIN_LOG_SIZE};
use super::proof_of_work::ProofOfWorkVerificationError;
use super::ColumnVec;
use crate::commitment_scheme::blake2_hash::Blake2sHasher;
use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher;
use crate::commitment_scheme::hasher::Hasher;
use crate::commitment_scheme::ops::MerkleOps;
use crate::commitment_scheme::verifier::MerkleVerificationError;
use crate::core::air::{Air, AirExt};
use crate::core::backend::cpu::CPUCircleEvaluation;
use crate::core::backend::CPUBackend;
use crate::core::channel::{Blake2sChannel, Channel as ChannelTrait};
use crate::core::circle::CirclePoint;
Expand All @@ -22,7 +24,8 @@ use crate::core::poly::BitReversedOrder;
use crate::core::ComponentVec;

type Channel = Blake2sChannel;
type MerkleHasher = Blake2sHasher;
type ChannelHasher = Blake2sHasher;
type MerkleHasher = Blake2sMerkleHasher;

pub const LOG_BLOWUP_FACTOR: u32 = 1;
pub const LOG_LAST_LAYER_DEGREE_BOUND: u32 = 0;
Expand All @@ -31,7 +34,7 @@ pub const N_QUERIES: usize = 3;

#[derive(Debug)]
pub struct StarkProof {
pub commitments: TreeVec<<MerkleHasher as Hasher>::Hash>,
pub commitments: TreeVec<<ChannelHasher as Hasher>::Hash>,
pub commitment_scheme_proof: CommitmentSchemeProof,
}

Expand All @@ -43,10 +46,10 @@ pub struct AdditionalProofData {
pub oods_quotients: Vec<CircleEvaluation<CPUBackend, SecureField, BitReversedOrder>>,
}

pub fn prove(
air: &impl Air<CPUBackend>,
pub fn prove<B: Backend + MerkleOps<MerkleHasher>>(
air: &impl Air<B>,
channel: &mut Channel,
trace: ColumnVec<CPUCircleEvaluation<BaseField, BitReversedOrder>>,
trace: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
) -> Result<StarkProof, ProvingError> {
// Check that traces are not too big.
for (i, trace) in trace.iter().enumerate() {
Expand Down Expand Up @@ -93,7 +96,7 @@ pub fn prove(
sample_points.push(vec![vec![oods_point]; 4]);

// TODO(spapini): Precompute twiddles outside.
let twiddles = CPUBackend::precompute_twiddles(
let twiddles = B::precompute_twiddles(
CanonicCoset::new(composition_polynomial_log_degree_bound + 1).half_coset(),
);
let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel, &twiddles);
Expand Down Expand Up @@ -165,12 +168,12 @@ pub fn verify(

/// Structures the tree-wise sampled values into component-wise OODS values and a composition
/// polynomial OODS value.
fn sampled_values_to_mask(
air: &impl Air<CPUBackend>,
fn sampled_values_to_mask<B: Backend>(
air: &impl Air<B>,
mut sampled_values: TreeVec<ColumnVec<Vec<SecureField>>>,
) -> Result<(ComponentVec<Vec<SecureField>>, SecureField), ()> {
let composition_partial_sampled_values = sampled_values.pop().ok_or(())?;
let composition_oods_value = SecureCirclePoly::eval_from_partial_evals(
let composition_oods_value = SecureCirclePoly::<B>::eval_from_partial_evals(
composition_partial_sampled_values
.iter()
.flatten()
Expand Down

0 comments on commit 42e6cf6

Please sign in to comment.