diff --git a/Cargo.lock b/Cargo.lock index 924a5a252..3f5d66b76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -276,6 +276,7 @@ dependencies = [ "num-traits", "paste", "pprof", + "prettytable-rs", "rand", "rand_chacha", "rayon", @@ -527,6 +528,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "ctr" version = "0.9.2" @@ -545,6 +567,27 @@ dependencies = [ "uuid", ] +[[package]] +name = "dirs-next" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" +dependencies = [ + "cfg-if", + "dirs-sys-next", +] + +[[package]] +name = "dirs-sys-next" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + [[package]] name = "either" version = "1.13.0" @@ -563,6 +606,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "361a90feb7004eca4019fb28352a9465666b24f840f5c3cddf0ff13920590b89" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "env_filter" version = "0.1.2" @@ -926,6 +975,16 @@ version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.6.0", + "libc", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1346,6 +1405,20 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettytable-rs" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eea25e07510aa6ab6547308ebe3c036016d162b8da920dbb079e3ba8acf3d95a" +dependencies = [ + "csv", + "encode_unicode", + "is-terminal", + "lazy_static", + "term", + "unicode-width", +] + [[package]] name = "primitive-types" version = "0.10.1" @@ -1449,6 +1522,17 @@ dependencies = [ "bitflags 2.6.0", ] +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + [[package]] name = "regex" version = "1.11.0" @@ -1768,6 +1852,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "term" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f" +dependencies = [ + "dirs-next", + "rustversion", + "winapi", +] + [[package]] name = "thiserror" version = "1.0.64" @@ -1925,6 +2020,12 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + [[package]] name = "unroll" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index a2b16d782..4a638dad0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ paste = "1" plonky2 = "0.2" poseidon = { path = "./poseidon" } pprof = { version = "0.13", features = ["flamegraph"] } +prettytable-rs = "^0.10" rand = "0.8" rand_chacha = { version = "0.3", features = ["serde1"] } rand_core = "0.6" diff --git a/Makefile.toml b/Makefile.toml index 2a4a079a2..b9f231db4 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -53,6 +53,11 @@ args = ["fmt", "-p", "ceno_zkvm", "--", "--check"] command = "cargo" workspace = false +[tasks.riscv_stats] +args = ["run", "--bin", "riscv_stats"] +command = "cargo" +workspace = false + [tasks.clippy] args = [ "clippy", diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 3f87754e9..a6cf65cd4 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -24,12 +24,14 @@ transcript = { path = "../transcript" } itertools.workspace = true num-traits.workspace = true paste.workspace = true +prettytable-rs.workspace = true strum.workspace = true strum_macros.workspace = true tracing.workspace = true tracing-flame.workspace = true tracing-subscriber.workspace = true + clap = { version = "4.5", features = ["derive"] } generic_static = "0.2" rand.workspace = true diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 89cfaa8d0..b378521a7 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -16,16 +16,17 @@ use ceno_emul::{ }; use ceno_zkvm::{ scheme::{PublicValues, constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier}, + stats::{StaticReport, TraceReport}, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, }; use ff_ext::ff::Field; use goldilocks::GoldilocksExt2; use itertools::Itertools; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; +use sumcheck::{entered_span, exit_span}; use tracing_flame::FlameLayer; -use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt}; +use tracing_subscriber::{EnvFilter, Registry, fmt, fmt::format::FmtSpan, layer::SubscriberExt}; use transcript::Transcript; - const PROGRAM_SIZE: usize = 16; // For now, we assume registers // - x0 is not touched, @@ -89,17 +90,29 @@ fn main() { .collect(), ); let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); + let mut fmt_layer = fmt::layer() + .compact() + .with_span_events(FmtSpan::CLOSE) + .with_thread_ids(false) + .with_thread_names(false); + fmt_layer.set_ansi(false); + + // Take filtering directives from RUST_LOG env_var + // Directive syntax: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives + // Example: RUST_LOG="info" cargo run.. to get spans/events at info level; profiling spans are info + // Example: RUST_LOG="[sumcheck]" cargo run.. to get only events under the "sumcheck" span + let filter = EnvFilter::from_default_env(); + let subscriber = Registry::default() - .with( - fmt::layer() - .compact() - .with_thread_ids(false) - .with_thread_names(false), - ) - .with(EnvFilter::from_default_env()) + .with(fmt_layer) + .with(filter) .with(flame_layer.with_threads_collapsed(true)); tracing::subscriber::set_global_default(subscriber).unwrap(); + let top_level = entered_span!("TOPLEVEL"); + + let keygen = entered_span!("KEYGEN"); + // keygen let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); @@ -118,6 +131,8 @@ fn main() { &program, ); + let static_report = StaticReport::new(&zkvm_cs); + let reg_init = initial_registers(); config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces); @@ -129,6 +144,7 @@ fn main() { .expect("keygen failed"); let vk = pk.get_vk(); + exit_span!(keygen); // proving let prover = ZKVMProver::new(pk); let verifier = ZKVMVerifier::new(vk); @@ -238,6 +254,16 @@ fn main() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); + // get instance counts from witness matrices + let trace_report = TraceReport::new_via_witnesses( + &static_report, + &zkvm_witness, + "EXAMPLE_PROGRAM in riscv_opcodes.rs", + ); + + trace_report.save_json("report.json"); + trace_report.save_table("report.txt"); + MockProver::assert_satisfied_full( zkvm_cs.clone(), zkvm_fixed_traces.clone(), @@ -248,6 +274,7 @@ fn main() { let timer = Instant::now(); let transcript = Transcript::new(b"riscv"); + let mut zkvm_proof = prover .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); @@ -255,7 +282,7 @@ fn main() { println!( "riscv_opcodes::create_proof, instance_num_vars = {}, time = {}", instance_num_vars, - timer.elapsed().as_secs_f64() + timer.elapsed().as_secs() ); let transcript = Transcript::new(b"riscv"); @@ -300,4 +327,5 @@ fn main() { } }; } + exit_span!(top_level); } diff --git a/ceno_zkvm/src/bin/riscv_stats.rs b/ceno_zkvm/src/bin/riscv_stats.rs new file mode 100644 index 000000000..01ca11763 --- /dev/null +++ b/ceno_zkvm/src/bin/riscv_stats.rs @@ -0,0 +1,18 @@ +use std::collections::BTreeMap; + +use ceno_zkvm::{ + instructions::riscv::Rv32imConfig, + stats::{StaticReport, TraceReport}, + structs::ZKVMConstraintSystem, +}; +use goldilocks::GoldilocksExt2; +type E = GoldilocksExt2; +fn main() { + let mut zkvm_cs = ZKVMConstraintSystem::default(); + + let _ = Rv32imConfig::::construct_circuits(&mut zkvm_cs); + let static_report = StaticReport::new(&zkvm_cs); + let report = TraceReport::new(&static_report, BTreeMap::new(), "no program"); + report.save_table("riscv_stats.txt"); + println!("INFO: generated riscv_stats.txt"); +} diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 90461d5d9..a093543db 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -1,6 +1,6 @@ use ceno_emul::Addr; -use itertools::Itertools; -use std::{collections::HashMap, marker::PhantomData}; +use itertools::{Itertools, chain}; +use std::{collections::HashMap, iter::once, marker::PhantomData}; use ff_ext::ExtensionField; use mpcs::PolynomialCommitmentScheme; @@ -15,7 +15,7 @@ use crate::{ }; /// namespace used for annotation, preserve meta info during circuit construction -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default, serde::Serialize)] pub struct NameSpace { namespace: Vec, } @@ -49,7 +49,7 @@ impl NameSpace { let mut name = String::new(); let mut needs_separation = false; - for ns in ns.iter().chain(Some(&this).into_iter()) { + for ns in chain!(ns, once(&this)) { if needs_separation { name += "/"; } diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index ba853a943..f97a1886d 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -14,7 +14,7 @@ impl ZKVMConstraintSystem { ) -> Result, ZKVMError> { let mut vm_pk = ZKVMProvingKey::new(pp, vp); - for (c_name, cs) in self.circuit_css.into_iter() { + for (c_name, cs) in self.circuit_css { // fixed_traces is optional // verifier will check it existent if cs.num_fixed > 0 let fixed_traces = if cs.num_fixed > 0 { diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 75b4b377d..a3c2ff02f 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -14,6 +14,7 @@ pub mod expression; pub mod gadgets; mod keygen; pub mod state; +pub mod stats; pub mod structs; mod uint; mod utils; diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index c078d7dd2..1504fd562 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -22,7 +22,7 @@ use ff::Field; use ff_ext::ExtensionField; use generic_static::StaticTypeMap; use goldilocks::SmallField; -use itertools::{Itertools, izip}; +use itertools::{Itertools, enumerate, izip}; use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension}; use rand::thread_rng; use std::{ @@ -504,7 +504,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); let expr_evaluated = expr_evaluated.get_base_field_vec(); - for (inst_id, element) in expr_evaluated.iter().enumerate() { + for (inst_id, element) in enumerate(expr_evaluated) { if *element != E::BaseField::ZERO { errors.push(MockProverError::AssertZeroError { expression: expr.clone(), @@ -528,7 +528,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { let expr_evaluated = expr_evaluated.get_ext_field_vec(); // Check each lookup expr exists in t vec - for (inst_id, element) in expr_evaluated.iter().enumerate() { + for (inst_id, element) in enumerate(expr_evaluated) { if !table.contains(&element.to_canonical_u64_vec()) { errors.push(MockProverError::LookupError { expression: expr.clone(), @@ -883,12 +883,12 @@ Hints: num_instances.insert(circuit_name.clone(), num_rows); } - for (rom_type, inputs) in rom_inputs.into_iter() { + for (rom_type, inputs) in rom_inputs { let table = rom_tables.get_mut(&rom_type).unwrap(); for (lk_input_values, circuit_name, lk_input_annotation, input_value_exprs) in inputs { // counting multiplicity in rom_input let mut lk_input_values_multiplicity = HashMap::new(); - for (row, input_value) in lk_input_values.iter().enumerate() { + for (row, input_value) in enumerate(&lk_input_values) { // we only keep first row to restore debug information lk_input_values_multiplicity .entry(input_value) @@ -1009,7 +1009,7 @@ Hints: assert!(gs.insert(circuit_name.clone(), w).is_none()); }; let mut records = vec![]; - for (row, record_rlc) in write_rlc_records.into_iter().enumerate() { + for (row, record_rlc) in enumerate(write_rlc_records) { // TODO: report error assert_eq!(writes.insert(record_rlc), true); records.push((record_rlc, row)); @@ -1045,7 +1045,7 @@ Hints: .get_ext_field_vec()[..*num_rows] .to_vec(); let mut records = vec![]; - for (row, record) in read_records.into_iter().enumerate() { + for (row, record) in enumerate(read_records) { // TODO: return error assert_eq!(reads.insert(record), true); records.push((record, row)); diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 25a08bfec..30d8d9a6d 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -52,6 +52,7 @@ impl> ZKVMProver { } /// create proof for zkvm execution + #[tracing::instrument(skip_all, name = "ZKVM_create_proof")] pub fn create_proof( &self, witnesses: ZKVMWitnesses, @@ -87,10 +88,11 @@ impl> ZKVMProver { let mut commitments = BTreeMap::new(); let mut wits = BTreeMap::new(); + let commit_to_traces_span = entered_span!("commit_to_traces"); // commit to opcode circuits first and then commit to table circuits, sorted by name for (circuit_name, witness) in witnesses.into_iter_sorted() { - let commit_dur = std::time::Instant::now(); let num_instances = witness.num_instances(); + let span = entered_span!("commit to iteration", circuit_name = circuit_name); let witness = match num_instances { 0 => vec![], _ => { @@ -100,16 +102,13 @@ impl> ZKVMProver { PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript) .map_err(ZKVMError::PCSError)?, ); - tracing::info!( - "commit to {} traces took {:?}", - circuit_name, - commit_dur.elapsed() - ); witness } }; + exit_span!(span); wits.insert(circuit_name, (witness, num_instances)); } + exit_span!(commit_to_traces_span); // squeeze two challenges from transcript let challenges = [ @@ -118,6 +117,7 @@ impl> ZKVMProver { ]; tracing::debug!("challenges in prover: {:?}", challenges); + let main_proofs_span = entered_span!("main_proofs"); let mut transcripts = transcript.fork(self.pk.circuit_pks.len()); for ((circuit_name, pk), (i, transcript)) in self .pk @@ -193,6 +193,7 @@ impl> ZKVMProver { } } } + exit_span!(main_proofs_span); Ok(vm_proof) } @@ -201,6 +202,7 @@ impl> ZKVMProver { /// 1: witness layer inferring from input -> output /// 2: proof (sumcheck reduce) from output to input #[allow(clippy::too_many_arguments)] + #[tracing::instrument(skip_all, name = "create_opcode_proof", fields(circuit_name=name))] pub fn create_opcode_proof( &self, name: &str, @@ -226,8 +228,9 @@ impl> ZKVMProver { .all(|v| { v.evaluations().len() == next_pow2_instances }) ); + let wit_inference_span = entered_span!("wit_inference"); // main constraint: read/write record witness inference - let span = entered_span!("wit_inference::record"); + let record_span = entered_span!("record"); let records_wit: Vec> = cs .r_expressions .par_iter() @@ -240,7 +243,7 @@ impl> ZKVMProver { .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); let (w_records_wit, lk_records_wit) = w_lk_records_wit.split_at(cs.w_expressions.len()); - exit_span!(span); + exit_span!(record_span); // product constraint: tower witness inference let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( @@ -255,7 +258,7 @@ impl> ZKVMProver { ); // process last layer by interleaving all the read/write record respectively // as last layer is the output of sel stage - let span = entered_span!("wit_inference::tower_witness_r_last_layer"); + let span = entered_span!("tower_witness_r_last_layer"); // TODO optimize last layer to avoid alloc new vector to save memory let r_records_last_layer = interleaving_mles_to_mles(r_records_wit, num_instances, NUM_FANIN, E::ONE); @@ -263,7 +266,7 @@ impl> ZKVMProver { exit_span!(span); // infer all tower witness after last layer - let span = entered_span!("wit_inference::tower_witness_r_layers"); + let span = entered_span!("tower_witness_r_layers"); let r_wit_layers = infer_tower_product_witness( log2_num_instances + log2_r_count, r_records_last_layer, @@ -271,14 +274,14 @@ impl> ZKVMProver { ); exit_span!(span); - let span = entered_span!("wit_inference::tower_witness_w_last_layer"); + let span = entered_span!("tower_witness_w_last_layer"); // TODO optimize last layer to avoid alloc new vector to save memory let w_records_last_layer = interleaving_mles_to_mles(w_records_wit, num_instances, NUM_FANIN, E::ONE); assert_eq!(w_records_last_layer.len(), NUM_FANIN); exit_span!(span); - let span = entered_span!("wit_inference::tower_witness_w_layers"); + let span = entered_span!("tower_witness_w_layers"); let w_wit_layers = infer_tower_product_witness( log2_num_instances + log2_w_count, w_records_last_layer, @@ -286,16 +289,17 @@ impl> ZKVMProver { ); exit_span!(span); - let span = entered_span!("wit_inference::tower_witness_lk_last_layer"); + let span = entered_span!("tower_witness_lk_last_layer"); // TODO optimize last layer to avoid alloc new vector to save memory let lk_records_last_layer = interleaving_mles_to_mles(lk_records_wit, num_instances, NUM_FANIN, chip_record_alpha); assert_eq!(lk_records_last_layer.len(), 2); exit_span!(span); - let span = entered_span!("wit_inference::tower_witness_lk_layers"); + let span = entered_span!("tower_witness_lk_layers"); let lk_wit_layers = infer_tower_logup_witness(None, lk_records_last_layer); exit_span!(span); + exit_span!(wit_inference_span); if cfg!(test) { // sanity check @@ -326,8 +330,9 @@ impl> ZKVMProver { })); } + let sumcheck_span = entered_span!("SUMCHECK"); // product constraint tower sumcheck - let span = entered_span!("sumcheck::tower"); + let tower_span = entered_span!("tower"); // final evals for verifier let record_r_out_evals: Vec = r_wit_layers[0] .iter() @@ -365,11 +370,11 @@ impl> ZKVMProver { .max() .unwrap() ); - exit_span!(span); + exit_span!(tower_span); tracing::debug!("tower sumcheck finished"); // batch sumcheck: selector + main degree > 1 constraints - let span = entered_span!("sumcheck::main_sel"); + let main_sel_span = entered_span!("main_sel"); let (rt_r, rt_w, rt_lk, rt_non_lc_sumcheck): (Vec, Vec, Vec, Vec) = ( tower_proof.prod_specs_points[0] .last() @@ -585,7 +590,8 @@ impl> ZKVMProver { ); let input_open_point = main_sel_sumcheck_proofs.point.clone(); assert!(input_open_point.len() == log2_num_instances); - exit_span!(span); + exit_span!(main_sel_span); + exit_span!(sumcheck_span); let span = entered_span!("witin::evals"); let wits_in_evals: Vec = witnesses @@ -594,7 +600,7 @@ impl> ZKVMProver { .collect(); exit_span!(span); - let span = entered_span!("pcs_open"); + let pcs_open_span = entered_span!("pcs_open"); let opening_dur = std::time::Instant::now(); tracing::debug!( "[opcode {}]: build opening proof for {} polys at {:?}", @@ -616,7 +622,7 @@ impl> ZKVMProver { name, opening_dur.elapsed(), ); - exit_span!(span); + exit_span!(pcs_open_span); let wits_commit = PCS::get_pure_commitment(&wits_commit); Ok(ZKVMOpcodeProof { @@ -642,6 +648,7 @@ impl> ZKVMProver { /// support batch prove for logup + product arguments each with different num_vars() /// side effect: concurrency will be determine based on min(thread, num_vars()), /// so suggest dont batch too small table (size < threads) with large table together + #[tracing::instrument(skip_all, name = "create_table_proof", fields(table_name=name))] pub fn create_table_proof( &self, name: &str, @@ -685,8 +692,9 @@ impl> ZKVMProver { .all(|(r, w)| r.table_spec.len == w.table_spec.len) ); + let wit_inference_span = entered_span!("wit_inference"); // main constraint: lookup denominator and numerator record witness inference - let span = entered_span!("wit_inference::record"); + let record_span = entered_span!("record"); let mut records_wit: Vec> = cs .r_table_expressions .par_iter() @@ -711,10 +719,10 @@ impl> ZKVMProver { let (lk_d_wit, _empty) = remains.split_at_mut(cs.lk_table_expressions.len()); assert!(_empty.is_empty()); - exit_span!(span); + exit_span!(record_span); // infer all tower witness after last layer - let span = entered_span!("wit_inference::tower_witness_lk_last_layer"); + let span = entered_span!("tower_witness_lk_last_layer"); let mut r_set_last_layer = r_set_wit .iter() .chain(w_set_wit.iter()) @@ -762,7 +770,7 @@ impl> ZKVMProver { .collect::>(); exit_span!(span); - let span = entered_span!("wit_inference::tower_witness_lk_layers"); + let span = entered_span!("tower_witness_lk_layers"); let r_wit_layers = r_set_last_layer .into_iter() .zip(r_set_wit.iter()) @@ -783,6 +791,7 @@ impl> ZKVMProver { .map(|(lk_n, lk_d)| infer_tower_logup_witness(Some(lk_n), lk_d)) .collect_vec(); exit_span!(span); + exit_span!(wit_inference_span); if cfg!(test) { // sanity check @@ -835,8 +844,9 @@ impl> ZKVMProver { })); } + let sumcheck_span = entered_span!("sumcheck"); // product constraint tower sumcheck - let span = entered_span!("sumcheck::tower"); + let tower_span = entered_span!("tower"); // final evals for verifier let r_out_evals = r_wit_layers .iter() @@ -893,7 +903,7 @@ impl> ZKVMProver { rt_tower.len(), // num var length should equal to max_num_instance max_log2_num_instance ); - exit_span!(span); + exit_span!(tower_span); // same point sumcheck is optional when all witin + fixed are in same num_vars let is_skip_same_point_sumcheck = witnesses @@ -908,7 +918,7 @@ impl> ZKVMProver { } else { // one sumcheck to make them opening on same point r (with different prefix) // If all table length are the same, we can skip this sumcheck - let span = entered_span!("sumcheck::opening_same_point"); + let span = entered_span!("opening_same_point"); // NOTE: max concurrency will be dominated by smallest table since it will blo let num_threads = optimal_sumcheck_threads(min_log2_num_instance); let alpha_pow = get_challenge_pows( @@ -997,6 +1007,7 @@ impl> ZKVMProver { ) }; + exit_span!(sumcheck_span); let span = entered_span!("fixed::evals + witin::evals"); let mut evals = witnesses .par_iter() @@ -1029,7 +1040,7 @@ impl> ZKVMProver { .collect_vec(); // TODO implement mechanism to skip commitment - let span = entered_span!("pcs_opening"); + let pcs_opening = entered_span!("pcs_opening"); let (fixed_opening_proof, fixed_commit) = if !fixed.is_empty() { ( Some( @@ -1068,7 +1079,7 @@ impl> ZKVMProver { transcript, ) .map_err(ZKVMError::PCSError)?; - exit_span!(span); + exit_span!(pcs_opening); let wits_commit = PCS::get_pure_commitment(&wits_commit); tracing::debug!( "[table {}] build opening proof for {} polys at {:?}: values = {:?}, commit = {:?}", @@ -1136,6 +1147,7 @@ impl TowerProofs { /// Tower Prover impl TowerProver { + #[tracing::instrument(skip_all, name = "tower_prover_create_proof")] pub fn create_proof<'a, E: ExtensionField>( prod_specs: Vec>, logup_specs: Vec>, @@ -1229,13 +1241,19 @@ impl TowerProver { virtual_polys.add_mle_list(vec![&eq, &q1, &q2], *alpha_denominator); } } - tracing::debug!("generated tower proof at round {}/{}", round, max_round_index); + + let wrap_batch_span = entered_span!("wrap_batch"); + // NOTE: at the time of adding this span, visualizing it with the flamegraph layer + // shows it to be (inexplicably) much more time-consuming than the call to `prove_batch_polys` + // This is likely a bug in the tracing-flame crate. let (sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( num_threads, virtual_polys.get_batched_polys(), transcript, ); + exit_span!(wrap_batch_span); + proofs.push_sumcheck_proofs(sumcheck_proofs.proofs); // rt' = r_merge || rt diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 87200d32d..bdce2bd09 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -44,6 +44,7 @@ impl> ZKVMVerifier } /// Verify a trace from start to halt. + #[tracing::instrument(skip_all, name = "verify_proof")] pub fn verify_proof( &self, vm_proof: ZKVMProof, diff --git a/ceno_zkvm/src/stats.rs b/ceno_zkvm/src/stats.rs new file mode 100644 index 000000000..7643d0c12 --- /dev/null +++ b/ceno_zkvm/src/stats.rs @@ -0,0 +1,279 @@ +use crate::{ + circuit_builder::{ConstraintSystem, NameSpace}, + expression::Expression, + structs::{ZKVMConstraintSystem, ZKVMWitnesses}, + utils, +}; +use ff_ext::ExtensionField; +use itertools::Itertools; +use prettytable::{Table, row}; +use serde_json::json; +use std::{ + collections::{BTreeMap, HashMap}, + fs::File, + io::Write, +}; +#[derive(Clone, Debug, serde::Serialize, Default)] +pub struct OpCodeStats { + namespace: NameSpace, + witnesses: usize, + reads: usize, + writes: usize, + lookups: usize, + // store degrees as frequency maps + assert_zero_expr_degrees: HashMap, + assert_zero_sumcheck_expr_degrees: HashMap, +} + +impl std::ops::Add for OpCodeStats { + type Output = OpCodeStats; + fn add(self, rhs: Self) -> Self::Output { + OpCodeStats { + namespace: NameSpace::default(), + witnesses: self.witnesses + rhs.witnesses, + reads: self.reads + rhs.reads, + writes: self.writes + rhs.writes, + lookups: self.lookups + rhs.lookups, + assert_zero_expr_degrees: utils::merge_frequency_tables( + self.assert_zero_expr_degrees, + rhs.assert_zero_expr_degrees, + ), + assert_zero_sumcheck_expr_degrees: utils::merge_frequency_tables( + self.assert_zero_sumcheck_expr_degrees, + rhs.assert_zero_sumcheck_expr_degrees, + ), + } + } +} + +#[derive(Clone, Debug, serde::Serialize)] +pub struct TableStats { + table_len: usize, +} + +#[derive(Clone, Debug, serde::Serialize)] +pub enum CircuitStats { + OpCode(OpCodeStats), + Table(TableStats), +} + +impl Default for CircuitStats { + fn default() -> Self { + CircuitStats::OpCode(OpCodeStats::default()) + } +} + +// logic to aggregate two circuit stats; ignore tables +impl std::ops::Add for CircuitStats { + type Output = CircuitStats; + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (CircuitStats::Table(_), CircuitStats::Table(_)) => { + CircuitStats::OpCode(OpCodeStats::default()) + } + (CircuitStats::Table(_), rhs) => rhs, + (lhs, CircuitStats::Table(_)) => lhs, + (CircuitStats::OpCode(lhs), CircuitStats::OpCode(rhs)) => { + CircuitStats::OpCode(lhs + rhs) + } + } + } +} + +impl CircuitStats { + pub fn new(system: &ConstraintSystem) -> Self { + let just_degrees_grouped = |exprs: &Vec>| { + let mut counter = HashMap::new(); + for expr in exprs { + *counter.entry(expr.degree()).or_insert(0) += 1; + } + counter + }; + let is_opcode = system.lk_table_expressions.is_empty() + && system.r_table_expressions.is_empty() + && system.w_table_expressions.is_empty(); + // distinguishing opcodes from tables as done in ZKVMProver::create_proof + if is_opcode { + CircuitStats::OpCode(OpCodeStats { + namespace: system.ns.clone(), + witnesses: system.num_witin as usize, + reads: system.r_expressions.len(), + writes: system.w_expressions.len(), + lookups: system.lk_expressions.len(), + assert_zero_expr_degrees: just_degrees_grouped(&system.assert_zero_expressions), + assert_zero_sumcheck_expr_degrees: just_degrees_grouped( + &system.assert_zero_sumcheck_expressions, + ), + }) + } else { + let table_len = if !system.lk_table_expressions.is_empty() { + system.lk_table_expressions[0].table_len + } else { + 0 + }; + CircuitStats::Table(TableStats { table_len }) + } + } +} + +pub struct Report { + metadata: BTreeMap, + circuits: Vec<(String, INFO)>, +} + +impl Report +where + INFO: serde::Serialize, +{ + pub fn get(&self, circuit_name: &str) -> Option<&INFO> { + self.circuits.iter().find_map(|(name, info)| { + if name == circuit_name { + Some(info) + } else { + None + } + }) + } + + pub fn save_json(&self, filename: &str) { + let json_data = json!({ + "metadata": self.metadata, + "circuits": self.circuits, + }); + + let mut file = File::create(filename).expect("Unable to create file"); + file.write_all(serde_json::to_string_pretty(&json_data).unwrap().as_bytes()) + .expect("Unable to write data"); + } +} +pub type StaticReport = Report; + +impl Report { + pub fn new(zkvm_system: &ZKVMConstraintSystem) -> Self { + Report { + metadata: BTreeMap::default(), + circuits: zkvm_system + .get_css() + .iter() + .map(|(k, v)| (k.clone(), CircuitStats::new(v))) + .collect_vec(), + } + } +} + +#[derive(Clone, Debug, serde::Serialize, Default)] +pub struct CircuitStatsTrace { + static_stats: CircuitStats, + num_instances: usize, +} + +impl CircuitStatsTrace { + pub fn new(static_stats: CircuitStats, num_instances: usize) -> Self { + CircuitStatsTrace { + static_stats, + num_instances, + } + } +} + +pub type TraceReport = Report; + +impl Report { + pub fn new( + static_report: &Report, + num_instances: BTreeMap, + program_name: &str, + ) -> Self { + let mut metadata = static_report.metadata.clone(); + // Note where the num_instances are extracted from + metadata.insert("PROGRAM_NAME".to_owned(), program_name.to_owned()); + + // Ensure we recognize all circuits from the num_instances map + num_instances.keys().for_each(|key| { + assert!(static_report.get(key).is_some(), r"unrecognized key {key}."); + }); + + // Stitch num instances to corresponding entries. Sort by num instances + let mut circuits = static_report + .circuits + .iter() + .map(|(key, value)| { + ( + key.to_owned(), + CircuitStatsTrace::new(value.clone(), *num_instances.get(key).unwrap_or(&0)), + ) + }) + .sorted_by(|lhs, rhs| rhs.1.num_instances.cmp(&lhs.1.num_instances)) + .collect_vec(); + + // aggregate results (for opcode circuits only) + let mut total = CircuitStatsTrace::default(); + for (_, circuit) in &circuits { + if let CircuitStats::OpCode(_) = &circuit.static_stats { + total = CircuitStatsTrace { + num_instances: total.num_instances + circuit.num_instances, + static_stats: total.static_stats + circuit.static_stats.clone(), + } + } + } + circuits.insert(0, ("OPCODES TOTAL".to_owned(), total)); + Report { metadata, circuits } + } + + // Extract num_instances from witness data + pub fn new_via_witnesses( + static_report: &Report, + zkvm_witnesses: &ZKVMWitnesses, + program_name: &str, + ) -> Self { + let num_instances = zkvm_witnesses + .clone() + .into_iter_sorted() + .map(|(key, value)| (key, value.num_instances())) + .collect::>(); + Self::new(static_report, num_instances, program_name) + } + + pub fn save_table(&self, filename: &str) { + let mut opcodes_table = Table::new(); + opcodes_table.add_row(row![ + "opcode_name", + "num_instances", + "lookups", + "reads", + "witnesses", + "writes", + "0_expr_deg", + "0_expr_sumcheck_deg" + ]); + let mut tables_table = Table::new(); + tables_table.add_row(row!["table_name", "num_instances", "table_len"]); + + for (name, circuit) in &self.circuits { + match &circuit.static_stats { + CircuitStats::OpCode(opstats) => { + opcodes_table.add_row(row![ + name.to_owned(), + circuit.num_instances, + opstats.lookups, + opstats.reads, + opstats.witnesses, + opstats.writes, + utils::display_hashmap(&opstats.assert_zero_expr_degrees), + utils::display_hashmap(&opstats.assert_zero_sumcheck_expr_degrees) + ]); + } + CircuitStats::Table(tablestats) => { + tables_table.add_row(row![ + name.to_owned(), + circuit.num_instances, + tablestats.table_len + ]); + } + } + } + let mut file = File::create(filename).expect("Unable to create file"); + _ = opcodes_table.print(&mut file); + _ = tables_table.print(&mut file); + } +} diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 5a504a2c3..96d7f787e 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -170,6 +170,10 @@ impl ZKVMConstraintSystem { SC::finalize_global_state(&mut circuit_builder).expect("global_state_out failed"); } + pub fn get_css(&self) -> &BTreeMap> { + &self.circuit_css + } + pub fn get_cs(&self, name: &String) -> Option<&ConstraintSystem> { self.circuit_css.get(name) } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index e9308c667..971c2ca0a 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -1,3 +1,5 @@ +use std::{collections::HashMap, fmt::Display, hash::Hash}; + use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; @@ -180,3 +182,21 @@ pub fn transpose(v: Vec>) -> Vec> { pub fn next_pow2_instance_padding(num_instance: usize) -> usize { num_instance.next_power_of_two().max(2) } + +pub fn display_hashmap(map: &HashMap) -> String { + format!( + "[{}]", + map.iter().map(|(k, v)| format!("{k}: {v}")).join(",") + ) +} + +pub fn merge_frequency_tables( + lhs: HashMap, + rhs: HashMap, +) -> HashMap { + let mut ret = lhs; + rhs.into_iter().for_each(|(key, value)| { + *ret.entry(key).or_insert(0) += value; + }); + ret +} diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index fd33b9d09..3a101fe53 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -103,7 +103,7 @@ fn prepare_input<'a, E: ExtensionField>( fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; - for nv in NV.into_iter() { + for nv in NV { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("sumcheck_nv_{}", nv)); group.sample_size(NUM_SAMPLES); @@ -148,7 +148,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; let threads = max_usable_threads(); - for nv in NV.into_iter() { + for nv in NV { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv)); group.sample_size(NUM_SAMPLES); diff --git a/sumcheck/src/macros.rs b/sumcheck/src/macros.rs index 470afdf18..a8c63206a 100644 --- a/sumcheck/src/macros.rs +++ b/sumcheck/src/macros.rs @@ -1,17 +1,21 @@ #[macro_export] macro_rules! entered_span { + ($first:expr, $($fields:tt)*) => { + $crate::tracing_span!($first, $($fields)*).entered() + }; ($first:expr $(,)*) => { $crate::tracing_span!($first).entered() }; } - #[macro_export] macro_rules! tracing_span { + ($first:expr, $($fields:tt)*) => { + tracing::span!(tracing::Level::INFO, $first, $($fields)*) + }; ($first:expr $(,)*) => { - tracing::span!(tracing::Level::DEBUG, $first) + tracing::span!(tracing::Level::INFO, $first) }; } - #[macro_export] macro_rules! exit_span { ($first:expr $(,)*) => { diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index d572af73f..80cc949df 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -92,51 +92,56 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ); let tx_prover_state = tx_prover_state.clone(); let mut thread_based_transcript = thread_based_transcript.clone(); - + let current_span = tracing::Span::current(); + // NOTE: Apply the span.in_scope(||) pattern to record work of spawned thread inside + // span of parent thread. s.spawn(move |_| { - let mut challenge = None; - let span = entered_span!("prove_rounds"); - for i in 0..num_variables { - let prover_msg = IOPProverStateV2::prove_round_and_update_state( - &mut prover_state, - &challenge, - ); - if thread_id < 2 { - tracing::debug!( - "thread {}: sumcheck round {}/{}", - thread_id, - i + 1, - num_variables + current_span.in_scope(|| { + let mut challenge = None; + let span = entered_span!("prove_rounds"); + for i in 0..num_variables { + let prover_msg = IOPProverStateV2::prove_round_and_update_state( + &mut prover_state, + &challenge, ); - } - thread_based_transcript.append_field_element_exts(&prover_msg.evaluations); + if thread_id < 2 { + tracing::debug!( + "thread {}: sumcheck round {}/{}", + thread_id, + i + 1, + num_variables + ); + } + thread_based_transcript + .append_field_element_exts(&prover_msg.evaluations); - challenge = Some( - thread_based_transcript.get_and_append_challenge(b"Internal round"), - ); - thread_based_transcript.commit_rolling(); - } - exit_span!(span); - // pushing the last challenge point to the state - if let Some(p) = challenge { - prover_state.challenges.push(p); - // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each(|mle| { - let mle = Arc::get_mut(mle).unwrap(); - if mle.num_vars() > 0 { - mle.fix_variables_in_place(&[p.elements]); - } - }); - tx_prover_state - .send(Some((thread_id, prover_state))) - .unwrap(); - } else { - tx_prover_state.send(None).unwrap(); - } + challenge = Some( + thread_based_transcript.get_and_append_challenge(b"Internal round"), + ); + thread_based_transcript.commit_rolling(); + } + exit_span!(span); + // pushing the last challenge point to the state + if let Some(p) = challenge { + prover_state.challenges.push(p); + // fix last challenge to collect final evaluation + prover_state + .poly + .flattened_ml_extensions + .iter_mut() + .for_each(|mle| { + let mle = Arc::get_mut(mle).unwrap(); + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]); + } + }); + tx_prover_state + .send(Some((thread_id, prover_state))) + .unwrap(); + } else { + tx_prover_state.send(None).unwrap(); + } + }) }); } @@ -148,7 +153,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let tx_prover_state = tx_prover_state.clone(); let mut thread_based_transcript = thread_based_transcript.clone(); - let span = entered_span!("main_thread_prove_rounds"); + let main_thread_span = entered_span!("main_thread_prove_rounds"); // main thread also be one worker thread // NOTE inline main thread flow with worker thread to improve efficiency // refactor to shared closure cause to 5% throuput drop @@ -167,7 +172,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { evaluations += AdditiveVec(round_poly_coeffs); } - let span = entered_span!("main_thread_get_challenge"); + let get_challenge_span = entered_span!("main_thread_get_challenge"); transcript.append_field_element_exts(&evaluations.0); let next_challenge = transcript.get_and_append_challenge(b"Internal round"); @@ -175,7 +180,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { thread_based_transcript.send_challenge(next_challenge.elements); }); - exit_span!(span); + exit_span!(get_challenge_span); prover_msgs.push(IOPProverMessage { evaluations: evaluations.0, @@ -184,7 +189,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { challenge = Some(next_challenge); thread_based_transcript.commit_rolling(); } - exit_span!(span); + exit_span!(main_thread_span); // pushing the last challenge point to the state if let Some(p) = challenge { prover_state.challenges.push(p);