Skip to content

Commit

Permalink
Logup trace generation
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Jul 17, 2024
1 parent 70ce843 commit cd56fad
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 88 deletions.
25 changes: 5 additions & 20 deletions crates/prover/benches/poseidon.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::channel::{Blake2sChannel, Channel};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::IntoSlice;
use stwo_prover::core::vcs::blake2_hash::Blake2sHasher;
use stwo_prover::core::vcs::hasher::Hasher;
use stwo_prover::examples::poseidon::{gen_trace, PoseidonAir, PoseidonComponent};
use stwo_prover::trace_generation::commit_and_prove;
use stwo_prover::examples::poseidon::prove_poseidon;

pub fn simd_poseidon(c: &mut Criterion) {
const LOG_N_ROWS: u32 = 15;
const LOG_N_INSTANCES: u32 = 18;
let mut group = c.benchmark_group("poseidon2");
group.throughput(Throughput::Elements(1u64 << (LOG_N_ROWS + 3)));
group.bench_function(format!("poseidon2 2^{} instances", LOG_N_ROWS + 3), |b| {
b.iter(|| {
let component = PoseidonComponent {
log_n_rows: LOG_N_ROWS,
};
let trace = gen_trace(component.log_column_size());
let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
let air = PoseidonAir { component };
commit_and_prove::<SimdBackend>(&air, channel, trace).unwrap()
});
group.throughput(Throughput::Elements(1u64 << LOG_N_INSTANCES));
group.bench_function(format!("poseidon2 2^{} instances", LOG_N_INSTANCES), |b| {
b.iter(|| prove_poseidon(LOG_N_INSTANCES));
});
}

Expand Down
134 changes: 134 additions & 0 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use itertools::Itertools;
use num_traits::Zero;
use tracing::{span, Level};

use crate::core::backend::simd::column::SecureFieldVec;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Backend, Column};
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;
use crate::core::ColumnVec;

#[derive(Copy, Clone, Debug)]
pub struct LookupElements {
pub z: SecureField,
pub alpha: SecureField,
}
impl LookupElements {
pub fn draw(channel: &mut Blake2sChannel) -> Self {
let [z, alpha] = channel.draw_felts(2).try_into().unwrap();
Self { z, alpha }
}
}

// SIMD backend generator.
pub struct LogupTraceGenerator {
log_size: u32,
trace: Vec<SecureColumn<SimdBackend>>,
denom: SecureFieldVec,
denom_inv: SecureFieldVec,
}
impl LogupTraceGenerator {
pub fn new(log_size: u32) -> Self {
let trace = vec![];
let denom = SecureFieldVec::zeros(1 << log_size);
let denom_inv = SecureFieldVec::zeros(1 << log_size);
Self {
log_size,
trace,
denom,
denom_inv,
}
}

pub fn new_col(&mut self) -> LogupColGenerator<'_> {
let log_size = self.log_size;
LogupColGenerator {
gen: self,
numerator: SecureColumn::<SimdBackend>::zeros(1 << log_size),
}
}

pub fn finalize(
mut self,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
SecureField,
) {
let claimed_xor_sum = eval_order_prefix_sum(self.trace.last_mut().unwrap(), self.log_size);

let trace = self
.trace
.into_iter()
.flat_map(|eval| {
eval.columns.map(|c| {
CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(
CanonicCoset::new(self.log_size).circle_domain(),
c,
)
})
})
.collect_vec();
(trace, claimed_xor_sum)
}
}

pub struct LogupColGenerator<'a> {
gen: &'a mut LogupTraceGenerator,
numerator: SecureColumn<SimdBackend>,
}
impl<'a> LogupColGenerator<'a> {
pub fn write_frac(&mut self, vec_row: usize, p: PackedSecureField, q: PackedSecureField) {
unsafe {
self.numerator.set_packed(vec_row, p);
*self.gen.denom.data.get_unchecked_mut(vec_row) = q;
}
}

pub fn finalize_col(mut self) {
FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data);

#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) {
unsafe {
let value = self.numerator.packed_at(vec_row)
* *self.gen.denom_inv.data.get_unchecked(vec_row);
let prev_value = self
.gen
.trace
.last()
.map(|col| col.packed_at(vec_row))
.unwrap_or_else(PackedSecureField::zero);
self.numerator.set_packed(vec_row, value + prev_value)
};
}

self.gen.trace.push(self.numerator)
}
}

// TODO(spapini): Consider adding optional Ops.
pub fn eval_order_prefix_sum<B: Backend>(col: &mut SecureColumn<B>, log_size: u32) -> SecureField {
let _span = span!(Level::INFO, "Prefix sum").entered();

let mut cur = SecureField::zero();
for i in 0..(1 << log_size) {
let index = if i & 1 == 0 {
i / 2
} else {
(1 << (log_size - 1)) + ((1 << log_size) - 1 - i) / 2
};
let index = bit_reverse_index(index, log_size);
cur += col.at(index);
col.set(index, cur);
}
cur
}
1 change: 1 addition & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod assert;
pub mod constant_columns;
mod info;
pub mod logup;
mod point;
mod simd_domain;

Expand Down
14 changes: 6 additions & 8 deletions crates/prover/src/core/fields/secure_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use super::m31::BaseField;
use super::qm31::SecureField;
use super::{ExtensionOf, FieldOps};
use crate::core::backend::{Col, Column, CpuBackend};
use crate::core::utils::IteratorMutExt;

pub const SECURE_EXTENSION_DEGREE: usize =
<SecureField as ExtensionOf<BaseField>>::EXTENSION_DEGREE;
Expand All @@ -14,13 +13,6 @@ pub struct SecureColumn<B: FieldOps<BaseField>> {
pub columns: [Col<B, BaseField>; SECURE_EXTENSION_DEGREE],
}
impl SecureColumn<CpuBackend> {
pub fn set(&mut self, index: usize, value: SecureField) {
self.columns
.iter_mut()
.map(|c| &mut c[index])
.assign(value.to_m31_array());
}

// TODO(spapini): Remove when we no longer use CircleEvaluation<SecureField>.
pub fn to_vec(&self) -> Vec<SecureField> {
(0..self.len()).map(|i| self.at(i)).collect()
Expand Down Expand Up @@ -50,6 +42,12 @@ impl<B: FieldOps<BaseField>> SecureColumn<B> {
columns: self.columns.clone().map(|c| c.to_cpu()),
}
}

pub fn set(&mut self, index: usize, value: SecureField) {
for i in 0..SECURE_EXTENSION_DEGREE {
self.columns[i].set(index, value.to_m31_array()[i]);
}
}
}

pub struct SecureColumnIter<'a> {
Expand Down
Loading

0 comments on commit cd56fad

Please sign in to comment.