diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 8f1746d79..e071cb000 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -29,9 +29,8 @@ impl VMState { /// 32 architectural registers + 1 register RD_NULL for dark writes to x0. pub const REG_COUNT: usize = 32 + 1; - pub fn new(platform: Platform, program: Program) -> Self { + pub fn new(platform: Platform, program: Arc) -> Self { let pc = program.entry; - let program = Arc::new(program); let mut vm = Self { pc, @@ -52,7 +51,7 @@ impl VMState { } pub fn new_from_elf(platform: Platform, elf: &[u8]) -> Result { - let program = Program::load_elf(elf, u32::MAX)?; + let program = Arc::new(Program::load_elf(elf, u32::MAX)?); Ok(Self::new(platform, program)) } diff --git a/ceno_emul/tests/test_vm_trace.rs b/ceno_emul/tests/test_vm_trace.rs index d931a6a9c..2aa5f0da2 100644 --- a/ceno_emul/tests/test_vm_trace.rs +++ b/ceno_emul/tests/test_vm_trace.rs @@ -1,6 +1,9 @@ #![allow(clippy::unusual_byte_groupings)] use anyhow::Result; -use std::collections::{BTreeMap, HashMap}; +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, +}; use ceno_emul::{ CENO_PLATFORM, Cycle, EmuContext, InsnKind, Platform, Program, StepRecord, Tracer, VMState, @@ -24,7 +27,7 @@ fn test_vm_trace() -> Result<()> { }) .collect(), ); - let mut ctx = VMState::new(CENO_PLATFORM, program); + let mut ctx = VMState::new(CENO_PLATFORM, Arc::new(program)); let steps = run(&mut ctx)?; @@ -52,7 +55,7 @@ fn test_empty_program() -> Result<()> { vec![], BTreeMap::new(), ); - let mut ctx = VMState::new(CENO_PLATFORM, empty_program); + let mut ctx = VMState::new(CENO_PLATFORM, Arc::new(empty_program)); let res = run(&mut ctx); assert!(matches!(res, Err(e) if e.to_string().contains("InstructionAccessFault")),); Ok(()) diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index f5bd125e7..4bb427eb9 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -67,3 +67,7 @@ name = "riscv_add" [[bench]] harness = false name = "fibonacci" + +[[bench]] +harness = false +name = "fibonacci_witness" diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index a90521cfc..f75c4ddd1 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -7,7 +7,7 @@ use std::{ use ceno_emul::{CENO_PLATFORM, Platform, Program, WORD_SIZE}; use ceno_zkvm::{ self, - e2e::{run_e2e_gen_witness, run_e2e_proof}, + e2e::{Checkpoint, run_e2e_with_checkpoint}, }; use criterion::*; @@ -15,17 +15,20 @@ use goldilocks::GoldilocksExt2; use mpcs::BasefoldDefault; criterion_group! { - name = fibonacci; + name = fibonacci_prove_group; config = Criterion::default().warm_up_time(Duration::from_millis(20000)); - targets = bench_e2e + targets = fibonacci_prove, } -criterion_main!(fibonacci); +criterion_main!(fibonacci_prove_group); const NUM_SAMPLES: usize = 10; -fn bench_e2e(c: &mut Criterion) { - type Pcs = BasefoldDefault; +type Pcs = BasefoldDefault; +type E = GoldilocksExt2; + +// Relevant init data for fibonacci run +fn setup() -> (Program, Platform, u32, u32) { let mut file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); file_path.push("examples/fibonacci.elf"); let stack_size = 32768; @@ -33,7 +36,6 @@ fn bench_e2e(c: &mut Criterion) { let elf_bytes = fs::read(&file_path).expect("read elf file"); let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap(); - // use sp1 platform let platform = Platform { // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. stack_top: 0x0020_0400, @@ -44,6 +46,11 @@ fn bench_e2e(c: &mut Criterion) { ..CENO_PLATFORM }; + (program, platform, stack_size, heap_size) +} + +fn fibonacci_prove(c: &mut Criterion) { + let (program, platform, stack_size, heap_size) = setup(); for max_steps in [1usize << 20, 1usize << 21, 1usize << 22] { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("fibonacci_max_steps_{}", max_steps)); @@ -58,18 +65,20 @@ fn bench_e2e(c: &mut Criterion) { |b| { b.iter_with_setup( || { - run_e2e_gen_witness::( + run_e2e_with_checkpoint::( program.clone(), platform.clone(), stack_size, heap_size, vec![], max_steps, + Checkpoint::PrepE2EProving, ) }, - |(prover, _, zkvm_witness, pi, _, _, _)| { + |(_, run_e2e_proof)| { let timer = Instant::now(); - let _ = run_e2e_proof(prover, zkvm_witness, pi); + + run_e2e_proof(); println!( "Fibonacci::create_proof, max_steps = {}, time = {}", max_steps, @@ -82,6 +91,4 @@ fn bench_e2e(c: &mut Criterion) { group.finish(); } - - type E = GoldilocksExt2; } diff --git a/ceno_zkvm/benches/fibonacci_witness.rs b/ceno_zkvm/benches/fibonacci_witness.rs new file mode 100644 index 000000000..2f09adaee --- /dev/null +++ b/ceno_zkvm/benches/fibonacci_witness.rs @@ -0,0 +1,83 @@ +use std::{fs, path::PathBuf, time::Duration}; + +use ceno_emul::{CENO_PLATFORM, Platform, Program, WORD_SIZE}; +use ceno_zkvm::{ + self, + e2e::{Checkpoint, run_e2e_with_checkpoint}, +}; +use criterion::*; + +use goldilocks::GoldilocksExt2; +use mpcs::BasefoldDefault; + +criterion_group! { + name = fibonacci; + config = Criterion::default().warm_up_time(Duration::from_millis(20000)); + targets = fibonacci_witness +} + +criterion_main!(fibonacci); + +const NUM_SAMPLES: usize = 10; +type Pcs = BasefoldDefault; +type E = GoldilocksExt2; + +// Relevant init data for fibonacci run +fn setup() -> (Program, Platform, u32, u32) { + let mut file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + file_path.push("examples/fibonacci.elf"); + let stack_size = 32768; + let heap_size = 2097152; + let elf_bytes = fs::read(&file_path).expect("read elf file"); + let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap(); + + let platform = Platform { + // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. + stack_top: 0x0020_0400, + rom: program.base_address + ..program.base_address + (program.instructions.len() * WORD_SIZE) as u32, + ram: 0x0010_0000..0xFFFF_0000, + unsafe_ecall_nop: true, + ..CENO_PLATFORM + }; + + (program, platform, stack_size, heap_size) +} + +fn fibonacci_witness(c: &mut Criterion) { + let (program, platform, stack_size, heap_size) = setup(); + + let max_steps = usize::MAX; + let mut group = c.benchmark_group(format!("fib_wit_max_steps_{}", max_steps)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new( + "fibonacci_witness", + format!("fib_wit_max_steps_{}", max_steps), + ), + |b| { + b.iter_with_setup( + || { + run_e2e_with_checkpoint::( + program.clone(), + platform.clone(), + stack_size, + heap_size, + vec![], + max_steps, + Checkpoint::PrepWitnessGen, + ) + }, + |(_, generate_witness)| { + generate_witness(); + }, + ); + }, + ); + + group.finish(); + + type E = GoldilocksExt2; +} diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 8ed2c4f7e..fe119bb25 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -1,4 +1,4 @@ -use std::{panic, time::Instant}; +use std::{panic, sync::Arc, time::Instant}; use ceno_zkvm::{ declare_program, @@ -177,7 +177,7 @@ fn main() { // init vm.x1 = 1, vm.x2 = -1, vm.x3 = step_loop let public_io_init = init_public_io(&[1, u32::MAX, step_loop]); - let mut vm = VMState::new(CENO_PLATFORM, program.clone()); + let mut vm = VMState::new(CENO_PLATFORM, Arc::new(program.clone())); // init memory mapped IO for record in &public_io_init { @@ -290,12 +290,7 @@ fn main() { trace_report.save_json("report.json"); trace_report.save_table("report.txt"); - MockProver::assert_satisfied_full( - zkvm_cs.clone(), - zkvm_fixed_traces.clone(), - &zkvm_witness, - &pi, - ); + MockProver::assert_satisfied_full(&zkvm_cs, zkvm_fixed_traces.clone(), &zkvm_witness, &pi); let timer = Instant::now(); diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 60b1e6bce..07baf8998 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -1,6 +1,6 @@ use ceno_emul::{CENO_PLATFORM, IterAddresses, Platform, Program, WORD_SIZE, Word}; use ceno_zkvm::{ - e2e::{run_e2e_gen_witness, run_e2e_proof, run_e2e_verify}, + e2e::{Checkpoint, run_e2e_with_checkpoint}, with_panic_hook, }; use clap::{Parser, ValueEnum}; @@ -8,7 +8,7 @@ use ff_ext::ff::Field; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; use mpcs::{Basefold, BasefoldRSParams}; -use std::{fs, panic, time::Instant}; +use std::{fs, panic}; use tracing::level_filters::LevelFilter; use tracing_forest::ForestLayer; use tracing_subscriber::{ @@ -143,37 +143,17 @@ fn main() { type B = Goldilocks; type Pcs = Basefold; - let (prover, verifier, zkvm_witness, pi, cycle_num, e2e_start, exit_code) = - run_e2e_gen_witness::( - program, - platform, - args.stack_size, - args.heap_size, - hints, - max_steps, - ); - - let timer = Instant::now(); - let mut zkvm_proof = run_e2e_proof(prover, zkvm_witness, pi); - let proving_time = timer.elapsed().as_secs_f64(); - let e2e_time = e2e_start.elapsed().as_secs_f64(); - let witgen_time = e2e_time - proving_time; - println!( - "Proving finished.\n\ -\tProving time = {:.3}s, freq = {:.3}khz\n\ -\tWitgen time = {:.3}s, freq = {:.3}khz\n\ -\tTotal time = {:.3}s, freq = {:.3}khz\n\ -\tthread num: {}", - proving_time, - cycle_num as f64 / proving_time / 1000.0, - witgen_time, - cycle_num as f64 / witgen_time / 1000.0, - e2e_time, - cycle_num as f64 / e2e_time / 1000.0, - rayon::current_num_threads() + let (state, _) = run_e2e_with_checkpoint::( + program, + platform, + args.stack_size, + args.heap_size, + hints, + max_steps, + Checkpoint::PrepSanityCheck, ); - run_e2e_verify(&verifier, zkvm_proof.clone(), exit_code, max_steps); + let (mut zkvm_proof, verifier) = state.expect("PrepSanityCheck should yield state."); // do sanity check let transcript = Transcript::new(b"riscv"); @@ -207,7 +187,6 @@ fn main() { } }; } - fn memory_from_file(path: &Option) -> Vec { path.as_ref() .map(|path| { diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 512f56982..b1d5b4d77 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -5,136 +5,69 @@ use crate::{ prover::ZKVMProver, verifier::ZKVMVerifier, }, state::GlobalState, - structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, - tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit}, + structs::{ + ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMProvingKey, ZKVMWitnesses, + }, + tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, }; use ceno_emul::{ ByteAddr, EmuContext, InsnKind::EANY, IterAddresses, Platform, Program, StepRecord, Tracer, VMState, WORD_SIZE, WordAddr, }; use ff_ext::ExtensionField; -use itertools::{Itertools, MinMaxResult, chain, enumerate}; +use itertools::{Itertools, MinMaxResult, chain}; use mpcs::PolynomialCommitmentScheme; use std::{ collections::{HashMap, HashSet}, iter::zip, - time::Instant, + ops::Deref, + sync::Arc, }; use transcript::BasicTranscript as Transcript; -type E2EWitnessGen = ( - ZKVMProver, - ZKVMVerifier, - ZKVMWitnesses, - PublicValues, - usize, // number of cycles - Instant, // e2e start, excluding key gen time - Option, -); - -pub fn run_e2e_gen_witness>( - program: Program, - platform: Platform, - stack_size: u32, - heap_size: u32, - hints: Vec, - max_steps: usize, -) -> E2EWitnessGen { - let stack_addrs = platform.stack_top - stack_size..platform.stack_top; - - // Detect heap as starting after program data. - let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; - let heap_addrs = heap_start..heap_start + heap_size; - - let mut mem_padder = MemPadder::new(heap_addrs.end..platform.ram.end); - - let mem_init = { - let program_addrs = program.image.iter().map(|(addr, value)| MemInitRecord { - addr: *addr, - value: *value, - }); - - let stack = stack_addrs - .iter_addresses() - .map(|addr| MemInitRecord { addr, value: 0 }); +pub struct FullMemState { + mem: Vec, + io: Vec, + reg: Vec, + priv_io: Vec, +} - let heap = heap_addrs - .iter_addresses() - .map(|addr| MemInitRecord { addr, value: 0 }); +type InitMemState = FullMemState; +type FinalMemState = FullMemState; - let mem_init = chain!(program_addrs, stack, heap).collect_vec(); +pub struct EmulationResult { + exit_code: Option, + all_records: Vec, + final_mem_state: FinalMemState, + pi: PublicValues, +} - mem_padder.padded_sorted(mem_init.len().next_power_of_two(), mem_init) - }; +fn emulate_program( + program: Arc, + max_steps: usize, + init_mem_state: InitMemState, + platform: &Platform, + hints: Vec, +) -> EmulationResult { + let InitMemState { + mem: mem_init, + io: io_init, + reg: reg_init, + priv_io: _, + } = init_mem_state; - let mut vm = VMState::new(platform.clone(), program); + let mut vm: VMState = VMState::new(platform.clone(), program); for (addr, value) in zip(platform.hints.iter_addresses(), &hints) { vm.init_memory(addr.into(), *value); } - // keygen - let pcs_param = PCS::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); - let (pp, vp) = PCS::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); - let program_params = ProgramParams { - platform: platform.clone(), - program_size: vm.program().instructions.len(), - static_memory_len: mem_init.len(), - ..ProgramParams::default() - }; - let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); - - let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); - let mmu_config = MmuConfig::::construct_circuits(&mut zkvm_cs); - let dummy_config = DummyExtraConfig::::construct_circuits(&mut zkvm_cs); - let prog_config = zkvm_cs.register_table_circuit::>(); - zkvm_cs.register_global_state::(); - - let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); - - zkvm_fixed_traces.register_table_circuit::>( - &zkvm_cs, - &prog_config, - vm.program(), - ); - - // IO is not used in this program, but it must have a particular size at the moment. - let io_init = mem_padder.padded_sorted(mmu_config.public_io_len(), vec![]); - - let reg_init = mmu_config.initial_registers(); - config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces); - mmu_config.generate_fixed_traces( - &zkvm_cs, - &mut zkvm_fixed_traces, - ®_init, - &mem_init, - &io_init.iter().map(|rec| rec.addr).collect_vec(), - ); - dummy_config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces); - - let pk = zkvm_cs - .clone() - .key_gen::(pp.clone(), vp.clone(), zkvm_fixed_traces.clone()) - .expect("keygen failed"); - let vk = pk.get_vk(); - - // proving - let e2e_start = Instant::now(); - let prover = ZKVMProver::new(pk); - let verifier = ZKVMVerifier::new(vk); - let all_records = vm .iter_until_halt() .take(max_steps) .collect::, _>>() .expect("vm exec failed"); - let cycle_num = all_records.len(); - tracing::info!("Proving {} execution steps", cycle_num); - for (i, step) in enumerate(&all_records).rev().take(5).rev() { - tracing::trace!("Step {i}: {:?} - {:?}\n", step.insn().codes().kind, step); - } - // Find the exit code from the HALT step, if halting at all. let exit_code = all_records .iter() @@ -158,16 +91,6 @@ pub fn run_e2e_gen_witness io_init.iter().map(|rec| rec.value).collect_vec(), ); - let mut zkvm_witness = ZKVMWitnesses::default(); - // assign opcode circuits - let dummy_records = config - .assign_opcode_circuit(&zkvm_cs, &mut zkvm_witness, all_records) - .unwrap(); - dummy_config - .assign_opcode_circuit(&zkvm_cs, &mut zkvm_witness, dummy_records) - .unwrap(); - zkvm_witness.finalize_lk_multiplicities(); - // Find the final register values and cycles. let reg_final = reg_init .iter() @@ -208,7 +131,11 @@ pub fn run_e2e_gen_witness // Find the final public IO cycles. let io_final = io_init .iter() - .map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0)) + .map(|rec| MemFinalRecord { + addr: rec.addr, + value: rec.value, + cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0), + }) .collect_vec(); let priv_io_final = zip(platform.hints.iter_addresses(), &hints) @@ -219,45 +146,331 @@ pub fn run_e2e_gen_witness }) .collect_vec(); + EmulationResult { + pi, + exit_code, + all_records, + final_mem_state: FinalMemState { + reg: reg_final, + io: io_final, + mem: mem_final, + priv_io: priv_io_final, + }, + } +} + +fn init_mem( + program: &Program, + platform: &Platform, + mem_padder: &mut MemPadder, + stack_size: u32, + heap_size: u32, +) -> Vec { + let stack_addrs = platform.stack_top - stack_size..platform.stack_top; + // Detect heap as starting after program data. + let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; + let heap_addrs = heap_start..heap_start + heap_size; + let program_addrs = program.image.iter().map(|(addr, value)| MemInitRecord { + addr: *addr, + value: *value, + }); + + let stack = stack_addrs + .iter_addresses() + .map(|addr| MemInitRecord { addr, value: 0 }); + + let heap = heap_addrs + .iter_addresses() + .map(|addr| MemInitRecord { addr, value: 0 }); + + let mem_init = chain!(program_addrs, stack, heap).collect_vec(); + + mem_padder.padded_sorted(mem_init.len().next_power_of_two(), mem_init) +} + +pub struct ConstraintSystemConfig { + zkvm_cs: ZKVMConstraintSystem, + config: Rv32imConfig, + mmu_config: MmuConfig, + dummy_config: DummyExtraConfig, + prog_config: ProgramTableConfig, +} + +fn construct_configs( + program_params: ProgramParams, +) -> ConstraintSystemConfig { + let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); + + let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); + let mmu_config = MmuConfig::::construct_circuits(&mut zkvm_cs); + let dummy_config = DummyExtraConfig::::construct_circuits(&mut zkvm_cs); + let prog_config = zkvm_cs.register_table_circuit::>(); + zkvm_cs.register_global_state::(); + ConstraintSystemConfig { + zkvm_cs, + config, + mmu_config, + dummy_config, + prog_config, + } +} + +fn generate_fixed_traces( + system_config: &ConstraintSystemConfig, + init_mem_state: &InitMemState, + program: &Program, +) -> ZKVMFixedTraces { + let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); + + zkvm_fixed_traces.register_table_circuit::>( + &system_config.zkvm_cs, + &system_config.prog_config, + program, + ); + + system_config + .config + .generate_fixed_traces(&system_config.zkvm_cs, &mut zkvm_fixed_traces); + system_config.mmu_config.generate_fixed_traces( + &system_config.zkvm_cs, + &mut zkvm_fixed_traces, + &init_mem_state.reg, + &init_mem_state.mem, + &init_mem_state.io.iter().map(|rec| rec.addr).collect_vec(), + ); + system_config + .dummy_config + .generate_fixed_traces(&system_config.zkvm_cs, &mut zkvm_fixed_traces); + + zkvm_fixed_traces +} + +pub fn generate_witness( + system_config: &ConstraintSystemConfig, + emul_result: EmulationResult, + program: &Program, +) -> ZKVMWitnesses { + let mut zkvm_witness = ZKVMWitnesses::default(); + // assign opcode circuits + let dummy_records = system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + emul_result.all_records, + ) + .unwrap(); + system_config + .dummy_config + .assign_opcode_circuit(&system_config.zkvm_cs, &mut zkvm_witness, dummy_records) + .unwrap(); + zkvm_witness.finalize_lk_multiplicities(); + // assign table circuits - config - .assign_table_circuit(&zkvm_cs, &mut zkvm_witness) + system_config + .config + .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) .unwrap(); - mmu_config + system_config + .mmu_config .assign_table_circuit( - &zkvm_cs, + &system_config.zkvm_cs, &mut zkvm_witness, - ®_final, - &mem_final, - &io_final, - &priv_io_final, + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result + .final_mem_state + .io + .iter() + .map(|rec| rec.cycle) + .collect_vec(), + &emul_result.final_mem_state.priv_io, ) .unwrap(); // assign program circuit zkvm_witness - .assign_table_circuit::>(&zkvm_cs, &prog_config, vm.program()) + .assign_table_circuit::>( + &system_config.zkvm_cs, + &system_config.prog_config, + program, + ) .unwrap(); + zkvm_witness +} + +// Encodes useful early return points of the e2e pipeline +pub enum Checkpoint { + PrepE2EProving, + PrepWitnessGen, + PrepSanityCheck, + Complete, +} + +// Currently handles state required by the sanity check in `bin/e2e.rs` +// Future cases would require this to be an enum +pub type IntermediateState = (ZKVMProof, ZKVMVerifier); + +// Runs end-to-end pipeline, stopping at a certain checkpoint and yielding useful state. +// +// The return type is a pair of: +// 1. Explicit state +// 2. A no-input-no-ouptut closure +// +// (2.) is useful when you want to setup a certain action and run it +// elsewhere (i.e, in a benchmark) +// (1.) is useful for exposing state which must be further combined with +// state external to this pipeline (e.g, sanity check in bin/e2e.rs) + +#[allow(clippy::type_complexity)] +pub fn run_e2e_with_checkpoint + 'static>( + program: Program, + platform: Platform, + stack_size: u32, + heap_size: u32, + hints: Vec, + max_steps: usize, + checkpoint: Checkpoint, +) -> (Option>, Box) { + // Detect heap as starting after program data. + let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; + let heap_addrs = heap_start..heap_start + heap_size; + let mut mem_padder = MemPadder::new(heap_addrs.end..platform.ram.end); + let mem_init = init_mem(&program, &platform, &mut mem_padder, stack_size, heap_size); + + let program_params = ProgramParams { + platform: platform.clone(), + program_size: program.instructions.len(), + static_memory_len: mem_init.len(), + ..ProgramParams::default() + }; + + let program = Arc::new(program); + let system_config = construct_configs::(program_params); + + // IO is not used in this program, but it must have a particular size at the moment. + let io_init = mem_padder.padded_sorted(system_config.mmu_config.public_io_len(), vec![]); + let reg_init = system_config.mmu_config.initial_registers(); + + let init_full_mem = InitMemState { + mem: mem_init, + reg: reg_init, + io: io_init, + priv_io: vec![], + }; + + // Generate fixed traces + let zkvm_fixed_traces = generate_fixed_traces(&system_config, &init_full_mem, &program); + + // Keygen + let pcs_param = PCS::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); + let (pp, vp) = PCS::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); + let pk = system_config + .zkvm_cs + .clone() + .key_gen::(pp.clone(), vp.clone(), zkvm_fixed_traces.clone()) + .expect("keygen failed"); + let vk = pk.get_vk(); + + if let Checkpoint::PrepE2EProving = checkpoint { + return ( + None, + Box::new(move || { + _ = run_e2e_proof( + program, + max_steps, + init_full_mem, + platform, + hints, + &system_config, + pk, + zkvm_fixed_traces, + ) + }), + ); + } + + // Emulate program + let emul_result = emulate_program(program.clone(), max_steps, init_full_mem, &platform, hints); + + // Clone some emul_result fields before consuming + let pi = emul_result.pi.clone(); + let exit_code = emul_result.exit_code; + + if let Checkpoint::PrepWitnessGen = checkpoint { + return ( + None, + Box::new(move || _ = generate_witness(&system_config, emul_result, program.deref())), + ); + } + + // Generate witness + let zkvm_witness = generate_witness(&system_config, emul_result, &program); + + // proving + let prover = ZKVMProver::new(pk); + if std::env::var("MOCK_PROVING").is_ok() { - MockProver::assert_satisfied_full(zkvm_cs, zkvm_fixed_traces, &zkvm_witness, &pi); + MockProver::assert_satisfied_full( + &system_config.zkvm_cs, + zkvm_fixed_traces.clone(), + &zkvm_witness, + &pi, + ); tracing::info!("Mock proving passed"); } - ( - prover, - verifier, - zkvm_witness, - pi, - cycle_num, - e2e_start, - exit_code, - ) + + // Run proof phase + let transcript = Transcript::new(b"riscv"); + let zkvm_proof = prover + .create_proof(zkvm_witness, pi, transcript) + .expect("create_proof failed"); + + let verifier = ZKVMVerifier::new(vk); + + run_e2e_verify(&verifier, zkvm_proof.clone(), exit_code, max_steps); + + if let Checkpoint::PrepSanityCheck = checkpoint { + return (Some((zkvm_proof, verifier)), Box::new(|| ())); + } + + (None, Box::new(|| ())) } +// Runs program emulation + witness generation + proving +#[allow(clippy::too_many_arguments)] pub fn run_e2e_proof>( - prover: ZKVMProver, - zkvm_witness: ZKVMWitnesses, - pi: PublicValues, + program: Arc, + max_steps: usize, + init_full_mem: InitMemState, + platform: Platform, + hints: Vec, + system_config: &ConstraintSystemConfig, + pk: ZKVMProvingKey, + zkvm_fixed_traces: ZKVMFixedTraces, ) -> ZKVMProof { + // Emulate program + let emul_result = emulate_program(program.clone(), max_steps, init_full_mem, &platform, hints); + + // clone pi before consuming + let pi = emul_result.pi.clone(); + + // Generate witness + let zkvm_witness = generate_witness(system_config, emul_result, program.deref()); + + // proving + let prover = ZKVMProver::new(pk); + + if std::env::var("MOCK_PROVING").is_ok() { + MockProver::assert_satisfied_full( + &system_config.zkvm_cs, + zkvm_fixed_traces.clone(), + &zkvm_witness, + &pi, + ); + tracing::info!("Mock proving passed"); + } + let transcript = Transcript::new(b"riscv"); prover .create_proof(zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index c47295dae..3724ceafd 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -746,7 +746,7 @@ Hints: } pub fn assert_satisfied_full( - cs: ZKVMConstraintSystem, + cs: &ZKVMConstraintSystem, mut fixed_trace: ZKVMFixedTraces, witnesses: &ZKVMWitnesses, pi: &PublicValues, diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index fe3ce8f07..13ef29a66 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -261,7 +261,7 @@ fn test_single_add_instance_e2e() { let vk = pk.get_vk(); // single instance - let mut vm = VMState::new(CENO_PLATFORM, program.clone()); + let mut vm = VMState::new(CENO_PLATFORM, program.clone().into()); let all_records = vm .iter_until_halt() .collect::, _>>() diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index ded9f35cb..8a15d1b19 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -14,7 +14,7 @@ mod ops; pub use ops::*; mod program; -pub use program::{InsnRecord, ProgramTableCircuit}; +pub use program::{InsnRecord, ProgramTableCircuit, ProgramTableConfig}; mod ram; pub use ram::*;