From e7bebcddcd82c34ef7b01ea1eed4cb044e3b9d53 Mon Sep 17 00:00:00 2001 From: mcalancea Date: Fri, 15 Nov 2024 09:37:25 +0200 Subject: [PATCH] Add circuit reports to riscv_opcodes example (#585) implement circuit statistics --- Cargo.lock | 101 ++++++++++ Cargo.toml | 1 + Makefile.toml | 5 + ceno_zkvm/Cargo.toml | 2 + ceno_zkvm/examples/riscv_opcodes.rs | 15 +- ceno_zkvm/src/bin/riscv_stats.rs | 18 ++ ceno_zkvm/src/circuit_builder.rs | 2 +- ceno_zkvm/src/lib.rs | 1 + ceno_zkvm/src/stats.rs | 279 ++++++++++++++++++++++++++++ ceno_zkvm/src/structs.rs | 4 + ceno_zkvm/src/utils.rs | 20 ++ 11 files changed, 446 insertions(+), 2 deletions(-) create mode 100644 ceno_zkvm/src/bin/riscv_stats.rs create mode 100644 ceno_zkvm/src/stats.rs diff --git a/Cargo.lock b/Cargo.lock index 6cc7ba700..ef42b1b82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -273,6 +273,7 @@ dependencies = [ "multilinear_extensions", "paste", "pprof", + "prettytable-rs", "rand", "rand_chacha", "rayon", @@ -524,6 +525,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "ctr" version = "0.9.2" @@ -542,6 +564,27 @@ dependencies = [ "uuid", ] +[[package]] +name = "dirs-next" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" +dependencies = [ + "cfg-if", + "dirs-sys-next", +] + +[[package]] +name = "dirs-sys-next" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + [[package]] name = "either" version = "1.13.0" @@ -560,6 +603,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "361a90feb7004eca4019fb28352a9465666b24f840f5c3cddf0ff13920590b89" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "env_filter" version = "0.1.2" @@ -923,6 +972,16 @@ version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.6.0", + "libc", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1332,6 +1391,20 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettytable-rs" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eea25e07510aa6ab6547308ebe3c036016d162b8da920dbb079e3ba8acf3d95a" +dependencies = [ + "csv", + "encode_unicode", + "is-terminal", + "lazy_static", + "term", + "unicode-width", +] + [[package]] name = "primitive-types" version = "0.10.1" @@ -1435,6 +1508,17 @@ dependencies = [ "bitflags 2.6.0", ] +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + [[package]] name = "regex" version = "1.11.0" @@ -1754,6 +1838,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "term" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f" +dependencies = [ + "dirs-next", + "rustversion", + "winapi", +] + [[package]] name = "thiserror" version = "1.0.64" @@ -1911,6 +2006,12 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + [[package]] name = "unroll" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index 49ac69e28..42b288184 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ paste = "1" plonky2 = "0.2" poseidon = { path = "./poseidon" } pprof = { version = "0.13", features = ["flamegraph"] } +prettytable-rs = "^0.10" rand = "0.8" rand_chacha = { version = "0.3", features = ["serde1"] } rand_core = "0.6" diff --git a/Makefile.toml b/Makefile.toml index 2a4a079a2..b9f231db4 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -53,6 +53,11 @@ args = ["fmt", "-p", "ceno_zkvm", "--", "--check"] command = "cargo" workspace = false +[tasks.riscv_stats] +args = ["run", "--bin", "riscv_stats"] +command = "cargo" +workspace = false + [tasks.clippy] args = [ "clippy", diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 6e50e1d32..f67c0c3f9 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -23,12 +23,14 @@ transcript = { path = "../transcript" } itertools.workspace = true paste.workspace = true +prettytable-rs.workspace = true strum.workspace = true strum_macros.workspace = true tracing.workspace = true tracing-flame.workspace = true tracing-subscriber.workspace = true + clap = { version = "4.5", features = ["derive"] } generic_static = "0.2" rand.workspace = true diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 1decd91b2..a0e4a580a 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -19,6 +19,7 @@ use ceno_emul::{ }; use ceno_zkvm::{ scheme::{PublicValues, constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier}, + stats::{StaticReport, TraceReport}, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, }; use ff_ext::ff::Field; @@ -29,7 +30,6 @@ use sumcheck::{entered_span, exit_span}; use tracing_flame::FlameLayer; use tracing_subscriber::{EnvFilter, Registry, fmt, fmt::format::FmtSpan, layer::SubscriberExt}; use transcript::Transcript; - const PROGRAM_SIZE: usize = 16; // For now, we assume registers // - x0 is not touched, @@ -122,6 +122,7 @@ fn main() { let mut zkvm_cs = ZKVMConstraintSystem::default(); let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); + let prog_config = zkvm_cs.register_table_circuit::>(); zkvm_cs.register_global_state::(); @@ -133,6 +134,8 @@ fn main() { &program, ); + let static_report = StaticReport::new(&zkvm_cs); + let reg_init = initial_registers(); // Define program constant here let program_data: &[u32] = &[]; @@ -288,6 +291,16 @@ fn main() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); + // get instance counts from witness matrices + let trace_report = TraceReport::new_via_witnesses( + &static_report, + &zkvm_witness, + "EXAMPLE_PROGRAM in riscv_opcodes.rs", + ); + + trace_report.save_json("report.json"); + trace_report.save_table("report.txt"); + MockProver::assert_satisfied_full( zkvm_cs.clone(), zkvm_fixed_traces.clone(), diff --git a/ceno_zkvm/src/bin/riscv_stats.rs b/ceno_zkvm/src/bin/riscv_stats.rs new file mode 100644 index 000000000..01ca11763 --- /dev/null +++ b/ceno_zkvm/src/bin/riscv_stats.rs @@ -0,0 +1,18 @@ +use std::collections::BTreeMap; + +use ceno_zkvm::{ + instructions::riscv::Rv32imConfig, + stats::{StaticReport, TraceReport}, + structs::ZKVMConstraintSystem, +}; +use goldilocks::GoldilocksExt2; +type E = GoldilocksExt2; +fn main() { + let mut zkvm_cs = ZKVMConstraintSystem::default(); + + let _ = Rv32imConfig::::construct_circuits(&mut zkvm_cs); + let static_report = StaticReport::new(&zkvm_cs); + let report = TraceReport::new(&static_report, BTreeMap::new(), "no program"); + report.save_table("riscv_stats.txt"); + println!("INFO: generated riscv_stats.txt"); +} diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index c10006983..204f842a8 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -15,7 +15,7 @@ use crate::{ }; /// namespace used for annotation, preserve meta info during circuit construction -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default, serde::Serialize)] pub struct NameSpace { namespace: Vec, } diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 75b4b377d..a3c2ff02f 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -14,6 +14,7 @@ pub mod expression; pub mod gadgets; mod keygen; pub mod state; +pub mod stats; pub mod structs; mod uint; mod utils; diff --git a/ceno_zkvm/src/stats.rs b/ceno_zkvm/src/stats.rs new file mode 100644 index 000000000..7643d0c12 --- /dev/null +++ b/ceno_zkvm/src/stats.rs @@ -0,0 +1,279 @@ +use crate::{ + circuit_builder::{ConstraintSystem, NameSpace}, + expression::Expression, + structs::{ZKVMConstraintSystem, ZKVMWitnesses}, + utils, +}; +use ff_ext::ExtensionField; +use itertools::Itertools; +use prettytable::{Table, row}; +use serde_json::json; +use std::{ + collections::{BTreeMap, HashMap}, + fs::File, + io::Write, +}; +#[derive(Clone, Debug, serde::Serialize, Default)] +pub struct OpCodeStats { + namespace: NameSpace, + witnesses: usize, + reads: usize, + writes: usize, + lookups: usize, + // store degrees as frequency maps + assert_zero_expr_degrees: HashMap, + assert_zero_sumcheck_expr_degrees: HashMap, +} + +impl std::ops::Add for OpCodeStats { + type Output = OpCodeStats; + fn add(self, rhs: Self) -> Self::Output { + OpCodeStats { + namespace: NameSpace::default(), + witnesses: self.witnesses + rhs.witnesses, + reads: self.reads + rhs.reads, + writes: self.writes + rhs.writes, + lookups: self.lookups + rhs.lookups, + assert_zero_expr_degrees: utils::merge_frequency_tables( + self.assert_zero_expr_degrees, + rhs.assert_zero_expr_degrees, + ), + assert_zero_sumcheck_expr_degrees: utils::merge_frequency_tables( + self.assert_zero_sumcheck_expr_degrees, + rhs.assert_zero_sumcheck_expr_degrees, + ), + } + } +} + +#[derive(Clone, Debug, serde::Serialize)] +pub struct TableStats { + table_len: usize, +} + +#[derive(Clone, Debug, serde::Serialize)] +pub enum CircuitStats { + OpCode(OpCodeStats), + Table(TableStats), +} + +impl Default for CircuitStats { + fn default() -> Self { + CircuitStats::OpCode(OpCodeStats::default()) + } +} + +// logic to aggregate two circuit stats; ignore tables +impl std::ops::Add for CircuitStats { + type Output = CircuitStats; + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (CircuitStats::Table(_), CircuitStats::Table(_)) => { + CircuitStats::OpCode(OpCodeStats::default()) + } + (CircuitStats::Table(_), rhs) => rhs, + (lhs, CircuitStats::Table(_)) => lhs, + (CircuitStats::OpCode(lhs), CircuitStats::OpCode(rhs)) => { + CircuitStats::OpCode(lhs + rhs) + } + } + } +} + +impl CircuitStats { + pub fn new(system: &ConstraintSystem) -> Self { + let just_degrees_grouped = |exprs: &Vec>| { + let mut counter = HashMap::new(); + for expr in exprs { + *counter.entry(expr.degree()).or_insert(0) += 1; + } + counter + }; + let is_opcode = system.lk_table_expressions.is_empty() + && system.r_table_expressions.is_empty() + && system.w_table_expressions.is_empty(); + // distinguishing opcodes from tables as done in ZKVMProver::create_proof + if is_opcode { + CircuitStats::OpCode(OpCodeStats { + namespace: system.ns.clone(), + witnesses: system.num_witin as usize, + reads: system.r_expressions.len(), + writes: system.w_expressions.len(), + lookups: system.lk_expressions.len(), + assert_zero_expr_degrees: just_degrees_grouped(&system.assert_zero_expressions), + assert_zero_sumcheck_expr_degrees: just_degrees_grouped( + &system.assert_zero_sumcheck_expressions, + ), + }) + } else { + let table_len = if !system.lk_table_expressions.is_empty() { + system.lk_table_expressions[0].table_len + } else { + 0 + }; + CircuitStats::Table(TableStats { table_len }) + } + } +} + +pub struct Report { + metadata: BTreeMap, + circuits: Vec<(String, INFO)>, +} + +impl Report +where + INFO: serde::Serialize, +{ + pub fn get(&self, circuit_name: &str) -> Option<&INFO> { + self.circuits.iter().find_map(|(name, info)| { + if name == circuit_name { + Some(info) + } else { + None + } + }) + } + + pub fn save_json(&self, filename: &str) { + let json_data = json!({ + "metadata": self.metadata, + "circuits": self.circuits, + }); + + let mut file = File::create(filename).expect("Unable to create file"); + file.write_all(serde_json::to_string_pretty(&json_data).unwrap().as_bytes()) + .expect("Unable to write data"); + } +} +pub type StaticReport = Report; + +impl Report { + pub fn new(zkvm_system: &ZKVMConstraintSystem) -> Self { + Report { + metadata: BTreeMap::default(), + circuits: zkvm_system + .get_css() + .iter() + .map(|(k, v)| (k.clone(), CircuitStats::new(v))) + .collect_vec(), + } + } +} + +#[derive(Clone, Debug, serde::Serialize, Default)] +pub struct CircuitStatsTrace { + static_stats: CircuitStats, + num_instances: usize, +} + +impl CircuitStatsTrace { + pub fn new(static_stats: CircuitStats, num_instances: usize) -> Self { + CircuitStatsTrace { + static_stats, + num_instances, + } + } +} + +pub type TraceReport = Report; + +impl Report { + pub fn new( + static_report: &Report, + num_instances: BTreeMap, + program_name: &str, + ) -> Self { + let mut metadata = static_report.metadata.clone(); + // Note where the num_instances are extracted from + metadata.insert("PROGRAM_NAME".to_owned(), program_name.to_owned()); + + // Ensure we recognize all circuits from the num_instances map + num_instances.keys().for_each(|key| { + assert!(static_report.get(key).is_some(), r"unrecognized key {key}."); + }); + + // Stitch num instances to corresponding entries. Sort by num instances + let mut circuits = static_report + .circuits + .iter() + .map(|(key, value)| { + ( + key.to_owned(), + CircuitStatsTrace::new(value.clone(), *num_instances.get(key).unwrap_or(&0)), + ) + }) + .sorted_by(|lhs, rhs| rhs.1.num_instances.cmp(&lhs.1.num_instances)) + .collect_vec(); + + // aggregate results (for opcode circuits only) + let mut total = CircuitStatsTrace::default(); + for (_, circuit) in &circuits { + if let CircuitStats::OpCode(_) = &circuit.static_stats { + total = CircuitStatsTrace { + num_instances: total.num_instances + circuit.num_instances, + static_stats: total.static_stats + circuit.static_stats.clone(), + } + } + } + circuits.insert(0, ("OPCODES TOTAL".to_owned(), total)); + Report { metadata, circuits } + } + + // Extract num_instances from witness data + pub fn new_via_witnesses( + static_report: &Report, + zkvm_witnesses: &ZKVMWitnesses, + program_name: &str, + ) -> Self { + let num_instances = zkvm_witnesses + .clone() + .into_iter_sorted() + .map(|(key, value)| (key, value.num_instances())) + .collect::>(); + Self::new(static_report, num_instances, program_name) + } + + pub fn save_table(&self, filename: &str) { + let mut opcodes_table = Table::new(); + opcodes_table.add_row(row![ + "opcode_name", + "num_instances", + "lookups", + "reads", + "witnesses", + "writes", + "0_expr_deg", + "0_expr_sumcheck_deg" + ]); + let mut tables_table = Table::new(); + tables_table.add_row(row!["table_name", "num_instances", "table_len"]); + + for (name, circuit) in &self.circuits { + match &circuit.static_stats { + CircuitStats::OpCode(opstats) => { + opcodes_table.add_row(row![ + name.to_owned(), + circuit.num_instances, + opstats.lookups, + opstats.reads, + opstats.witnesses, + opstats.writes, + utils::display_hashmap(&opstats.assert_zero_expr_degrees), + utils::display_hashmap(&opstats.assert_zero_sumcheck_expr_degrees) + ]); + } + CircuitStats::Table(tablestats) => { + tables_table.add_row(row![ + name.to_owned(), + circuit.num_instances, + tablestats.table_len + ]); + } + } + } + let mut file = File::create(filename).expect("Unable to create file"); + _ = opcodes_table.print(&mut file); + _ = tables_table.print(&mut file); + } +} diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 5a504a2c3..96d7f787e 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -170,6 +170,10 @@ impl ZKVMConstraintSystem { SC::finalize_global_state(&mut circuit_builder).expect("global_state_out failed"); } + pub fn get_css(&self) -> &BTreeMap> { + &self.circuit_css + } + pub fn get_cs(&self, name: &String) -> Option<&ConstraintSystem> { self.circuit_css.get(name) } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index e9308c667..971c2ca0a 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -1,3 +1,5 @@ +use std::{collections::HashMap, fmt::Display, hash::Hash}; + use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; @@ -180,3 +182,21 @@ pub fn transpose(v: Vec>) -> Vec> { pub fn next_pow2_instance_padding(num_instance: usize) -> usize { num_instance.next_power_of_two().max(2) } + +pub fn display_hashmap(map: &HashMap) -> String { + format!( + "[{}]", + map.iter().map(|(k, v)| format!("{k}: {v}")).join(",") + ) +} + +pub fn merge_frequency_tables( + lhs: HashMap, + rhs: HashMap, +) -> HashMap { + let mut ret = lhs; + rhs.into_iter().for_each(|(key, value)| { + *ret.entry(key).or_insert(0) += value; + }); + ret +}