diff --git a/Cargo.lock b/Cargo.lock index bf37a3132..c80dfdedb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -262,7 +262,6 @@ dependencies = [ "ceno_emul", "cfg-if", "clap", - "const_env", "criterion", "ff", "ff_ext", @@ -409,26 +408,6 @@ dependencies = [ "tiny-keccak", ] -[[package]] -name = "const_env" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e9e4f72c6e3398ca6da372abd9affd8f89781fe728869bbf986206e9af9627e" -dependencies = [ - "const_env_impl", -] - -[[package]] -name = "const_env_impl" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a4f51209740b5e1589e702b3044cdd4562cef41b6da404904192ffffb852d62" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "constant_time_eq" version = "0.3.1" @@ -719,7 +698,6 @@ version = "0.1.0" dependencies = [ "ark-std", "cfg-if", - "const_env", "criterion", "crossbeam-channel", "ff", @@ -1692,7 +1670,6 @@ version = "0.1.0" dependencies = [ "ark-std", "cfg-if", - "const_env", "criterion", "ff", "ff_ext", @@ -1800,7 +1777,6 @@ name = "sumcheck" version = "0.1.0" dependencies = [ "ark-std", - "const_env", "criterion", "crossbeam-channel", "ff", diff --git a/Cargo.toml b/Cargo.toml index fc601342b..ea6d35b27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,6 @@ version = "0.1.0" [workspace.dependencies] ark-std = "0.4" cfg-if = "1.0" -const_env = "0.1" criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" ff = "0.13" diff --git a/build.rs b/build.rs deleted file mode 100644 index 3e31cb0a9..000000000 --- a/build.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("cargo:rerun-if-env-changed=RAYON_NUM_THREADS"); -} diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 36d2e77cb..ac16a6de9 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -36,7 +36,6 @@ thread_local = "1.1" [dev-dependencies] base64 = "0.22" cfg-if.workspace = true -const_env.workspace = true criterion.workspace = true pprof.workspace = true serde_json.workspace = true diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 9d69e120f..16d5cfe67 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -7,7 +7,6 @@ use ceno_zkvm::{ scheme::prover::ZKVMProver, structs::{ZKVMConstraintSystem, ZKVMFixedTraces}, }; -use const_env::from_env; use criterion::*; use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES; @@ -37,31 +36,9 @@ cfg_if::cfg_if! { criterion_main!(op_add); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; fn bench_add(c: &mut Criterion) { type Pcs = BasefoldDefault; - let max_threads = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; let mut zkvm_cs = ZKVMConstraintSystem::default(); let _ = zkvm_cs.register_opcode_circuit::>(); let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); @@ -128,7 +105,6 @@ fn bench_add(c: &mut Criterion) { commit, &[], num_instances, - max_threads, &mut transcript, &challenges, ) diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index d322d360c..b979975d2 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -8,7 +8,6 @@ use ceno_zkvm::{ tables::{MemFinalRecord, ProgramTableCircuit, initial_memory, initial_registers}, }; use clap::Parser; -use const_env::from_env; use ceno_emul::{ ByteAddr, CENO_PLATFORM, EmuContext, @@ -28,9 +27,6 @@ use tracing_flame::FlameLayer; use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt}; use transcript::Transcript; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - const PROGRAM_SIZE: usize = 512; // For now, we assume registers // - x0 is not touched, @@ -80,27 +76,6 @@ fn main() { type E = GoldilocksExt2; type Pcs = Basefold; - let max_threads = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; - let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() .with( @@ -237,7 +212,7 @@ fn main() { let transcript = Transcript::new(b"riscv"); let mut zkvm_proof = prover - .create_proof(zkvm_witness, pi, max_threads, transcript) + .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); println!( diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2c92197a5..c87283c01 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -32,7 +32,7 @@ use crate::{ structs::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, - utils::{get_challenge_pows, next_pow2_instance_padding, proper_num_threads}, + utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, virtual_polys::VirtualPolynomials, }; @@ -52,7 +52,6 @@ impl> ZKVMProver { &self, witnesses: ZKVMWitnesses, pi: PublicValues, - max_threads: usize, mut transcript: Transcript, ) -> Result, ZKVMError> { let mut vm_proof = ZKVMProof::empty(pi); @@ -135,7 +134,6 @@ impl> ZKVMProver { wits_commit, pi, num_instances, - max_threads, transcript, &challenges, )?; @@ -155,7 +153,6 @@ impl> ZKVMProver { witness.into_iter().map(|v| v.into()).collect_vec(), wits_commit, pi, - max_threads, transcript, &challenges, )?; @@ -186,7 +183,6 @@ impl> ZKVMProver { wits_commit: PCS::CommitmentWithData, pi: &[E::BaseField], num_instances: usize, - max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { @@ -320,7 +316,6 @@ impl> ZKVMProver { let lk_q2_out_eval = lk_wit_layers[0][3].get_ext_field_vec()[0]; assert!(record_r_out_evals.len() == NUM_FANIN && record_w_out_evals.len() == NUM_FANIN); let (rt_tower, tower_proof) = TowerProver::create_proof( - max_threads, vec![ TowerProverSpec { witness: r_wit_layers, @@ -363,7 +358,7 @@ impl> ZKVMProver { rt_tower[..log2_num_instances].to_vec(), ); - let num_threads = proper_num_threads(log2_num_instances, max_threads); + let num_threads = optimal_sumcheck_threads(log2_num_instances); let alpha_pow = get_challenge_pows( MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + cs.assert_zero_sumcheck_expressions.len(), transcript, @@ -624,7 +619,6 @@ impl> ZKVMProver { witnesses: Vec>, wits_commit: PCS::CommitmentWithData, pi: &[E::BaseField], - max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { @@ -843,7 +837,6 @@ impl> ZKVMProver { .collect_vec(); let (rt_tower, tower_proof) = TowerProver::create_proof( - max_threads, // pattern [r1, w1, r2, w2, ...] same pair are chain together r_wit_layers .into_iter() @@ -884,7 +877,7 @@ impl> ZKVMProver { // If all table length are the same, we can skip this sumcheck let span = entered_span!("sumcheck::opening_same_point"); // NOTE: max concurrency will be dominated by smallest table since it will blo - let num_threads = proper_num_threads(min_log2_num_instance, max_threads); + let num_threads = optimal_sumcheck_threads(min_log2_num_instance); let alpha_pow = get_challenge_pows( cs.r_table_expressions.len() + cs.w_table_expressions.len() @@ -1074,7 +1067,6 @@ impl TowerProofs { /// Tower Prover impl TowerProver { pub fn create_proof<'a, E: ExtensionField>( - max_threads: usize, prod_specs: Vec>, logup_specs: Vec>, num_fanin: usize, @@ -1109,7 +1101,7 @@ impl TowerProver { let (next_rt, _) = (1..=max_round_index).fold((initial_rt, alpha_pows), |(out_rt, alpha_pows), round| { // in first few round we just run on single thread - let num_threads = proper_num_threads(out_rt.len(), max_threads); + let num_threads = optimal_sumcheck_threads(out_rt.len()); let eq: ArcMultilinearExtension = build_eq_x_r_vec(&out_rt).into_mle().into(); let mut virtual_polys = VirtualPolynomials::::new(num_threads, out_rt.len()); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 80c43af0b..f928a8918 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -141,7 +141,6 @@ fn test_rw_lk_expression_combination() { commit, &[], num_instances, - 1, &mut transcript, &prover_challenges, ) @@ -290,7 +289,7 @@ fn test_single_add_instance_e2e() { let pi = PublicValues::new(0, 0, 0, 0, 0); let transcript = Transcript::new(b"riscv"); let zkvm_proof = prover - .create_proof(zkvm_witness, pi, 1, transcript) + .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); let transcript = Transcript::new(b"riscv"); diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 119b13439..47af02967 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -2,6 +2,7 @@ use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; +use multilinear_extensions::util::max_usable_threads; use transcript::Transcript; /// convert ext field element to u64, assume it is inside the range @@ -113,7 +114,8 @@ pub fn u64vec(x: u64) -> [u64; W] { /// we expect each thread at least take 4 num of sumcheck variables /// return optimal num threads to run sumcheck -pub fn proper_num_threads(num_vars: usize, expected_max_threads: usize) -> usize { +pub fn optimal_sumcheck_threads(num_vars: usize) -> usize { + let expected_max_threads = max_usable_threads(); let min_numvar_per_thread = 4; if num_vars <= min_numvar_per_thread { 1 diff --git a/gkr/Cargo.toml b/gkr/Cargo.toml index 702791bd0..fab3ec6b8 100644 --- a/gkr/Cargo.toml +++ b/gkr/Cargo.toml @@ -9,7 +9,6 @@ ark-std.workspace = true ff.workspace = true goldilocks.workspace = true -const_env.workspace = true crossbeam-channel.workspace = true ff_ext = { path = "../ff_ext" } itertools.workspace = true diff --git a/gkr/benches/keccak256.rs b/gkr/benches/keccak256.rs index fce8ffb31..d48920dd7 100644 --- a/gkr/benches/keccak256.rs +++ b/gkr/benches/keccak256.rs @@ -3,11 +3,11 @@ use std::time::Duration; -use const_env::from_env; use criterion::*; use gkr::gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak256}; use goldilocks::GoldilocksExt2; +use multilinear_extensions::util::max_usable_threads; cfg_if::cfg_if! { if #[cfg(feature = "flamegraph")] { @@ -28,8 +28,6 @@ cfg_if::cfg_if! { criterion_main!(keccak256); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; fn bench_keccak256(c: &mut Criterion) { println!( @@ -37,26 +35,7 @@ fn bench_keccak256(c: &mut Criterion) { keccak256_circuit::().layers.len() ); - let max_thread_id = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; + let max_thread_id = max_usable_threads(); let circuit = keccak256_circuit::(); diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index 28e4f8284..a0a8e56a2 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -30,3 +30,21 @@ pub fn create_uninit_vec(len: usize) -> Vec> { pub fn largest_even_below(n: usize) -> usize { if n % 2 == 0 { n } else { n.saturating_sub(1) } } + +fn prev_power_of_two(n: usize) -> usize { + (n + 1).next_power_of_two() / 2 +} + +/// Largest power of two that fits the available rayon threads +pub fn max_usable_threads() -> usize { + if cfg!(test) { + 1 + } else { + let n = rayon::current_num_threads(); + let threads = prev_power_of_two(n); + if n != threads { + tracing::warn!("thread size {n} is not power of 2, using {threads} threads instead."); + } + threads + } +} diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index a5468ad90..45e4e9fb7 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -2,7 +2,7 @@ use std::{cmp::max, collections::HashMap, marker::PhantomData, mem::MaybeUninit, use crate::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - util::{bit_decompose, create_uninit_vec}, + util::{bit_decompose, create_uninit_vec, max_usable_threads}, }; use ark_std::{end_timer, iterable::Iterable, rand::Rng, start_timer}; use ff::{Field, PrimeField}; @@ -452,8 +452,7 @@ pub fn build_eq_x_r_vec(r: &[E]) -> Vec { // .... // 1 1 1 1 -> r0 * r1 * r2 * r3 // we will need 2^num_var evaluations - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let nthreads = max_usable_threads(); let nbits = nthreads.trailing_zeros() as usize; assert_eq!(1 << nbits, nthreads); diff --git a/singer/Cargo.toml b/singer/Cargo.toml index e6c10602e..ac71ab4b9 100644 --- a/singer/Cargo.toml +++ b/singer/Cargo.toml @@ -28,7 +28,6 @@ tracing-subscriber.workspace = true [dev-dependencies] cfg-if.workspace = true -const_env.workspace = true criterion.workspace = true pprof.workspace = true tracing.workspace = true diff --git a/singer/benches/add.rs b/singer/benches/add.rs index 70fee6f28..5984a19b2 100644 --- a/singer/benches/add.rs +++ b/singer/benches/add.rs @@ -4,7 +4,6 @@ use std::time::{Duration, Instant}; use ark_std::test_rng; -use const_env::from_env; use criterion::*; use ff_ext::{ExtensionField, ff::Field}; @@ -30,9 +29,8 @@ cfg_if::cfg_if! { criterion_main!(op_add); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; +use multilinear_extensions::util::max_usable_threads; use singer::{ CircuitWiresIn, SingerGraphBuilder, SingerParams, instructions::{Instruction, InstructionGraph, SingerCircuitBuilder, add::AddInstruction}, @@ -42,26 +40,7 @@ use singer_utils::structs::ChipChallenges; use transcript::Transcript; fn bench_add(c: &mut Criterion) { - let max_thread_id = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; + let max_thread_id = max_usable_threads(); let chip_challenges = ChipChallenges::default(); let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 4b9605cdf..540495c0a 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -6,7 +6,6 @@ version.workspace = true [dependencies] ark-std.workspace = true -const_env.workspace = true ff.workspace = true ff_ext = { path = "../ff_ext" } goldilocks.workspace = true diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index 7116789b9..fd33b9d09 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -4,7 +4,6 @@ use std::array; use ark_std::test_rng; -use const_env::from_env; use criterion::*; use ff_ext::ExtensionField; use itertools::Itertools; @@ -14,6 +13,7 @@ use goldilocks::GoldilocksExt2; use multilinear_extensions::{ mle::DenseMultilinearExtension, op_mle, + util::max_usable_threads, virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2 as VirtualPolynomial}, }; use transcript::Transcript; @@ -41,10 +41,10 @@ pub fn transpose(v: Vec>) -> Vec> { } fn prepare_input<'a, E: ExtensionField>( - max_thread_id: usize, nv: usize, ) -> (E, VirtualPolynomial<'a, E>, Vec>) { let mut rng = test_rng(); + let max_thread_id = max_usable_threads(); let size_log2 = ceil_log2(max_thread_id); let fs: [ArcMultilinearExtension<'a, E>; NUM_DEGREE] = array::from_fn(|_| { let mle: ArcMultilinearExtension<'a, E> = @@ -100,9 +100,6 @@ fn prepare_input<'a, E: ExtensionField>( (asserted_sum, virtual_poly_v1, virtual_poly_v2) } -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; @@ -119,7 +116,7 @@ fn sumcheck_fn(c: &mut Criterion) { || { let prover_transcript = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; + { prepare_input(nv) }; ( prover_transcript, asserted_sum, @@ -150,6 +147,7 @@ fn sumcheck_fn(c: &mut Criterion) { fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; + let threads = max_usable_threads(); for nv in NV.into_iter() { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv)); @@ -163,7 +161,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { || { let prover_transcript = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; + { prepare_input(nv) }; ( prover_transcript, asserted_sum, @@ -178,7 +176,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { virtual_poly_splitted, )| { let (_sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, + threads, virtual_poly_splitted, &mut prover_transcript, ); diff --git a/sumcheck/examples/devirgo_sumcheck.rs b/sumcheck/examples/devirgo_sumcheck.rs deleted file mode 100644 index 29ff81368..000000000 --- a/sumcheck/examples/devirgo_sumcheck.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::sync::Arc; - -use ark_std::test_rng; -use const_env::from_env; -use ff_ext::{ExtensionField, ff::Field}; -use goldilocks::GoldilocksExt2; -use itertools::Itertools; -use multilinear_extensions::{ - commutative_op_mle_pair, - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - virtual_poly::VirtualPolynomial, -}; -use sumcheck::{ - structs::{IOPProverState, IOPVerifierState}, - util::ceil_log2, -}; -use transcript::Transcript; - -type E = GoldilocksExt2; - -fn prepare_input( - max_thread_id: usize, -) -> (E, VirtualPolynomial, Vec>) { - let nv = 10; - let mut rng = test_rng(); - let size_log2 = ceil_log2(max_thread_id); - let f1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - let g1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - - let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1.clone(), E::BaseField::ONE); - virtual_poly_1.mul_by_mle(g1.clone(), ::BaseField::ONE); - - let mut virtual_poly_f1: Vec> = match &f1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) - .map(|mle| VirtualPolynomial::new_from_mle(mle, E::BaseField::ONE)) - .collect_vec(), - _ => unreachable!(), - }; - - let poly_g1: Vec> = match &g1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) - .collect_vec(), - _ => unreachable!(), - }; - - let asserted_sum = commutative_op_mle_pair!(|f1, g1| { - (0..f1.len()) - .map(|i| f1[i] * g1[i]) - .fold(E::ZERO, |acc, item| acc + item) - }); - - virtual_poly_f1 - .iter_mut() - .zip(poly_g1.iter()) - .for_each(|(f1, g1)| f1.mul_by_mle(g1.clone(), E::BaseField::ONE)); - (asserted_sum, virtual_poly_1, virtual_poly_f1) -} - -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - -fn main() { - let mut prover_transcript_v1 = Transcript::::new(b"test"); - let mut prover_transcript_v2 = Transcript::::new(b"test"); - - let (asserted_sum, virtual_poly, virtual_poly_splitted) = prepare_input(RAYON_NUM_THREADS); - let (sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, - virtual_poly_splitted.clone(), - &mut prover_transcript_v2, - ); - println!("v2 finish"); - - let mut transcript = Transcript::new(b"test"); - let poly_info = virtual_poly.aux_info.clone(); - let subclaim = IOPVerifierState::::verify( - asserted_sum, - &sumcheck_proof_v2, - &poly_info, - &mut transcript, - ); - assert!( - virtual_poly.evaluate( - subclaim - .point - .iter() - .map(|c| c.elements) - .collect::>() - .as_ref() - ) == subclaim.expected_evaluation, - "wrong subclaim" - ); - - #[allow(deprecated)] - let (sumcheck_proof_v1, _) = - IOPProverState::::prove_parallel(virtual_poly.clone(), &mut prover_transcript_v1); - - println!("v1 finish"); - assert!(sumcheck_proof_v2 == sumcheck_proof_v1); -} diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index 2336e52a9..f2ede8680 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -43,6 +43,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ) -> (IOPProof, IOPProverStateV2<'a, E>) { assert!(!polys.is_empty()); assert_eq!(polys.len(), max_thread_id); + assert!(max_thread_id.is_power_of_two()); let log2_max_thread_id = ceil_log2(max_thread_id); // do not support SIZE not power of 2 assert!(