Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add e2e prover #188

Merged
merged 22 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ceno_emul/src/rv32im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ pub struct Instruction {
pub func7: u32,
}

impl Default for Instruction {
fn default() -> Self {
insn(InsnKind::INVALID, InsnCategory::Invalid, 0x00, 0x0, 0x00)
}
}

impl DecodedInstruction {
pub fn new(insn: u32) -> Self {
Self {
Expand Down
10 changes: 10 additions & 0 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
rv32im::DecodedInstruction,
CENO_PLATFORM,
};
use crate::rv32im::Instruction;

/// An instruction and its context in an execution trace. That is concrete values of registers and memory.
///
Expand All @@ -22,6 +23,7 @@ pub struct StepRecord {
pub cycle: Cycle,
pub pc: Change<ByteAddr>,
pub insn_code: Word,
pub insn: Instruction,

pub rs1: Option<ReadOp>,
pub rs2: Option<ReadOp>,
Expand Down Expand Up @@ -62,6 +64,10 @@ impl StepRecord {
DecodedInstruction::new(self.insn_code)
}

pub fn insn(&self) -> Instruction {
self.insn
}

pub fn rs1(&self) -> Option<ReadOp> {
self.rs1.clone()
}
Expand Down Expand Up @@ -134,6 +140,10 @@ impl Tracer {
self.record.insn_code = value;
}

pub fn store_insn(&mut self, insn: Instruction) {
self.record.insn = insn;
}

pub fn load_register(&mut self, idx: RegIdx, value: Word) {
let addr = CENO_PLATFORM.register_vma(idx).into();

Expand Down
8 changes: 7 additions & 1 deletion ceno_emul/src/vm_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ impl VMState {
Ok(step)
}
}

pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) {
self.registers[idx] = value;
}
}

impl EmuContext for VMState {
Expand Down Expand Up @@ -109,7 +113,9 @@ impl EmuContext for VMState {
Err(anyhow!("Trap {:?}", cause)) // Crash.
}

fn on_insn_decoded(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) {}
fn on_insn_decoded(&mut self, insn: &Instruction, _decoded: &DecodedInstruction) {
self.tracer.store_insn(*insn);
}

fn on_normal_end(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) {
self.tracer.store_pc(ByteAddr(self.pc));
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ fn bench_add(c: &mut Criterion) {
.collect_vec();
let timer = Instant::now();
let _ = prover
.create_proof(
.create_opcode_proof(
wits_in,
num_instances,
max_threads,
Expand Down
163 changes: 130 additions & 33 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
use std::time::Instant;
use std::{collections::BTreeMap, time::Instant};

use ark_std::test_rng;
use ceno_zkvm::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
instructions::{riscv::addsub::AddInstruction, Instruction},
scheme::prover::ZKVMProver,
UIntValue,
};
use const_env::from_env;

use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM};
use ceno_zkvm::{
circuit_builder::ZKVMConstraintSystem,
scheme::verifier::ZKVMVerifier,
tables::{RangeTableCircuit, TableCircuit},
};
use ff_ext::ff::Field;
use goldilocks::{Goldilocks, GoldilocksExt2};
use goldilocks::GoldilocksExt2;
use itertools::Itertools;
use multilinear_extensions::mle::IntoMLE;
use sumcheck::util::is_power_of_2;
use tracing_flame::FlameLayer;
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry};
Expand All @@ -20,7 +26,23 @@ use transcript::Transcript;
#[from_env]
const RAYON_NUM_THREADS: usize = 8;

// For now, we assume registers
// - x0 is not touched,
// - x1 is initialized to 1,
// - x2 is initialized to -1,
// - x3 is initialized to loop bound.
// we use x4 to hold the acc_sum.
const PROGRAM_ADD_LOOP: [u32; 4] = [
// func7 rs2 rs1 f3 rd opcode
0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1
0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1
0b_1_111111_00000_00011_001_1100_1_1100011, // bne x3, x0, -8
0b_000000000000_00000_000_00000_1110011, // ecall halt
];

fn main() {
type E = GoldilocksExt2;

let max_threads = {
if !is_power_of_2(RAYON_NUM_THREADS) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]
Expand All @@ -41,16 +63,6 @@ fn main() {
RAYON_NUM_THREADS
}
};
let mut cs = ConstraintSystem::new(|| "risv_add");
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);
let _ = AddInstruction::construct_circuit(&mut circuit_builder);
let pk = cs.key_gen(None);
let num_witin = pk.get_cs().num_witin;

let prover = ZKVMProver::new(pk);
let mut transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap();
let subscriber = Registry::default()
Expand All @@ -64,34 +76,119 @@ fn main() {
.with(flame_layer.with_threads_collapsed(true));
tracing::subscriber::set_global_default(subscriber).unwrap();

for instance_num_vars in 20..22 {
// generate mock witness
// keygen
let mut zkvm_fixed_traces = BTreeMap::default();
let mut zkvm_cs = ZKVMConstraintSystem::default();

let (add_cs, add_config) = {
hero78119 marked this conversation as resolved.
Show resolved Hide resolved
let mut cs = ConstraintSystem::new(|| "riscv_add");
let mut circuit_builder = CircuitBuilder::<E>::new(&mut cs);
let config = AddInstruction::construct_circuit(&mut circuit_builder).unwrap();
zkvm_cs.add_cs(AddInstruction::<E>::name(), cs.clone());
zkvm_fixed_traces.insert(AddInstruction::<E>::name(), None);
(cs, config)
};
let (range_cs, range_config) = {
let mut cs = ConstraintSystem::new(|| "riscv_range");
let mut circuit_builder = CircuitBuilder::<E>::new(&mut cs);
let config = RangeTableCircuit::construct_circuit(&mut circuit_builder).unwrap();
zkvm_cs.add_cs(
<RangeTableCircuit<E> as TableCircuit<E>>::name(),
cs.clone(),
);
zkvm_fixed_traces.insert(
<RangeTableCircuit<E> as TableCircuit<E>>::name(),
Some(RangeTableCircuit::<E>::generate_fixed_traces(
&config,
cs.num_fixed,
)),
);
(cs, config)
};
let pk = zkvm_cs.key_gen(zkvm_fixed_traces).expect("keygen failed");
let vk = pk.get_vk();

// proving
let prover = ZKVMProver::new(pk);
let verifier = ZKVMVerifier::new(vk);

for instance_num_vars in 8..22 {
let num_instances = 1 << instance_num_vars;
let wits_in = (0..num_witin as usize)
.map(|_| {
(0..num_instances)
.map(|_| Goldilocks::random(&mut rng))
.collect::<Vec<Goldilocks>>()
.into_mle()
.into()
let mut vm = VMState::new(CENO_PLATFORM);
let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr();

// init vm.x1 = 1, vm.x2 = -1, vm.x3 = num_instances
// vm.x4 += vm.x1
vm.init_register_unsafe(1usize, 1);
vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement
vm.init_register_unsafe(3usize, num_instances as u32);
for (i, inst) in PROGRAM_ADD_LOOP.iter().enumerate() {
vm.init_memory(pc_start + i, *inst);
}
let records = vm
.iter_until_success()
.collect::<Result<Vec<StepRecord>, _>>()
.expect("vm exec failed")
.into_iter()
.filter(|record| record.insn().kind == ADD)
.collect::<Vec<_>>();
tracing::info!("tracer generated {} ADD records", records.len());

// TODO: generate range check inputs from opcode_circuit::assign_instances()
let rc_inputs = records
.iter()
.flat_map(|record| {
let rs1 = UIntValue::new(record.rs1().unwrap().value);
let rs2 = UIntValue::new(record.rs2().unwrap().value);

let rd_prev = UIntValue::new(record.rd().unwrap().value.before);
let rd = UIntValue::new(record.rd().unwrap().value.after);
let carries = rs1
.add_u16_carries(&rs2)
.into_iter()
.map(|c| c as u16)
.collect_vec();

[rd_prev.limbs, rd.limbs, carries].concat()
})
.collect_vec();
.map(|x| x as usize)
.collect::<Vec<_>>();

let mut zkvm_witness = BTreeMap::default();
let add_witness =
AddInstruction::assign_instances(&add_config, add_cs.num_witin as usize, records)
.unwrap();
let range_witness = RangeTableCircuit::<E>::assign_instances(
&range_config,
range_cs.num_witin as usize,
&rc_inputs,
)
.unwrap();

zkvm_witness.insert(AddInstruction::<E>::name(), add_witness);
zkvm_witness.insert(RangeTableCircuit::<E>::name(), range_witness);
hero78119 marked this conversation as resolved.
Show resolved Hide resolved

let timer = Instant::now();
let _ = prover
.create_proof(
wits_in,
num_instances,
max_threads,
&mut transcript,
&real_challenges,
)

let mut transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let zkvm_proof = prover
.create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges)
.expect("create_proof failed");

let mut transcript = Transcript::new(b"riscv");
assert!(
verifier
.verify_proof(zkvm_proof, &mut transcript, &real_challenges)
.expect("verify proof return with error"),
);

println!(
"AddInstruction::create_proof, instance_num_vars = {}, time = {}",
instance_num_vars,
timer.elapsed().as_secs_f64()
);
}

type E = GoldilocksExt2;
}
52 changes: 46 additions & 6 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::marker::PhantomData;
use itertools::Itertools;
use std::{collections::BTreeMap, marker::PhantomData};

use ff_ext::ExtensionField;
use multilinear_extensions::mle::DenseMultilinearExtension;
use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLEs};

use crate::{
error::ZKVMError,
expression::{Expression, Fixed, WitIn},
structs::WitnessId,
witness::RowMajorMatrix,
};

/// namespace used for annotation, preserve meta info during circuit construction
Expand Down Expand Up @@ -135,7 +137,13 @@ impl<E: ExtensionField> ConstraintSystem<E> {
}
}

pub fn key_gen(self, fixed_traces: Option<Vec<DenseMultilinearExtension<E>>>) -> ProvingKey<E> {
pub fn key_gen(self, fixed_traces: Option<RowMajorMatrix<E::BaseField>>) -> ProvingKey<E> {
// TODO: commit to fixed_traces

// transpose from row-major to column-major
let fixed_traces =
fixed_traces.map(|t| t.de_interleaving().into_mles().into_iter().collect_vec());

ProvingKey {
fixed_traces,
vk: VerifyingKey { cs: self },
Expand Down Expand Up @@ -301,9 +309,6 @@ pub struct ProvingKey<E: ExtensionField> {
}

impl<E: ExtensionField> ProvingKey<E> {
// pub fn create_pk(vk: VerifyingKey<E>) -> Self {
// Self { vk }
// }
pub fn get_cs(&self) -> &ConstraintSystem<E> {
self.vk.get_cs()
}
Expand All @@ -319,3 +324,38 @@ impl<E: ExtensionField> VerifyingKey<E> {
&self.cs
}
}

#[derive(Default)]
pub struct ZKVMConstraintSystem<E: ExtensionField> {
pub circuit_css: BTreeMap<String, ConstraintSystem<E>>,
}

impl<E: ExtensionField> ZKVMConstraintSystem<E> {
pub fn add_cs(&mut self, name: String, cs: ConstraintSystem<E>) {
assert!(self.circuit_css.insert(name, cs).is_none());
}
}

#[derive(Default)]
pub struct ZKVMProvingKey<E: ExtensionField> {
// pk for opcode and table circuits
pub circuit_pks: BTreeMap<String, ProvingKey<E>>,
}

impl<E: ExtensionField> ZKVMProvingKey<E> {
pub fn get_vk(&self) -> ZKVMVerifyingKey<E> {
ZKVMVerifyingKey {
circuit_vks: self
.circuit_pks
.iter()
.map(|(name, pk)| (name.clone(), pk.vk.clone()))
.collect(),
}
}
}

#[derive(Default)]
pub struct ZKVMVerifyingKey<E: ExtensionField> {
// pk for opcode and table circuits
pub circuit_vks: BTreeMap<String, VerifyingKey<E>>,
}
5 changes: 4 additions & 1 deletion ceno_zkvm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ pub enum UtilError {
pub enum ZKVMError {
CircuitError,
UtilError(UtilError),
VerifyError(&'static str),
WitnessNotFound(String),
VKNotFound(String),
FixedTraceNotFound(String),
VerifyError(String),
}

impl From<UtilError> for ZKVMError {
Expand Down
2 changes: 2 additions & 0 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub mod riscv;

pub trait Instruction<E: ExtensionField> {
type InstructionConfig: Send + Sync;

fn name() -> String;
fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError>;
Expand Down
Loading
Loading