Skip to content

Commit

Permalink
Fri AVX ops (#563)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/563)
<!-- Reviewable:end -->
  • Loading branch information
spapinistarkware authored Apr 4, 2024
1 parent 7dedf53 commit 9d91b8a
Show file tree
Hide file tree
Showing 12 changed files with 427 additions and 117 deletions.
9 changes: 7 additions & 2 deletions benches/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use stwo::core::fields::m31::BaseField;
use stwo::core::fields::qm31::SecureField;
use stwo::core::fields::secure_column::SecureColumn;
use stwo::core::fri::FriOps;
use stwo::core::poly::circle::CanonicCoset;
use stwo::core::poly::circle::{CanonicCoset, PolyOps};
use stwo::core::poly::line::{LineDomain, LineEvaluation};

fn folding_benchmark(c: &mut Criterion) {
Expand All @@ -19,9 +19,14 @@ fn folding_benchmark(c: &mut Criterion) {
},
);
let alpha = SecureField::from_u32_unchecked(2213980, 2213981, 2213982, 2213983);
let twiddles = CPUBackend::precompute_twiddles(domain.coset());
c.bench_function("fold_line", |b| {
b.iter(|| {
black_box(CPUBackend::fold_line(black_box(&evals), black_box(alpha)));
black_box(CPUBackend::fold_line(
black_box(&evals),
black_box(alpha),
&twiddles,
));
})
});
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/backend/avx512/fft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) {
/// Computes the twiddles for the first fft layer from the second, and loads both to AVX registers.
/// Returns the twiddles for the first layer and the twiddles for the second layer.
/// # Safety
unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) {
pub unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) {
// Start by loading the twiddles for the second layer (layer 1):
// The twiddles for layer 1 are replicated in the following pattern:
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
Expand Down
160 changes: 160 additions & 0 deletions src/core/backend/avx512/fri.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use super::AVX512Backend;
use crate::core::backend::avx512::fft::compute_first_twiddles;
use crate::core::backend::avx512::fft::ifft::avx_ibutterfly;
use crate::core::backend::avx512::qm31::PackedSecureField;
use crate::core::backend::avx512::VECS_LOG_SIZE;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fri::{self, FriOps};
use crate::core::poly::circle::SecureEvaluation;
use crate::core::poly::line::LineEvaluation;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::domain_line_twiddles_from_tree;

impl FriOps for AVX512Backend {
fn fold_line(
eval: &LineEvaluation<Self>,
alpha: SecureField,
twiddles: &TwiddleTree<Self>,
) -> LineEvaluation<Self> {
let log_size = eval.len().ilog2();
if log_size <= VECS_LOG_SIZE as u32 {
let eval = fri::fold_line(&eval.to_cpu(), alpha);
return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect());
}

let domain = eval.domain();
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];

let mut folded_values = SecureColumn::zeros(1 << (log_size - 1));

for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) {
let value = unsafe {
let twiddle_dbl: [i32; 16] =
std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i));
let val0 = eval.values.packed_at(vec_index * 2).to_packed_m31s();
let val1 = eval.values.packed_at(vec_index * 2 + 1).to_packed_m31s();
let pairs: [_; 4] = std::array::from_fn(|i| {
let (a, b) = val0[i].deinterleave_with(val1[i]);
avx_ibutterfly(a, b, std::mem::transmute(twiddle_dbl))
});
let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
folded_values.set_packed(vec_index, value);
}

LineEvaluation::new(domain.double(), folded_values)
}

fn fold_circle_into_line(
dst: &mut LineEvaluation<Self>,
src: &SecureEvaluation<Self>,
alpha: SecureField,
twiddles: &TwiddleTree<Self>,
) {
let log_size = src.len().ilog2();
assert!(log_size > VECS_LOG_SIZE as u32, "Evaluation too small");

let domain = src.domain;
let alpha_sq = alpha * alpha;
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];

for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) {
let value = unsafe {
// The 16 twiddles of the circle domain can be derived from the 8 twiddles of the
// next line domain. See `compute_first_twiddles()`.
let twiddle_dbl: [i32; 8] =
std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 8 + i));
let (t0, _) = compute_first_twiddles(twiddle_dbl);
let val0 = src.values.packed_at(vec_index * 2).to_packed_m31s();
let val1 = src.values.packed_at(vec_index * 2 + 1).to_packed_m31s();
let pairs: [_; 4] = std::array::from_fn(|i| {
let (a, b) = val0[i].deinterleave_with(val1[i]);
avx_ibutterfly(a, b, t0)
});
let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
dst.values.set_packed(
vec_index,
dst.values.packed_at(vec_index) * PackedSecureField::broadcast(alpha_sq) + value,
);
}
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use crate::core::backend::avx512::AVX512Backend;
use crate::core::backend::CPUBackend;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fri::FriOps;
use crate::core::poly::circle::{CanonicCoset, PolyOps, SecureEvaluation};
use crate::core::poly::line::{LineDomain, LineEvaluation};
use crate::qm31;

#[test]
fn test_fold_line() {
const LOG_SIZE: u32 = 7;
let values: Vec<SecureField> = (0..(1 << LOG_SIZE))
.map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3))
.collect();
let alpha = qm31!(1, 3, 5, 7);
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
let cpu_fold = CPUBackend::fold_line(
&LineEvaluation::new(domain, values.iter().copied().collect()),
alpha,
&CPUBackend::precompute_twiddles(domain.coset()),
);

let avx_fold = AVX512Backend::fold_line(
&LineEvaluation::new(domain, values.iter().copied().collect()),
alpha,
&AVX512Backend::precompute_twiddles(domain.coset()),
);

assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec());
}

#[test]
fn test_fold_circle_into_line() {
const LOG_SIZE: u32 = 7;
let values: Vec<SecureField> = (0..(1 << LOG_SIZE))
.map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3))
.collect();
let alpha = qm31!(1, 3, 5, 7);
let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let line_domain = LineDomain::new(circle_domain.half_coset);

let mut cpu_fold =
LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1)));
CPUBackend::fold_circle_into_line(
&mut cpu_fold,
&SecureEvaluation {
domain: circle_domain,
values: values.iter().copied().collect(),
},
alpha,
&CPUBackend::precompute_twiddles(line_domain.coset()),
);

let mut avx_fold =
LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1)));
AVX512Backend::fold_circle_into_line(
&mut avx_fold,
&SecureEvaluation {
domain: circle_domain,
values: values.iter().copied().collect(),
},
alpha,
&AVX512Backend::precompute_twiddles(line_domain.coset()),
);

assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec());
}
}
12 changes: 11 additions & 1 deletion src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod blake2s_avx;
pub mod circle;
pub mod cm31;
pub mod fft;
mod fri;
pub mod m31;
pub mod qm31;
pub mod quotients;
Expand All @@ -17,7 +18,7 @@ use self::bit_reverse::bit_reverse_m31;
use self::cm31::PackedCM31;
pub use self::m31::{PackedBaseField, K_BLOCK_SIZE};
use self::qm31::PackedSecureField;
use super::{Backend, Column, ColumnOps};
use super::{Backend, CPUBackend, Column, ColumnOps};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
Expand Down Expand Up @@ -252,6 +253,15 @@ impl SecureColumn<AVX512Backend> {
}
}

impl FromIterator<SecureField> for SecureColumn<AVX512Backend> {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let cpu_col = SecureColumn::<CPUBackend>::from_iter(iter);
SecureColumn {
columns: cpu_col.columns.map(|col| col.into_iter().collect()),
}
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
Expand Down
17 changes: 12 additions & 5 deletions src/core/backend/avx512/qm31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,24 @@ impl PackedSecureField {

// Multiply packed QM31 by packed M31.
pub fn mul_packed_m31(&self, rhs: PackedBaseField) -> PackedSecureField {
let a = self.0[0].0[0] * rhs;
let b = self.0[0].0[1] * rhs;
let c = self.0[1].0[0] * rhs;
let d = self.0[1].0[1] * rhs;
PackedSecureField([PackedCM31([a, b]), PackedCM31([c, d])])
Self::from_packed_m31s(self.to_packed_m31s().map(|m31| m31 * rhs))
}

/// Sums all the elements in the packed M31 element.
pub fn pointwise_sum(self) -> QM31 {
self.to_array().into_iter().sum()
}

pub fn to_packed_m31s(&self) -> [PackedBaseField; 4] {
[self.a().a(), self.a().b(), self.b().a(), self.b().b()]
}

pub fn from_packed_m31s(array: [PackedBaseField; 4]) -> Self {
Self([
PackedCM31([array[0], array[1]]),
PackedCM31([array[2], array[3]]),
])
}
}
impl Add for PackedSecureField {
type Output = Self;
Expand Down
59 changes: 11 additions & 48 deletions src/core/backend/cpu/fri.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,25 @@
use super::CPUBackend;
use crate::core::fft::ibutterfly;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::fri::{FriOps, CIRCLE_TO_LINE_FOLD_STEP, FOLD_STEP};
use crate::core::fri::{fold_circle_into_line, fold_line, FriOps};
use crate::core::poly::circle::SecureEvaluation;
use crate::core::poly::line::LineEvaluation;
use crate::core::utils::bit_reverse_index;
use crate::core::poly::twiddles::TwiddleTree;

// TODO(spapini): Optimized these functions as well.
impl FriOps for CPUBackend {
fn fold_line(eval: &LineEvaluation<Self>, alpha: SecureField) -> LineEvaluation<Self> {
let n = eval.len();
assert!(n >= 2, "Evaluation too small");

let domain = eval.domain();

let folded_values = eval
.values
.into_iter()
.array_chunks()
.enumerate()
.map(|(i, [f_x, f_neg_x])| {
// TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer.
let x = domain.at(bit_reverse_index(i << FOLD_STEP, domain.log_size()));

let (mut f0, mut f1) = (f_x, f_neg_x);
ibutterfly(&mut f0, &mut f1, x.inverse());
f0 + alpha * f1
})
.collect();

LineEvaluation::new(domain.double(), folded_values)
fn fold_line(
eval: &LineEvaluation<Self>,
alpha: SecureField,
_twiddles: &TwiddleTree<Self>,
) -> LineEvaluation<Self> {
fold_line(eval, alpha)
}
fn fold_circle_into_line(
dst: &mut LineEvaluation<Self>,
src: &SecureEvaluation<Self>,
alpha: SecureField,
_twiddles: &TwiddleTree<Self>,
) {
assert_eq!(src.len() >> CIRCLE_TO_LINE_FOLD_STEP, dst.len());

let domain = src.domain;
let alpha_sq = alpha * alpha;

src.into_iter()
.array_chunks()
.enumerate()
.for_each(|(i, [f_p, f_neg_p])| {
// TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer.
let p = domain.at(bit_reverse_index(
i << CIRCLE_TO_LINE_FOLD_STEP,
domain.log_size(),
));

// Calculate `f0(px)` and `f1(px)` such that `2f(p) = f0(px) + py * f1(px)`.
let (mut f0_px, mut f1_px) = (f_p, f_neg_p);
ibutterfly(&mut f0_px, &mut f1_px, p.y.inverse());
let f_prime = alpha * f1_px + f0_px;

dst.values.set(i, dst.values.at(i) * alpha_sq + f_prime);
});
fold_circle_into_line(dst, src, alpha)
}
}
7 changes: 5 additions & 2 deletions src/core/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::commitment_scheme::blake2_hash::Blake2sHash;
use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher;
use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver};
use crate::core::channel::Channel;
use crate::core::poly::twiddles::TwiddleTree;

type MerkleHasher = Blake2sMerkleHasher;
type ProofChannel = Blake2sChannel;
Expand Down Expand Up @@ -65,6 +66,7 @@ impl CommitmentSchemeProver {
&self,
sampled_points: TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>,
channel: &mut ProofChannel,
twiddles: &TwiddleTree<CPUBackend>,
) -> CommitmentSchemeProof {
// Evaluate polynomials on samples points.
let samples = self
Expand All @@ -90,8 +92,9 @@ impl CommitmentSchemeProver {

// Run FRI commitment phase on the oods quotients.
let fri_config = FriConfig::new(LOG_LAST_LAYER_DEGREE_BOUND, LOG_BLOWUP_FACTOR, N_QUERIES);
let fri_prover =
FriProver::<CPUBackend, MerkleHasher>::commit(channel, fri_config, &quotients);
let fri_prover = FriProver::<CPUBackend, MerkleHasher>::commit(
channel, fri_config, &quotients, twiddles,
);

// Proof of work.
let proof_of_work = ProofOfWork::new(PROOF_OF_WORK_BITS).prove(channel);
Expand Down
Loading

0 comments on commit 9d91b8a

Please sign in to comment.