Skip to content

Commit

Permalink
Dumb down FRI (#529)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/529)
<!-- Reviewable:end -->
  • Loading branch information
spapinistarkware authored Apr 3, 2024
1 parent ea89ad5 commit c3327b8
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 338 deletions.
10 changes: 8 additions & 2 deletions benches/fri.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use stwo::core::backend::CPUBackend;
use stwo::core::fields::m31::BaseField;
use stwo::core::fields::qm31::SecureField;
use stwo::core::fields::secure_column::SecureColumn;
use stwo::core::fri::FriOps;
use stwo::core::poly::circle::CanonicCoset;
use stwo::core::poly::line::{LineDomain, LineEvaluation};
Expand All @@ -10,9 +12,13 @@ fn folding_benchmark(c: &mut Criterion) {
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
let evals = LineEvaluation::new(
domain,
vec![BaseField::from_u32_unchecked(712837213).into(); 1 << LOG_SIZE],
SecureColumn {
columns: std::array::from_fn(|i| {
vec![BaseField::from_u32_unchecked(i as u32); 1 << LOG_SIZE]
}),
},
);
let alpha = BaseField::from_u32_unchecked(12389).into();
let alpha = SecureField::from_u32_unchecked(2213980, 2213981, 2213982, 2213983);
c.bench_function("fold_line", |b| {
b.iter(|| {
black_box(CPUBackend::fold_line(black_box(&evals), black_box(alpha)));
Expand Down
34 changes: 13 additions & 21 deletions src/core/backend/cpu/fri.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
use std::iter::zip;

use super::CPUBackend;
use crate::core::fft::ibutterfly;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, Field, FieldExpOps};
use crate::core::fields::FieldExpOps;
use crate::core::fri::{FriOps, CIRCLE_TO_LINE_FOLD_STEP, FOLD_STEP};
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::circle::SecureEvaluation;
use crate::core::poly::line::LineEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;

impl FriOps for CPUBackend {
fn fold_line(
eval: &LineEvaluation<Self, SecureField, BitReversedOrder>,
alpha: SecureField,
) -> LineEvaluation<Self, SecureField, BitReversedOrder> {
fn fold_line(eval: &LineEvaluation<Self>, alpha: SecureField) -> LineEvaluation<Self> {
let n = eval.len();
assert!(n >= 2, "Evaluation too small");

let domain = eval.domain();

let folded_values = eval
.values
.into_iter()
.array_chunks()
.enumerate()
.map(|(i, &[f_x, f_neg_x])| {
.map(|(i, [f_x, f_neg_x])| {
// TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer.
let x = domain.at(bit_reverse_index(i << FOLD_STEP, domain.log_size()));

Expand All @@ -37,22 +31,20 @@ impl FriOps for CPUBackend {

LineEvaluation::new(domain.double(), folded_values)
}
fn fold_circle_into_line<F: Field>(
dst: &mut LineEvaluation<Self, SecureField, BitReversedOrder>,
src: &CircleEvaluation<Self, F, BitReversedOrder>,
fn fold_circle_into_line(
dst: &mut LineEvaluation<Self>,
src: &SecureEvaluation<Self>,
alpha: SecureField,
) where
F: ExtensionOf<BaseField>,
SecureField: ExtensionOf<F> + Field,
{
) {
assert_eq!(src.len() >> CIRCLE_TO_LINE_FOLD_STEP, dst.len());

let domain = src.domain;
let alpha_sq = alpha * alpha;

zip(&mut dst.values, src.array_chunks())
src.into_iter()
.array_chunks()
.enumerate()
.for_each(|(i, (dst, &[f_p, f_neg_p]))| {
.for_each(|(i, [f_p, f_neg_p])| {
// TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer.
let p = domain.at(bit_reverse_index(
i << CIRCLE_TO_LINE_FOLD_STEP,
Expand All @@ -64,7 +56,7 @@ impl FriOps for CPUBackend {
ibutterfly(&mut f0_px, &mut f1_px, p.y.inverse());
let f_prime = alpha * f1_px + f0_px;

*dst = *dst * alpha_sq + f_prime;
dst.values.set(i, dst.values.at(i) * alpha_sq + f_prime);
});
}
}
3 changes: 0 additions & 3 deletions src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use std::fmt::Debug;
use super::{Backend, Column, ColumnOps, FieldOps};
use crate::core::fields::Field;
use crate::core::poly::circle::{CircleEvaluation, CirclePoly};
use crate::core::poly::line::LineEvaluation;
use crate::core::utils::bit_reverse;

#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -49,8 +48,6 @@ impl<T: Debug + Clone + Default> Column<T> for Vec<T> {

pub type CPUCirclePoly = CirclePoly<CPUBackend>;
pub type CPUCircleEvaluation<F, EvalOrder> = CircleEvaluation<CPUBackend, F, EvalOrder>;
// TODO(spapini): Remove the EvalOrder on LineEvaluation.
pub type CPULineEvaluation<F, EvalOrder> = LineEvaluation<CPUBackend, F, EvalOrder>;

#[cfg(test)]
mod tests {
Expand Down
8 changes: 0 additions & 8 deletions src/core/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use crate::commitment_scheme::blake2_hash::{Blake2sHash, Blake2sHasher};
use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher;
use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver};
use crate::core::channel::Channel;
use crate::core::poly::circle::SecureEvaluation;

type MerkleHasher = Blake2sMerkleHasher;
type ProofChannel = Blake2sChannel;
Expand Down Expand Up @@ -89,13 +88,6 @@ impl CommitmentSchemeProver {
let columns = self.evaluations().flatten();
let quotients = compute_fri_quotients(&columns, &samples.flatten(), channel.draw_felt());

// TODO(spapini): Conversion to CircleEvaluation can be removed when FRI supports
// SecureColumn.
let quotients = quotients
.into_iter()
.map(SecureEvaluation::to_cpu)
.collect_vec();

// Run FRI commitment phase on the oods quotients.
let fri_config = FriConfig::new(LOG_LAST_LAYER_DEGREE_BOUND, LOG_BLOWUP_FACTOR, N_QUERIES);
let fri_prover =
Expand Down
4 changes: 2 additions & 2 deletions src/core/commitment_scheme/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub fn fri_answers(
random_coeff: SecureField,
query_domain_per_log_size: BTreeMap<u32, SparseSubCircleDomain>,
queried_values_per_column: &[Vec<BaseField>],
) -> Result<Vec<SparseCircleEvaluation<SecureField>>, VerificationError> {
) -> Result<Vec<SparseCircleEvaluation>, VerificationError> {
izip!(column_log_sizes, samples, queried_values_per_column)
.sorted_by_key(|(log_size, ..)| Reverse(*log_size))
.group_by(|(log_size, ..)| *log_size)
Expand All @@ -121,7 +121,7 @@ pub fn fri_answers_for_log_size(
random_coeff: SecureField,
query_domain: &SparseSubCircleDomain,
queried_values_per_column: &[&Vec<BaseField>],
) -> Result<SparseCircleEvaluation<SecureField>, VerificationError> {
) -> Result<SparseCircleEvaluation, VerificationError> {
let commitment_domain = CanonicCoset::new(log_size).circle_domain();
let sample_batches = ColumnSampleBatch::new_vec(samples);
for queried_values in queried_values_per_column {
Expand Down
52 changes: 52 additions & 0 deletions src/core/fields/secure_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,56 @@ impl<B: FieldOps<BaseField>> SecureColumn<B> {
pub fn is_empty(&self) -> bool {
self.columns[0].is_empty()
}

pub fn to_cpu(&self) -> SecureColumn<CPUBackend> {
SecureColumn {
columns: self.columns.clone().map(|c| c.to_vec()),
}
}
}

pub struct SecureColumnIter<'a> {
column: &'a SecureColumn<CPUBackend>,
index: usize,
}
impl Iterator for SecureColumnIter<'_> {
type Item = SecureField;

fn next(&mut self) -> Option<Self::Item> {
if self.index < self.column.len() {
let value = self.column.at(self.index);
self.index += 1;
Some(value)
} else {
None
}
}
}
impl<'a> IntoIterator for &'a SecureColumn<CPUBackend> {
type Item = SecureField;
type IntoIter = SecureColumnIter<'a>;

fn into_iter(self) -> Self::IntoIter {
SecureColumnIter {
column: self,
index: 0,
}
}
}
impl FromIterator<SecureField> for SecureColumn<CPUBackend> {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let mut columns = std::array::from_fn(|_| vec![]);
for value in iter.into_iter() {
let vals = value.to_m31_array();
for j in 0..SECURE_EXTENSION_DEGREE {
columns[j].push(vals[j]);
}
}
SecureColumn { columns }
}
}
impl From<SecureColumn<CPUBackend>> for Vec<SecureField> {
fn from(column: SecureColumn<CPUBackend>) -> Self {
column.into_iter().collect()
}
}
Loading

0 comments on commit c3327b8

Please sign in to comment.