Skip to content

Commit

Permalink
Simplify thread pool configuration (#464)
Browse files Browse the repository at this point in the history
This PR should resolve occasionally hang issue during various condition.
Root cause is (probably) from `RAYON_NUM_THREADS` env parse from
`cons_env` crate, which properly a race condition so the environment var
change doesn't trigger a rebuild effectively. This will cause potiential
mismatch with system rayon thread (for unknown reason), probably a bug
in `cons_env`.

Despite the root cause is not 100% for sure, we can clean up those
complexity, and instead just respect rayon thread pool parsing from
global entry, which greatly simplify the overall flow.


Change has been verified on remote benchmark machine. before/after
didn't cause performance difference

---------

Co-authored-by: Matthias Görgens <[email protected]>
  • Loading branch information
hero78119 and matthiasgoergens authored Oct 28, 2024
1 parent e3ce193 commit 1b8d622
Show file tree
Hide file tree
Showing 19 changed files with 40 additions and 266 deletions.
24 changes: 0 additions & 24 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ version = "0.1.0"
[workspace.dependencies]
ark-std = "0.4"
cfg-if = "1.0"
const_env = "0.1"
criterion = { version = "0.5", features = ["html_reports"] }
crossbeam-channel = "0.5"
ff = "0.13"
Expand Down
3 changes: 0 additions & 3 deletions build.rs

This file was deleted.

1 change: 0 additions & 1 deletion ceno_zkvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ thread_local = "1.1"
[dev-dependencies]
base64 = "0.22"
cfg-if.workspace = true
const_env.workspace = true
criterion.workspace = true
pprof.workspace = true
serde_json.workspace = true
Expand Down
24 changes: 0 additions & 24 deletions ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use ceno_zkvm::{
scheme::prover::ZKVMProver,
structs::{ZKVMConstraintSystem, ZKVMFixedTraces},
};
use const_env::from_env;
use criterion::*;

use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES;
Expand Down Expand Up @@ -37,31 +36,9 @@ cfg_if::cfg_if! {
criterion_main!(op_add);

const NUM_SAMPLES: usize = 10;
#[from_env]
const RAYON_NUM_THREADS: usize = 8;

fn bench_add(c: &mut Criterion) {
type Pcs = BasefoldDefault<E>;
let max_threads = {
if !RAYON_NUM_THREADS.is_power_of_two() {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
{
panic!(
"add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool"
);
}

#[cfg(feature = "non_pow2_rayon_thread")]
{
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS);
create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true);
max_thread_id
}
} else {
RAYON_NUM_THREADS
}
};
let mut zkvm_cs = ZKVMConstraintSystem::default();
let _ = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
Expand Down Expand Up @@ -128,7 +105,6 @@ fn bench_add(c: &mut Criterion) {
commit,
&[],
num_instances,
max_threads,
&mut transcript,
&challenges,
)
Expand Down
27 changes: 1 addition & 26 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use ceno_zkvm::{
tables::{MemFinalRecord, ProgramTableCircuit, initial_memory, initial_registers},
};
use clap::Parser;
use const_env::from_env;

use ceno_emul::{
ByteAddr, CENO_PLATFORM, EmuContext,
Expand All @@ -28,9 +27,6 @@ use tracing_flame::FlameLayer;
use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt};
use transcript::Transcript;

#[from_env]
const RAYON_NUM_THREADS: usize = 8;

const PROGRAM_SIZE: usize = 512;
// For now, we assume registers
// - x0 is not touched,
Expand Down Expand Up @@ -80,27 +76,6 @@ fn main() {
type E = GoldilocksExt2;
type Pcs = Basefold<GoldilocksExt2, BasefoldRSParams, ChaCha8Rng>;

let max_threads = {
if !RAYON_NUM_THREADS.is_power_of_two() {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
{
panic!(
"add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool"
);
}

#[cfg(feature = "non_pow2_rayon_thread")]
{
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS);
create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true);
max_thread_id
}
} else {
RAYON_NUM_THREADS
}
};

let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap();
let subscriber = Registry::default()
.with(
Expand Down Expand Up @@ -237,7 +212,7 @@ fn main() {

let transcript = Transcript::new(b"riscv");
let mut zkvm_proof = prover
.create_proof(zkvm_witness, pi, max_threads, transcript)
.create_proof(zkvm_witness, pi, transcript)
.expect("create_proof failed");

println!(
Expand Down
16 changes: 4 additions & 12 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::{
structs::{
Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses,
},
utils::{get_challenge_pows, next_pow2_instance_padding, proper_num_threads},
utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads},
virtual_polys::VirtualPolynomials,
};

Expand All @@ -52,7 +52,6 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
&self,
witnesses: ZKVMWitnesses<E>,
pi: PublicValues<u32>,
max_threads: usize,
mut transcript: Transcript<E>,
) -> Result<ZKVMProof<E, PCS>, ZKVMError> {
let mut vm_proof = ZKVMProof::empty(pi);
Expand Down Expand Up @@ -135,7 +134,6 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
wits_commit,
pi,
num_instances,
max_threads,
transcript,
&challenges,
)?;
Expand All @@ -155,7 +153,6 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
witness.into_iter().map(|v| v.into()).collect_vec(),
wits_commit,
pi,
max_threads,
transcript,
&challenges,
)?;
Expand Down Expand Up @@ -186,7 +183,6 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
wits_commit: PCS::CommitmentWithData,
pi: &[E::BaseField],
num_instances: usize,
max_threads: usize,
transcript: &mut Transcript<E>,
challenges: &[E; 2],
) -> Result<ZKVMOpcodeProof<E, PCS>, ZKVMError> {
Expand Down Expand Up @@ -320,7 +316,6 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
let lk_q2_out_eval = lk_wit_layers[0][3].get_ext_field_vec()[0];
assert!(record_r_out_evals.len() == NUM_FANIN && record_w_out_evals.len() == NUM_FANIN);
let (rt_tower, tower_proof) = TowerProver::create_proof(
max_threads,
vec![
TowerProverSpec {
witness: r_wit_layers,
Expand Down Expand Up @@ -363,7 +358,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
rt_tower[..log2_num_instances].to_vec(),
);

let num_threads = proper_num_threads(log2_num_instances, max_threads);
let num_threads = optimal_sumcheck_threads(log2_num_instances);
let alpha_pow = get_challenge_pows(
MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + cs.assert_zero_sumcheck_expressions.len(),
transcript,
Expand Down Expand Up @@ -624,7 +619,6 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
witnesses: Vec<ArcMultilinearExtension<'_, E>>,
wits_commit: PCS::CommitmentWithData,
pi: &[E::BaseField],
max_threads: usize,
transcript: &mut Transcript<E>,
challenges: &[E; 2],
) -> Result<ZKVMTableProof<E, PCS>, ZKVMError> {
Expand Down Expand Up @@ -843,7 +837,6 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.collect_vec();

let (rt_tower, tower_proof) = TowerProver::create_proof(
max_threads,
// pattern [r1, w1, r2, w2, ...] same pair are chain together
r_wit_layers
.into_iter()
Expand Down Expand Up @@ -884,7 +877,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
// If all table length are the same, we can skip this sumcheck
let span = entered_span!("sumcheck::opening_same_point");
// NOTE: max concurrency will be dominated by smallest table since it will blo
let num_threads = proper_num_threads(min_log2_num_instance, max_threads);
let num_threads = optimal_sumcheck_threads(min_log2_num_instance);
let alpha_pow = get_challenge_pows(
cs.r_table_expressions.len()
+ cs.w_table_expressions.len()
Expand Down Expand Up @@ -1074,7 +1067,6 @@ impl<E: ExtensionField> TowerProofs<E> {
/// Tower Prover
impl TowerProver {
pub fn create_proof<'a, E: ExtensionField>(
max_threads: usize,
prod_specs: Vec<TowerProverSpec<'a, E>>,
logup_specs: Vec<TowerProverSpec<'a, E>>,
num_fanin: usize,
Expand Down Expand Up @@ -1109,7 +1101,7 @@ impl TowerProver {
let (next_rt, _) =
(1..=max_round_index).fold((initial_rt, alpha_pows), |(out_rt, alpha_pows), round| {
// in first few round we just run on single thread
let num_threads = proper_num_threads(out_rt.len(), max_threads);
let num_threads = optimal_sumcheck_threads(out_rt.len());

let eq: ArcMultilinearExtension<E> = build_eq_x_r_vec(&out_rt).into_mle().into();
let mut virtual_polys = VirtualPolynomials::<E>::new(num_threads, out_rt.len());
Expand Down
3 changes: 1 addition & 2 deletions ceno_zkvm/src/scheme/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ fn test_rw_lk_expression_combination() {
commit,
&[],
num_instances,
1,
&mut transcript,
&prover_challenges,
)
Expand Down Expand Up @@ -290,7 +289,7 @@ fn test_single_add_instance_e2e() {
let pi = PublicValues::new(0, 0, 0, 0, 0);
let transcript = Transcript::new(b"riscv");
let zkvm_proof = prover
.create_proof(zkvm_witness, pi, 1, transcript)
.create_proof(zkvm_witness, pi, transcript)
.expect("create_proof failed");

let transcript = Transcript::new(b"riscv");
Expand Down
4 changes: 3 additions & 1 deletion ceno_zkvm/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use ff::Field;
use ff_ext::ExtensionField;
use goldilocks::SmallField;
use itertools::Itertools;
use multilinear_extensions::util::max_usable_threads;
use transcript::Transcript;

/// convert ext field element to u64, assume it is inside the range
Expand Down Expand Up @@ -113,7 +114,8 @@ pub fn u64vec<const W: usize, const C: usize>(x: u64) -> [u64; W] {

/// we expect each thread at least take 4 num of sumcheck variables
/// return optimal num threads to run sumcheck
pub fn proper_num_threads(num_vars: usize, expected_max_threads: usize) -> usize {
pub fn optimal_sumcheck_threads(num_vars: usize) -> usize {
let expected_max_threads = max_usable_threads();
let min_numvar_per_thread = 4;
if num_vars <= min_numvar_per_thread {
1
Expand Down
1 change: 0 additions & 1 deletion gkr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ ark-std.workspace = true
ff.workspace = true
goldilocks.workspace = true

const_env.workspace = true
crossbeam-channel.workspace = true
ff_ext = { path = "../ff_ext" }
itertools.workspace = true
Expand Down
25 changes: 2 additions & 23 deletions gkr/benches/keccak256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

use std::time::Duration;

use const_env::from_env;
use criterion::*;

use gkr::gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak256};
use goldilocks::GoldilocksExt2;
use multilinear_extensions::util::max_usable_threads;

cfg_if::cfg_if! {
if #[cfg(feature = "flamegraph")] {
Expand All @@ -28,35 +28,14 @@ cfg_if::cfg_if! {
criterion_main!(keccak256);

const NUM_SAMPLES: usize = 10;
#[from_env]
const RAYON_NUM_THREADS: usize = 8;

fn bench_keccak256(c: &mut Criterion) {
println!(
"#layers: {}",
keccak256_circuit::<GoldilocksExt2>().layers.len()
);

let max_thread_id = {
if !RAYON_NUM_THREADS.is_power_of_two() {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
{
panic!(
"add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool"
);
}

#[cfg(feature = "non_pow2_rayon_thread")]
{
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS);
create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true);
max_thread_id
}
} else {
RAYON_NUM_THREADS
}
};
let max_thread_id = max_usable_threads();

let circuit = keccak256_circuit::<GoldilocksExt2>();

Expand Down
18 changes: 18 additions & 0 deletions multilinear_extensions/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,21 @@ pub fn create_uninit_vec<T: Sized>(len: usize) -> Vec<MaybeUninit<T>> {
pub fn largest_even_below(n: usize) -> usize {
if n % 2 == 0 { n } else { n.saturating_sub(1) }
}

fn prev_power_of_two(n: usize) -> usize {
(n + 1).next_power_of_two() / 2
}

/// Largest power of two that fits the available rayon threads
pub fn max_usable_threads() -> usize {
if cfg!(test) {
1
} else {
let n = rayon::current_num_threads();
let threads = prev_power_of_two(n);
if n != threads {
tracing::warn!("thread size {n} is not power of 2, using {threads} threads instead.");
}
threads
}
}
Loading

0 comments on commit 1b8d622

Please sign in to comment.