Skip to content

Commit

Permalink
Parallelize poseidon trace gen
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Sep 5, 2024
1 parent dfa1003 commit b80533c
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
2 changes: 2 additions & 0 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::mem::transmute;
use bytemuck::{cast_slice, Zeroable};
use itertools::Itertools;
use num_traits::One;
use tracing::{span, Level};

use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE};
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
Expand Down Expand Up @@ -280,6 +281,7 @@ impl PolyOps for SimdBackend {

#[allow(clippy::int_plus_one)]
fn precompute_twiddles(mut coset: Coset) -> TwiddleTree<Self> {
let _span = span!(Level::INFO, "Compute twiddles").entered();
let root_coset = coset;

// Generate xs for descending cosets, each bit reversed.
Expand Down
2 changes: 0 additions & 2 deletions crates/prover/src/examples/blake/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ where
);

// Precompute twiddles.
let span = span!(Level::INFO, "Precompute twiddles").entered();
const XOR_TABLE_MAX_LOG_SIZE: u32 = 16;
let log_max_rows =
(log_size + *ROUND_LOG_SPLIT.iter().max().unwrap()).max(XOR_TABLE_MAX_LOG_SIZE);
Expand All @@ -237,7 +236,6 @@ where
.circle_domain()
.half_coset,
);
span.exit();

// Prepare inputs.
let blake_inputs = (0..(1 << (log_size - LOG_N_LANES)))
Expand Down
2 changes: 0 additions & 2 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,11 @@ pub fn prove_fibonacci_plonk(
circuit.mult.set((1 << log_n_rows) - 2, 1.into());

// Precompute twiddles.
let span = span!(Level::INFO, "Precompute twiddles").entered();
let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(log_n_rows + config.fri_config.log_blowup_factor + 1)
.circle_domain()
.half_coset,
);
span.exit();

// Setup protocol.
let channel = &mut Blake2sChannel::default();
Expand Down
21 changes: 17 additions & 4 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::ops::{Add, AddAssign, Mul, Sub};

use itertools::Itertools;
use num_traits::One;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use tracing::{span, Level};

use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements};
Expand All @@ -13,6 +15,7 @@ use crate::constraint_framework::{
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::utils::UnsafeShared;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::channel::Blake2sChannel;
Expand Down Expand Up @@ -214,7 +217,19 @@ pub fn gen_trace(
}),
};

for vec_index in 0..(1 << (log_size - LOG_N_LANES)) {
let iter_range = 0..1 << (log_size - LOG_N_LANES);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let borrowed_trace = unsafe { UnsafeShared::new(&mut trace) };
let borrowed_lookup_data = unsafe { UnsafeShared::new(&mut lookup_data) };
iter.for_each(|vec_index| {
let trace = unsafe { borrowed_trace.get() };
let lookup_data = unsafe { borrowed_lookup_data.get() };
// Initial state.
let mut col_index = 0;
for rep_i in 0..N_INSTANCES_PER_ROW {
Expand Down Expand Up @@ -274,7 +289,7 @@ pub fn gen_trace(
.zip(state)
.for_each(|(res, state_i)| res.data[vec_index] = state_i);
}
}
});
let domain = CanonicCoset::new(log_size).circle_domain();
let trace = trace
.into_iter()
Expand Down Expand Up @@ -326,13 +341,11 @@ pub fn prove_poseidon(
let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32;

// Precompute twiddles.
let span = span!(Level::INFO, "Precompute twiddles").entered();
let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(log_n_rows + LOG_EXPAND + config.fri_config.log_blowup_factor)
.circle_domain()
.half_coset,
);
span.exit();

// Setup protocol.
let channel = &mut Blake2sChannel::default();
Expand Down

0 comments on commit b80533c

Please sign in to comment.