From b80533c89f6701dd6a675d9a10d04d4fb0888fe0 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Thu, 5 Sep 2024 10:57:41 +0300 Subject: [PATCH] Parallelize poseidon trace gen --- crates/prover/src/core/backend/simd/circle.rs | 2 ++ crates/prover/src/examples/blake/air.rs | 2 -- crates/prover/src/examples/plonk/mod.rs | 2 -- crates/prover/src/examples/poseidon/mod.rs | 21 +++++++++++++++---- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index 66f67d52b..1b16ac07d 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -4,6 +4,7 @@ use std::mem::transmute; use bytemuck::{cast_slice, Zeroable}; use itertools::Itertools; use num_traits::One; +use tracing::{span, Level}; use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE}; use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; @@ -280,6 +281,7 @@ impl PolyOps for SimdBackend { #[allow(clippy::int_plus_one)] fn precompute_twiddles(mut coset: Coset) -> TwiddleTree { + let _span = span!(Level::INFO, "Compute twiddles").entered(); let root_coset = coset; // Generate xs for descending cosets, each bit reversed. diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index e655ee658..0c3e7835a 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -228,7 +228,6 @@ where ); // Precompute twiddles. - let span = span!(Level::INFO, "Precompute twiddles").entered(); const XOR_TABLE_MAX_LOG_SIZE: u32 = 16; let log_max_rows = (log_size + *ROUND_LOG_SPLIT.iter().max().unwrap()).max(XOR_TABLE_MAX_LOG_SIZE); @@ -237,7 +236,6 @@ where .circle_domain() .half_coset, ); - span.exit(); // Prepare inputs. let blake_inputs = (0..(1 << (log_size - LOG_N_LANES))) diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 58248a0bc..547fd6ef5 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -172,13 +172,11 @@ pub fn prove_fibonacci_plonk( circuit.mult.set((1 << log_n_rows) - 2, 1.into()); // Precompute twiddles. - let span = span!(Level::INFO, "Precompute twiddles").entered(); let twiddles = SimdBackend::precompute_twiddles( CanonicCoset::new(log_n_rows + config.fri_config.log_blowup_factor + 1) .circle_domain() .half_coset, ); - span.exit(); // Setup protocol. let channel = &mut Blake2sChannel::default(); diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index c94f0ba1d..155efe54e 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -4,6 +4,8 @@ use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; use num_traits::One; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use tracing::{span, Level}; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; @@ -13,6 +15,7 @@ use crate::constraint_framework::{ use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::utils::UnsafeShared; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Col, Column}; use crate::core::channel::Blake2sChannel; @@ -214,7 +217,19 @@ pub fn gen_trace( }), }; - for vec_index in 0..(1 << (log_size - LOG_N_LANES)) { + let iter_range = 0..1 << (log_size - LOG_N_LANES); + + #[cfg(not(feature = "parallel"))] + let iter = iter_range; + + #[cfg(feature = "parallel")] + let iter = iter_range.into_par_iter(); + + let borrowed_trace = unsafe { UnsafeShared::new(&mut trace) }; + let borrowed_lookup_data = unsafe { UnsafeShared::new(&mut lookup_data) }; + iter.for_each(|vec_index| { + let trace = unsafe { borrowed_trace.get() }; + let lookup_data = unsafe { borrowed_lookup_data.get() }; // Initial state. let mut col_index = 0; for rep_i in 0..N_INSTANCES_PER_ROW { @@ -274,7 +289,7 @@ pub fn gen_trace( .zip(state) .for_each(|(res, state_i)| res.data[vec_index] = state_i); } - } + }); let domain = CanonicCoset::new(log_size).circle_domain(); let trace = trace .into_iter() @@ -326,13 +341,11 @@ pub fn prove_poseidon( let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32; // Precompute twiddles. - let span = span!(Level::INFO, "Precompute twiddles").entered(); let twiddles = SimdBackend::precompute_twiddles( CanonicCoset::new(log_n_rows + LOG_EXPAND + config.fri_config.log_blowup_factor) .circle_domain() .half_coset, ); - span.exit(); // Setup protocol. let channel = &mut Blake2sChannel::default();