Skip to content

Commit

Permalink
Further break down e2e logic (scroll-tech#703)
Browse files Browse the repository at this point in the history
Fixes scroll-tech#691.

- Refactors e2e into smaller steps. 
- Introduces a `run_partial` method which takes a `prefix` argument,
indicating when to stop the run of the pipeline. According to this
argument, different parts of the state are yielded back. This part isn't
too pretty type-wise, but ultimately I think it's not dangerous and a
reasonable compromise for the moment.

Later edit: I've also included a benchmark for witness generation of the
Fibonacci program, since it is a consumer of this refactor and helps to
exemplify it.
  • Loading branch information
mcalancea authored and 10to4 committed Dec 12, 2024
1 parent 614f385 commit 774bab5
Show file tree
Hide file tree
Showing 11 changed files with 479 additions and 196 deletions.
5 changes: 2 additions & 3 deletions ceno_emul/src/vm_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ impl VMState {
/// 32 architectural registers + 1 register RD_NULL for dark writes to x0.
pub const REG_COUNT: usize = 32 + 1;

pub fn new(platform: Platform, program: Program) -> Self {
pub fn new(platform: Platform, program: Arc<Program>) -> Self {
let pc = program.entry;
let program = Arc::new(program);

let mut vm = Self {
pc,
Expand All @@ -52,7 +51,7 @@ impl VMState {
}

pub fn new_from_elf(platform: Platform, elf: &[u8]) -> Result<Self> {
let program = Program::load_elf(elf, u32::MAX)?;
let program = Arc::new(Program::load_elf(elf, u32::MAX)?);
Ok(Self::new(platform, program))
}

Expand Down
9 changes: 6 additions & 3 deletions ceno_emul/tests/test_vm_trace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#![allow(clippy::unusual_byte_groupings)]
use anyhow::Result;
use std::collections::{BTreeMap, HashMap};
use std::{
collections::{BTreeMap, HashMap},
sync::Arc,
};

use ceno_emul::{
CENO_PLATFORM, Cycle, EmuContext, InsnKind, Platform, Program, StepRecord, Tracer, VMState,
Expand All @@ -24,7 +27,7 @@ fn test_vm_trace() -> Result<()> {
})
.collect(),
);
let mut ctx = VMState::new(CENO_PLATFORM, program);
let mut ctx = VMState::new(CENO_PLATFORM, Arc::new(program));

let steps = run(&mut ctx)?;

Expand Down Expand Up @@ -52,7 +55,7 @@ fn test_empty_program() -> Result<()> {
vec![],
BTreeMap::new(),
);
let mut ctx = VMState::new(CENO_PLATFORM, empty_program);
let mut ctx = VMState::new(CENO_PLATFORM, Arc::new(empty_program));
let res = run(&mut ctx);
assert!(matches!(res, Err(e) if e.to_string().contains("InstructionAccessFault")),);
Ok(())
Expand Down
4 changes: 4 additions & 0 deletions ceno_zkvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,7 @@ name = "riscv_add"
[[bench]]
harness = false
name = "fibonacci"

[[bench]]
harness = false
name = "fibonacci_witness"
31 changes: 19 additions & 12 deletions ceno_zkvm/benches/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,35 @@ use std::{
use ceno_emul::{CENO_PLATFORM, Platform, Program, WORD_SIZE};
use ceno_zkvm::{
self,
e2e::{run_e2e_gen_witness, run_e2e_proof},
e2e::{Checkpoint, run_e2e_with_checkpoint},
};
use criterion::*;

use goldilocks::GoldilocksExt2;
use mpcs::BasefoldDefault;

criterion_group! {
name = fibonacci;
name = fibonacci_prove_group;
config = Criterion::default().warm_up_time(Duration::from_millis(20000));
targets = bench_e2e
targets = fibonacci_prove,
}

criterion_main!(fibonacci);
criterion_main!(fibonacci_prove_group);

const NUM_SAMPLES: usize = 10;

fn bench_e2e(c: &mut Criterion) {
type Pcs = BasefoldDefault<E>;
type Pcs = BasefoldDefault<E>;
type E = GoldilocksExt2;

// Relevant init data for fibonacci run
fn setup() -> (Program, Platform, u32, u32) {
let mut file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
file_path.push("examples/fibonacci.elf");
let stack_size = 32768;
let heap_size = 2097152;
let elf_bytes = fs::read(&file_path).expect("read elf file");
let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap();

// use sp1 platform
let platform = Platform {
// The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here.
stack_top: 0x0020_0400,
Expand All @@ -44,6 +46,11 @@ fn bench_e2e(c: &mut Criterion) {
..CENO_PLATFORM
};

(program, platform, stack_size, heap_size)
}

fn fibonacci_prove(c: &mut Criterion) {
let (program, platform, stack_size, heap_size) = setup();
for max_steps in [1usize << 20, 1usize << 21, 1usize << 22] {
// expand more input size once runtime is acceptable
let mut group = c.benchmark_group(format!("fibonacci_max_steps_{}", max_steps));
Expand All @@ -58,18 +65,20 @@ fn bench_e2e(c: &mut Criterion) {
|b| {
b.iter_with_setup(
|| {
run_e2e_gen_witness::<E, Pcs>(
run_e2e_with_checkpoint::<E, Pcs>(
program.clone(),
platform.clone(),
stack_size,
heap_size,
vec![],
max_steps,
Checkpoint::PrepE2EProving,
)
},
|(prover, _, zkvm_witness, pi, _, _, _)| {
|(_, run_e2e_proof)| {
let timer = Instant::now();
let _ = run_e2e_proof(prover, zkvm_witness, pi);

run_e2e_proof();
println!(
"Fibonacci::create_proof, max_steps = {}, time = {}",
max_steps,
Expand All @@ -82,6 +91,4 @@ fn bench_e2e(c: &mut Criterion) {

group.finish();
}

type E = GoldilocksExt2;
}
83 changes: 83 additions & 0 deletions ceno_zkvm/benches/fibonacci_witness.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use std::{fs, path::PathBuf, time::Duration};

use ceno_emul::{CENO_PLATFORM, Platform, Program, WORD_SIZE};
use ceno_zkvm::{
self,
e2e::{Checkpoint, run_e2e_with_checkpoint},
};
use criterion::*;

use goldilocks::GoldilocksExt2;
use mpcs::BasefoldDefault;

criterion_group! {
name = fibonacci;
config = Criterion::default().warm_up_time(Duration::from_millis(20000));
targets = fibonacci_witness
}

criterion_main!(fibonacci);

const NUM_SAMPLES: usize = 10;
type Pcs = BasefoldDefault<E>;
type E = GoldilocksExt2;

// Relevant init data for fibonacci run
fn setup() -> (Program, Platform, u32, u32) {
let mut file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
file_path.push("examples/fibonacci.elf");
let stack_size = 32768;
let heap_size = 2097152;
let elf_bytes = fs::read(&file_path).expect("read elf file");
let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap();

let platform = Platform {
// The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here.
stack_top: 0x0020_0400,
rom: program.base_address
..program.base_address + (program.instructions.len() * WORD_SIZE) as u32,
ram: 0x0010_0000..0xFFFF_0000,
unsafe_ecall_nop: true,
..CENO_PLATFORM
};

(program, platform, stack_size, heap_size)
}

fn fibonacci_witness(c: &mut Criterion) {
let (program, platform, stack_size, heap_size) = setup();

let max_steps = usize::MAX;
let mut group = c.benchmark_group(format!("fib_wit_max_steps_{}", max_steps));
group.sample_size(NUM_SAMPLES);

// Benchmark the proving time
group.bench_function(
BenchmarkId::new(
"fibonacci_witness",
format!("fib_wit_max_steps_{}", max_steps),
),
|b| {
b.iter_with_setup(
|| {
run_e2e_with_checkpoint::<E, Pcs>(
program.clone(),
platform.clone(),
stack_size,
heap_size,
vec![],
max_steps,
Checkpoint::PrepWitnessGen,
)
},
|(_, generate_witness)| {
generate_witness();
},
);
},
);

group.finish();

type E = GoldilocksExt2;
}
11 changes: 3 additions & 8 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{panic, time::Instant};
use std::{panic, sync::Arc, time::Instant};

use ceno_zkvm::{
declare_program,
Expand Down Expand Up @@ -177,7 +177,7 @@ fn main() {
// init vm.x1 = 1, vm.x2 = -1, vm.x3 = step_loop
let public_io_init = init_public_io(&[1, u32::MAX, step_loop]);

let mut vm = VMState::new(CENO_PLATFORM, program.clone());
let mut vm = VMState::new(CENO_PLATFORM, Arc::new(program.clone()));

// init memory mapped IO
for record in &public_io_init {
Expand Down Expand Up @@ -290,12 +290,7 @@ fn main() {
trace_report.save_json("report.json");
trace_report.save_table("report.txt");

MockProver::assert_satisfied_full(
zkvm_cs.clone(),
zkvm_fixed_traces.clone(),
&zkvm_witness,
&pi,
);
MockProver::assert_satisfied_full(&zkvm_cs, zkvm_fixed_traces.clone(), &zkvm_witness, &pi);

let timer = Instant::now();

Expand Down
43 changes: 11 additions & 32 deletions ceno_zkvm/src/bin/e2e.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use ceno_emul::{CENO_PLATFORM, IterAddresses, Platform, Program, WORD_SIZE, Word};
use ceno_zkvm::{
e2e::{run_e2e_gen_witness, run_e2e_proof, run_e2e_verify},
e2e::{Checkpoint, run_e2e_with_checkpoint},
with_panic_hook,
};
use clap::{Parser, ValueEnum};
use ff_ext::ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use mpcs::{Basefold, BasefoldRSParams};
use std::{fs, panic, time::Instant};
use std::{fs, panic};
use tracing::level_filters::LevelFilter;
use tracing_forest::ForestLayer;
use tracing_subscriber::{
Expand Down Expand Up @@ -143,37 +143,17 @@ fn main() {
type B = Goldilocks;
type Pcs = Basefold<GoldilocksExt2, BasefoldRSParams>;

let (prover, verifier, zkvm_witness, pi, cycle_num, e2e_start, exit_code) =
run_e2e_gen_witness::<E, Pcs>(
program,
platform,
args.stack_size,
args.heap_size,
hints,
max_steps,
);

let timer = Instant::now();
let mut zkvm_proof = run_e2e_proof(prover, zkvm_witness, pi);
let proving_time = timer.elapsed().as_secs_f64();
let e2e_time = e2e_start.elapsed().as_secs_f64();
let witgen_time = e2e_time - proving_time;
println!(
"Proving finished.\n\
\tProving time = {:.3}s, freq = {:.3}khz\n\
\tWitgen time = {:.3}s, freq = {:.3}khz\n\
\tTotal time = {:.3}s, freq = {:.3}khz\n\
\tthread num: {}",
proving_time,
cycle_num as f64 / proving_time / 1000.0,
witgen_time,
cycle_num as f64 / witgen_time / 1000.0,
e2e_time,
cycle_num as f64 / e2e_time / 1000.0,
rayon::current_num_threads()
let (state, _) = run_e2e_with_checkpoint::<E, Pcs>(
program,
platform,
args.stack_size,
args.heap_size,
hints,
max_steps,
Checkpoint::PrepSanityCheck,
);

run_e2e_verify(&verifier, zkvm_proof.clone(), exit_code, max_steps);
let (mut zkvm_proof, verifier) = state.expect("PrepSanityCheck should yield state.");

// do sanity check
let transcript = Transcript::new(b"riscv");
Expand Down Expand Up @@ -207,7 +187,6 @@ fn main() {
}
};
}

fn memory_from_file(path: &Option<String>) -> Vec<u32> {
path.as_ref()
.map(|path| {
Expand Down
Loading

0 comments on commit 774bab5

Please sign in to comment.