From 1b8d622e077dfc0aeb4cc2700f9c58ef6ace0d40 Mon Sep 17 00:00:00 2001 From: Ming Date: Mon, 28 Oct 2024 16:54:47 +0800 Subject: [PATCH] Simplify thread pool configuration (#464) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR should resolve occasionally hang issue during various condition. Root cause is (probably) from `RAYON_NUM_THREADS` env parse from `cons_env` crate, which properly a race condition so the environment var change doesn't trigger a rebuild effectively. This will cause potiential mismatch with system rayon thread (for unknown reason), probably a bug in `cons_env`. Despite the root cause is not 100% for sure, we can clean up those complexity, and instead just respect rayon thread pool parsing from global entry, which greatly simplify the overall flow. Change has been verified on remote benchmark machine. before/after didn't cause performance difference --------- Co-authored-by: Matthias Görgens --- Cargo.lock | 24 ----- Cargo.toml | 1 - build.rs | 3 - ceno_zkvm/Cargo.toml | 1 - ceno_zkvm/benches/riscv_add.rs | 24 ----- ceno_zkvm/examples/riscv_opcodes.rs | 27 +---- ceno_zkvm/src/scheme/prover.rs | 16 +-- ceno_zkvm/src/scheme/tests.rs | 3 +- ceno_zkvm/src/utils.rs | 4 +- gkr/Cargo.toml | 1 - gkr/benches/keccak256.rs | 25 +---- multilinear_extensions/src/util.rs | 18 ++++ multilinear_extensions/src/virtual_poly.rs | 5 +- singer/Cargo.toml | 1 - singer/benches/add.rs | 25 +---- sumcheck/Cargo.toml | 1 - sumcheck/benches/devirgo_sumcheck.rs | 14 ++- sumcheck/examples/devirgo_sumcheck.rs | 112 --------------------- sumcheck/src/prover_v2.rs | 1 + 19 files changed, 40 insertions(+), 266 deletions(-) delete mode 100644 build.rs delete mode 100644 sumcheck/examples/devirgo_sumcheck.rs diff --git a/Cargo.lock b/Cargo.lock index bf37a3132..c80dfdedb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -262,7 +262,6 @@ dependencies = [ "ceno_emul", "cfg-if", "clap", - "const_env", "criterion", "ff", "ff_ext", @@ -409,26 +408,6 @@ dependencies = [ "tiny-keccak", ] -[[package]] -name = "const_env" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e9e4f72c6e3398ca6da372abd9affd8f89781fe728869bbf986206e9af9627e" -dependencies = [ - "const_env_impl", -] - -[[package]] -name = "const_env_impl" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a4f51209740b5e1589e702b3044cdd4562cef41b6da404904192ffffb852d62" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "constant_time_eq" version = "0.3.1" @@ -719,7 +698,6 @@ version = "0.1.0" dependencies = [ "ark-std", "cfg-if", - "const_env", "criterion", "crossbeam-channel", "ff", @@ -1692,7 +1670,6 @@ version = "0.1.0" dependencies = [ "ark-std", "cfg-if", - "const_env", "criterion", "ff", "ff_ext", @@ -1800,7 +1777,6 @@ name = "sumcheck" version = "0.1.0" dependencies = [ "ark-std", - "const_env", "criterion", "crossbeam-channel", "ff", diff --git a/Cargo.toml b/Cargo.toml index fc601342b..ea6d35b27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,6 @@ version = "0.1.0" [workspace.dependencies] ark-std = "0.4" cfg-if = "1.0" -const_env = "0.1" criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" ff = "0.13" diff --git a/build.rs b/build.rs deleted file mode 100644 index 3e31cb0a9..000000000 --- a/build.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("cargo:rerun-if-env-changed=RAYON_NUM_THREADS"); -} diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 36d2e77cb..ac16a6de9 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -36,7 +36,6 @@ thread_local = "1.1" [dev-dependencies] base64 = "0.22" cfg-if.workspace = true -const_env.workspace = true criterion.workspace = true pprof.workspace = true serde_json.workspace = true diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 9d69e120f..16d5cfe67 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -7,7 +7,6 @@ use ceno_zkvm::{ scheme::prover::ZKVMProver, structs::{ZKVMConstraintSystem, ZKVMFixedTraces}, }; -use const_env::from_env; use criterion::*; use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES; @@ -37,31 +36,9 @@ cfg_if::cfg_if! { criterion_main!(op_add); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; fn bench_add(c: &mut Criterion) { type Pcs = BasefoldDefault; - let max_threads = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; let mut zkvm_cs = ZKVMConstraintSystem::default(); let _ = zkvm_cs.register_opcode_circuit::>(); let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); @@ -128,7 +105,6 @@ fn bench_add(c: &mut Criterion) { commit, &[], num_instances, - max_threads, &mut transcript, &challenges, ) diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index d322d360c..b979975d2 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -8,7 +8,6 @@ use ceno_zkvm::{ tables::{MemFinalRecord, ProgramTableCircuit, initial_memory, initial_registers}, }; use clap::Parser; -use const_env::from_env; use ceno_emul::{ ByteAddr, CENO_PLATFORM, EmuContext, @@ -28,9 +27,6 @@ use tracing_flame::FlameLayer; use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt}; use transcript::Transcript; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - const PROGRAM_SIZE: usize = 512; // For now, we assume registers // - x0 is not touched, @@ -80,27 +76,6 @@ fn main() { type E = GoldilocksExt2; type Pcs = Basefold; - let max_threads = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; - let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() .with( @@ -237,7 +212,7 @@ fn main() { let transcript = Transcript::new(b"riscv"); let mut zkvm_proof = prover - .create_proof(zkvm_witness, pi, max_threads, transcript) + .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); println!( diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2c92197a5..c87283c01 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -32,7 +32,7 @@ use crate::{ structs::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, - utils::{get_challenge_pows, next_pow2_instance_padding, proper_num_threads}, + utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, virtual_polys::VirtualPolynomials, }; @@ -52,7 +52,6 @@ impl> ZKVMProver { &self, witnesses: ZKVMWitnesses, pi: PublicValues, - max_threads: usize, mut transcript: Transcript, ) -> Result, ZKVMError> { let mut vm_proof = ZKVMProof::empty(pi); @@ -135,7 +134,6 @@ impl> ZKVMProver { wits_commit, pi, num_instances, - max_threads, transcript, &challenges, )?; @@ -155,7 +153,6 @@ impl> ZKVMProver { witness.into_iter().map(|v| v.into()).collect_vec(), wits_commit, pi, - max_threads, transcript, &challenges, )?; @@ -186,7 +183,6 @@ impl> ZKVMProver { wits_commit: PCS::CommitmentWithData, pi: &[E::BaseField], num_instances: usize, - max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { @@ -320,7 +316,6 @@ impl> ZKVMProver { let lk_q2_out_eval = lk_wit_layers[0][3].get_ext_field_vec()[0]; assert!(record_r_out_evals.len() == NUM_FANIN && record_w_out_evals.len() == NUM_FANIN); let (rt_tower, tower_proof) = TowerProver::create_proof( - max_threads, vec![ TowerProverSpec { witness: r_wit_layers, @@ -363,7 +358,7 @@ impl> ZKVMProver { rt_tower[..log2_num_instances].to_vec(), ); - let num_threads = proper_num_threads(log2_num_instances, max_threads); + let num_threads = optimal_sumcheck_threads(log2_num_instances); let alpha_pow = get_challenge_pows( MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + cs.assert_zero_sumcheck_expressions.len(), transcript, @@ -624,7 +619,6 @@ impl> ZKVMProver { witnesses: Vec>, wits_commit: PCS::CommitmentWithData, pi: &[E::BaseField], - max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { @@ -843,7 +837,6 @@ impl> ZKVMProver { .collect_vec(); let (rt_tower, tower_proof) = TowerProver::create_proof( - max_threads, // pattern [r1, w1, r2, w2, ...] same pair are chain together r_wit_layers .into_iter() @@ -884,7 +877,7 @@ impl> ZKVMProver { // If all table length are the same, we can skip this sumcheck let span = entered_span!("sumcheck::opening_same_point"); // NOTE: max concurrency will be dominated by smallest table since it will blo - let num_threads = proper_num_threads(min_log2_num_instance, max_threads); + let num_threads = optimal_sumcheck_threads(min_log2_num_instance); let alpha_pow = get_challenge_pows( cs.r_table_expressions.len() + cs.w_table_expressions.len() @@ -1074,7 +1067,6 @@ impl TowerProofs { /// Tower Prover impl TowerProver { pub fn create_proof<'a, E: ExtensionField>( - max_threads: usize, prod_specs: Vec>, logup_specs: Vec>, num_fanin: usize, @@ -1109,7 +1101,7 @@ impl TowerProver { let (next_rt, _) = (1..=max_round_index).fold((initial_rt, alpha_pows), |(out_rt, alpha_pows), round| { // in first few round we just run on single thread - let num_threads = proper_num_threads(out_rt.len(), max_threads); + let num_threads = optimal_sumcheck_threads(out_rt.len()); let eq: ArcMultilinearExtension = build_eq_x_r_vec(&out_rt).into_mle().into(); let mut virtual_polys = VirtualPolynomials::::new(num_threads, out_rt.len()); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 80c43af0b..f928a8918 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -141,7 +141,6 @@ fn test_rw_lk_expression_combination() { commit, &[], num_instances, - 1, &mut transcript, &prover_challenges, ) @@ -290,7 +289,7 @@ fn test_single_add_instance_e2e() { let pi = PublicValues::new(0, 0, 0, 0, 0); let transcript = Transcript::new(b"riscv"); let zkvm_proof = prover - .create_proof(zkvm_witness, pi, 1, transcript) + .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); let transcript = Transcript::new(b"riscv"); diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 119b13439..47af02967 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -2,6 +2,7 @@ use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; +use multilinear_extensions::util::max_usable_threads; use transcript::Transcript; /// convert ext field element to u64, assume it is inside the range @@ -113,7 +114,8 @@ pub fn u64vec(x: u64) -> [u64; W] { /// we expect each thread at least take 4 num of sumcheck variables /// return optimal num threads to run sumcheck -pub fn proper_num_threads(num_vars: usize, expected_max_threads: usize) -> usize { +pub fn optimal_sumcheck_threads(num_vars: usize) -> usize { + let expected_max_threads = max_usable_threads(); let min_numvar_per_thread = 4; if num_vars <= min_numvar_per_thread { 1 diff --git a/gkr/Cargo.toml b/gkr/Cargo.toml index 702791bd0..fab3ec6b8 100644 --- a/gkr/Cargo.toml +++ b/gkr/Cargo.toml @@ -9,7 +9,6 @@ ark-std.workspace = true ff.workspace = true goldilocks.workspace = true -const_env.workspace = true crossbeam-channel.workspace = true ff_ext = { path = "../ff_ext" } itertools.workspace = true diff --git a/gkr/benches/keccak256.rs b/gkr/benches/keccak256.rs index fce8ffb31..d48920dd7 100644 --- a/gkr/benches/keccak256.rs +++ b/gkr/benches/keccak256.rs @@ -3,11 +3,11 @@ use std::time::Duration; -use const_env::from_env; use criterion::*; use gkr::gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak256}; use goldilocks::GoldilocksExt2; +use multilinear_extensions::util::max_usable_threads; cfg_if::cfg_if! { if #[cfg(feature = "flamegraph")] { @@ -28,8 +28,6 @@ cfg_if::cfg_if! { criterion_main!(keccak256); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; fn bench_keccak256(c: &mut Criterion) { println!( @@ -37,26 +35,7 @@ fn bench_keccak256(c: &mut Criterion) { keccak256_circuit::().layers.len() ); - let max_thread_id = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; + let max_thread_id = max_usable_threads(); let circuit = keccak256_circuit::(); diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index 28e4f8284..a0a8e56a2 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -30,3 +30,21 @@ pub fn create_uninit_vec(len: usize) -> Vec> { pub fn largest_even_below(n: usize) -> usize { if n % 2 == 0 { n } else { n.saturating_sub(1) } } + +fn prev_power_of_two(n: usize) -> usize { + (n + 1).next_power_of_two() / 2 +} + +/// Largest power of two that fits the available rayon threads +pub fn max_usable_threads() -> usize { + if cfg!(test) { + 1 + } else { + let n = rayon::current_num_threads(); + let threads = prev_power_of_two(n); + if n != threads { + tracing::warn!("thread size {n} is not power of 2, using {threads} threads instead."); + } + threads + } +} diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index a5468ad90..45e4e9fb7 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -2,7 +2,7 @@ use std::{cmp::max, collections::HashMap, marker::PhantomData, mem::MaybeUninit, use crate::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - util::{bit_decompose, create_uninit_vec}, + util::{bit_decompose, create_uninit_vec, max_usable_threads}, }; use ark_std::{end_timer, iterable::Iterable, rand::Rng, start_timer}; use ff::{Field, PrimeField}; @@ -452,8 +452,7 @@ pub fn build_eq_x_r_vec(r: &[E]) -> Vec { // .... // 1 1 1 1 -> r0 * r1 * r2 * r3 // we will need 2^num_var evaluations - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let nthreads = max_usable_threads(); let nbits = nthreads.trailing_zeros() as usize; assert_eq!(1 << nbits, nthreads); diff --git a/singer/Cargo.toml b/singer/Cargo.toml index e6c10602e..ac71ab4b9 100644 --- a/singer/Cargo.toml +++ b/singer/Cargo.toml @@ -28,7 +28,6 @@ tracing-subscriber.workspace = true [dev-dependencies] cfg-if.workspace = true -const_env.workspace = true criterion.workspace = true pprof.workspace = true tracing.workspace = true diff --git a/singer/benches/add.rs b/singer/benches/add.rs index 70fee6f28..5984a19b2 100644 --- a/singer/benches/add.rs +++ b/singer/benches/add.rs @@ -4,7 +4,6 @@ use std::time::{Duration, Instant}; use ark_std::test_rng; -use const_env::from_env; use criterion::*; use ff_ext::{ExtensionField, ff::Field}; @@ -30,9 +29,8 @@ cfg_if::cfg_if! { criterion_main!(op_add); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; +use multilinear_extensions::util::max_usable_threads; use singer::{ CircuitWiresIn, SingerGraphBuilder, SingerParams, instructions::{Instruction, InstructionGraph, SingerCircuitBuilder, add::AddInstruction}, @@ -42,26 +40,7 @@ use singer_utils::structs::ChipChallenges; use transcript::Transcript; fn bench_add(c: &mut Criterion) { - let max_thread_id = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; + let max_thread_id = max_usable_threads(); let chip_challenges = ChipChallenges::default(); let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 4b9605cdf..540495c0a 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -6,7 +6,6 @@ version.workspace = true [dependencies] ark-std.workspace = true -const_env.workspace = true ff.workspace = true ff_ext = { path = "../ff_ext" } goldilocks.workspace = true diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index 7116789b9..fd33b9d09 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -4,7 +4,6 @@ use std::array; use ark_std::test_rng; -use const_env::from_env; use criterion::*; use ff_ext::ExtensionField; use itertools::Itertools; @@ -14,6 +13,7 @@ use goldilocks::GoldilocksExt2; use multilinear_extensions::{ mle::DenseMultilinearExtension, op_mle, + util::max_usable_threads, virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2 as VirtualPolynomial}, }; use transcript::Transcript; @@ -41,10 +41,10 @@ pub fn transpose(v: Vec>) -> Vec> { } fn prepare_input<'a, E: ExtensionField>( - max_thread_id: usize, nv: usize, ) -> (E, VirtualPolynomial<'a, E>, Vec>) { let mut rng = test_rng(); + let max_thread_id = max_usable_threads(); let size_log2 = ceil_log2(max_thread_id); let fs: [ArcMultilinearExtension<'a, E>; NUM_DEGREE] = array::from_fn(|_| { let mle: ArcMultilinearExtension<'a, E> = @@ -100,9 +100,6 @@ fn prepare_input<'a, E: ExtensionField>( (asserted_sum, virtual_poly_v1, virtual_poly_v2) } -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; @@ -119,7 +116,7 @@ fn sumcheck_fn(c: &mut Criterion) { || { let prover_transcript = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; + { prepare_input(nv) }; ( prover_transcript, asserted_sum, @@ -150,6 +147,7 @@ fn sumcheck_fn(c: &mut Criterion) { fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; + let threads = max_usable_threads(); for nv in NV.into_iter() { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv)); @@ -163,7 +161,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { || { let prover_transcript = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; + { prepare_input(nv) }; ( prover_transcript, asserted_sum, @@ -178,7 +176,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { virtual_poly_splitted, )| { let (_sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, + threads, virtual_poly_splitted, &mut prover_transcript, ); diff --git a/sumcheck/examples/devirgo_sumcheck.rs b/sumcheck/examples/devirgo_sumcheck.rs deleted file mode 100644 index 29ff81368..000000000 --- a/sumcheck/examples/devirgo_sumcheck.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::sync::Arc; - -use ark_std::test_rng; -use const_env::from_env; -use ff_ext::{ExtensionField, ff::Field}; -use goldilocks::GoldilocksExt2; -use itertools::Itertools; -use multilinear_extensions::{ - commutative_op_mle_pair, - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - virtual_poly::VirtualPolynomial, -}; -use sumcheck::{ - structs::{IOPProverState, IOPVerifierState}, - util::ceil_log2, -}; -use transcript::Transcript; - -type E = GoldilocksExt2; - -fn prepare_input( - max_thread_id: usize, -) -> (E, VirtualPolynomial, Vec>) { - let nv = 10; - let mut rng = test_rng(); - let size_log2 = ceil_log2(max_thread_id); - let f1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - let g1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - - let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1.clone(), E::BaseField::ONE); - virtual_poly_1.mul_by_mle(g1.clone(), ::BaseField::ONE); - - let mut virtual_poly_f1: Vec> = match &f1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) - .map(|mle| VirtualPolynomial::new_from_mle(mle, E::BaseField::ONE)) - .collect_vec(), - _ => unreachable!(), - }; - - let poly_g1: Vec> = match &g1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) - .collect_vec(), - _ => unreachable!(), - }; - - let asserted_sum = commutative_op_mle_pair!(|f1, g1| { - (0..f1.len()) - .map(|i| f1[i] * g1[i]) - .fold(E::ZERO, |acc, item| acc + item) - }); - - virtual_poly_f1 - .iter_mut() - .zip(poly_g1.iter()) - .for_each(|(f1, g1)| f1.mul_by_mle(g1.clone(), E::BaseField::ONE)); - (asserted_sum, virtual_poly_1, virtual_poly_f1) -} - -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - -fn main() { - let mut prover_transcript_v1 = Transcript::::new(b"test"); - let mut prover_transcript_v2 = Transcript::::new(b"test"); - - let (asserted_sum, virtual_poly, virtual_poly_splitted) = prepare_input(RAYON_NUM_THREADS); - let (sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, - virtual_poly_splitted.clone(), - &mut prover_transcript_v2, - ); - println!("v2 finish"); - - let mut transcript = Transcript::new(b"test"); - let poly_info = virtual_poly.aux_info.clone(); - let subclaim = IOPVerifierState::::verify( - asserted_sum, - &sumcheck_proof_v2, - &poly_info, - &mut transcript, - ); - assert!( - virtual_poly.evaluate( - subclaim - .point - .iter() - .map(|c| c.elements) - .collect::>() - .as_ref() - ) == subclaim.expected_evaluation, - "wrong subclaim" - ); - - #[allow(deprecated)] - let (sumcheck_proof_v1, _) = - IOPProverState::::prove_parallel(virtual_poly.clone(), &mut prover_transcript_v1); - - println!("v1 finish"); - assert!(sumcheck_proof_v2 == sumcheck_proof_v1); -} diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index 2336e52a9..f2ede8680 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -43,6 +43,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ) -> (IOPProof, IOPProverStateV2<'a, E>) { assert!(!polys.is_empty()); assert_eq!(polys.len(), max_thread_id); + assert!(max_thread_id.is_power_of_two()); let log2_max_thread_id = ceil_log2(max_thread_id); // do not support SIZE not power of 2 assert!(