Skip to content

Commit

Permalink
Feat: add e2e prover (#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
kunxian-xia authored Sep 11, 2024
1 parent 5a68423 commit 0cfa10b
Show file tree
Hide file tree
Showing 22 changed files with 792 additions and 269 deletions.
6 changes: 6 additions & 0 deletions ceno_emul/src/rv32im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ pub struct Instruction {
pub func7: u32,
}

impl Default for Instruction {
fn default() -> Self {
insn(InsnKind::INVALID, InsnCategory::Invalid, 0x00, 0x0, 0x00)
}
}

impl DecodedInstruction {
pub fn new(insn: u32) -> Self {
Self {
Expand Down
10 changes: 10 additions & 0 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
rv32im::DecodedInstruction,
CENO_PLATFORM,
};
use crate::rv32im::Instruction;

/// An instruction and its context in an execution trace. That is concrete values of registers and memory.
///
Expand All @@ -22,6 +23,7 @@ pub struct StepRecord {
pub cycle: Cycle,
pub pc: Change<ByteAddr>,
pub insn_code: Word,
pub insn: Instruction,

pub rs1: Option<ReadOp>,
pub rs2: Option<ReadOp>,
Expand Down Expand Up @@ -69,6 +71,10 @@ impl StepRecord {
DecodedInstruction::new(self.insn_code)
}

pub fn insn(&self) -> Instruction {
self.insn
}

pub fn rs1(&self) -> Option<ReadOp> {
self.rs1.clone()
}
Expand Down Expand Up @@ -141,6 +147,10 @@ impl Tracer {
self.record.insn_code = value;
}

pub fn store_insn(&mut self, insn: Instruction) {
self.record.insn = insn;
}

pub fn load_register(&mut self, idx: RegIdx, value: Word) {
let addr = CENO_PLATFORM.register_vma(idx).into();

Expand Down
8 changes: 7 additions & 1 deletion ceno_emul/src/vm_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ impl VMState {
Ok(step)
}
}

pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) {
self.registers[idx] = value;
}
}

impl EmuContext for VMState {
Expand Down Expand Up @@ -109,7 +113,9 @@ impl EmuContext for VMState {
Err(anyhow!("Trap {:?}", cause)) // Crash.
}

fn on_insn_decoded(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) {}
fn on_insn_decoded(&mut self, insn: &Instruction, _decoded: &DecodedInstruction) {
self.tracer.store_insn(*insn);
}

fn on_normal_end(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) {
self.tracer.store_pc(ByteAddr(self.pc));
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ fn bench_add(c: &mut Criterion) {
.collect_vec();
let timer = Instant::now();
let _ = prover
.create_proof(
.create_opcode_proof(
wits_in,
num_instances,
max_threads,
Expand Down
130 changes: 91 additions & 39 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use std::time::Instant;

use ark_std::test_rng;
use ceno_zkvm::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::{riscv::addsub::AddInstruction, Instruction},
scheme::prover::ZKVMProver,
};
use ceno_zkvm::{instructions::riscv::addsub::AddInstruction, scheme::prover::ZKVMProver};
use const_env::from_env;

use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM};
use ceno_zkvm::{
scheme::verifier::ZKVMVerifier,
structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::RangeTableCircuit,
};
use ff_ext::ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLE;
use goldilocks::GoldilocksExt2;
use sumcheck::util::is_power_of_2;
use tracing_flame::FlameLayer;
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry};
Expand All @@ -20,7 +20,23 @@ use transcript::Transcript;
#[from_env]
const RAYON_NUM_THREADS: usize = 8;

// For now, we assume registers
// - x0 is not touched,
// - x1 is initialized to 1,
// - x2 is initialized to -1,
// - x3 is initialized to loop bound.
// we use x4 to hold the acc_sum.
const PROGRAM_ADD_LOOP: [u32; 4] = [
// func7 rs2 rs1 f3 rd opcode
0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1
0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1
0b_1_111111_00000_00011_001_1100_1_1100011, // bne x3, x0, -8
0b_000000000000_00000_000_00000_1110011, // ecall halt
];

fn main() {
type E = GoldilocksExt2;

let max_threads = {
if !is_power_of_2(RAYON_NUM_THREADS) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
Expand All @@ -41,16 +57,6 @@ fn main() {
RAYON_NUM_THREADS
}
};
let mut cs = ConstraintSystem::new(|| "risv_add");
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);
let _ = AddInstruction::construct_circuit(&mut circuit_builder);
let pk = cs.key_gen(None);
let num_witin = pk.get_cs().num_witin;

let prover = ZKVMProver::new(pk);
let mut transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap();
let subscriber = Registry::default()
Expand All @@ -64,34 +70,80 @@ fn main() {
.with(flame_layer.with_threads_collapsed(true));
tracing::subscriber::set_global_default(subscriber).unwrap();

for instance_num_vars in 20..22 {
// generate mock witness
// keygen
let mut zkvm_cs = ZKVMConstraintSystem::default();
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let range_config = zkvm_cs.register_table_circuit::<RangeTableCircuit<E>>();

let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);
zkvm_fixed_traces
.register_table_circuit::<RangeTableCircuit<E>>(&zkvm_cs, range_config.clone());

let pk = zkvm_cs
.clone()
.key_gen(zkvm_fixed_traces)
.expect("keygen failed");
let vk = pk.get_vk();

// proving
let prover = ZKVMProver::new(pk);
let verifier = ZKVMVerifier::new(vk);

for instance_num_vars in 8..22 {
let num_instances = 1 << instance_num_vars;
let wits_in = (0..num_witin as usize)
.map(|_| {
(0..num_instances)
.map(|_| Goldilocks::random(&mut rng))
.collect::<Vec<Goldilocks>>()
.into_mle()
.into()
})
.collect_vec();
let mut vm = VMState::new(CENO_PLATFORM);
let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr();

// init vm.x1 = 1, vm.x2 = -1, vm.x3 = num_instances
// vm.x4 += vm.x1
vm.init_register_unsafe(1usize, 1);
vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement
vm.init_register_unsafe(3usize, num_instances as u32);
for (i, inst) in PROGRAM_ADD_LOOP.iter().enumerate() {
vm.init_memory(pc_start + i, *inst);
}
let records = vm
.iter_until_success()
.collect::<Result<Vec<StepRecord>, _>>()
.expect("vm exec failed")
.into_iter()
.filter(|record| record.insn().kind == ADD)
.collect::<Vec<_>>();
tracing::info!("tracer generated {} ADD records", records.len());

let mut zkvm_witness = ZKVMWitnesses::default();
// assign opcode circuits
zkvm_witness
.assign_opcode_circuit::<AddInstruction<E>>(&zkvm_cs, &add_config, records)
.unwrap();
zkvm_witness.finalize_lk_multiplicities();
// assign table circuits
zkvm_witness
.assign_table_circuit::<RangeTableCircuit<E>>(&zkvm_cs, &range_config)
.unwrap();

let timer = Instant::now();
let _ = prover
.create_proof(
wits_in,
num_instances,
max_threads,
&mut transcript,
&real_challenges,
)

let mut transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let zkvm_proof = prover
.create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges)
.expect("create_proof failed");

let mut transcript = Transcript::new(b"riscv");
assert!(
verifier
.verify_proof(zkvm_proof, &mut transcript, &real_challenges)
.expect("verify proof return with error"),
);

println!(
"AddInstruction::create_proof, instance_num_vars = {}, time = {}",
instance_num_vars,
timer.elapsed().as_secs_f64()
);
}

type E = GoldilocksExt2;
}
40 changes: 11 additions & 29 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use itertools::Itertools;
use std::marker::PhantomData;

use ff_ext::ExtensionField;
use multilinear_extensions::mle::DenseMultilinearExtension;
use multilinear_extensions::mle::IntoMLEs;

use crate::{
error::ZKVMError,
expression::{Expression, Fixed, WitIn},
structs::WitnessId,
structs::{ProvingKey, VerifyingKey, WitnessId},
witness::RowMajorMatrix,
};

/// namespace used for annotation, preserve meta info during circuit construction
Expand Down Expand Up @@ -135,7 +137,13 @@ impl<E: ExtensionField> ConstraintSystem<E> {
}
}

pub fn key_gen(self, fixed_traces: Option<Vec<DenseMultilinearExtension<E>>>) -> ProvingKey<E> {
pub fn key_gen(self, fixed_traces: Option<RowMajorMatrix<E::BaseField>>) -> ProvingKey<E> {
// TODO: commit to fixed_traces

// transpose from row-major to column-major
let fixed_traces =
fixed_traces.map(|t| t.de_interleaving().into_mles().into_iter().collect_vec());

ProvingKey {
fixed_traces,
vk: VerifyingKey { cs: self },
Expand Down Expand Up @@ -293,29 +301,3 @@ impl<E: ExtensionField> ConstraintSystem<E> {
pub struct CircuitBuilder<'a, E: ExtensionField> {
pub(crate) cs: &'a mut ConstraintSystem<E>,
}

#[derive(Clone, Debug)]
pub struct ProvingKey<E: ExtensionField> {
pub fixed_traces: Option<Vec<DenseMultilinearExtension<E>>>,
pub vk: VerifyingKey<E>,
}

impl<E: ExtensionField> ProvingKey<E> {
// pub fn create_pk(vk: VerifyingKey<E>) -> Self {
// Self { vk }
// }
pub fn get_cs(&self) -> &ConstraintSystem<E> {
self.vk.get_cs()
}
}

#[derive(Clone, Debug)]
pub struct VerifyingKey<E: ExtensionField> {
cs: ConstraintSystem<E>,
}

impl<E: ExtensionField> VerifyingKey<E> {
pub fn get_cs(&self) -> &ConstraintSystem<E> {
&self.cs
}
}
5 changes: 4 additions & 1 deletion ceno_zkvm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ pub enum UtilError {
pub enum ZKVMError {
CircuitError,
UtilError(UtilError),
VerifyError(&'static str),
WitnessNotFound(String),
VKNotFound(String),
FixedTraceNotFound(String),
VerifyError(String),
}

impl From<UtilError> for ZKVMError {
Expand Down
2 changes: 2 additions & 0 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub mod riscv;

pub trait Instruction<E: ExtensionField> {
type InstructionConfig: Send + Sync;

fn name() -> String;
fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError>;
Expand Down
Loading

0 comments on commit 0cfa10b

Please sign in to comment.