From dc2107b9620c57cb6c0c11181971238e7c7f843a Mon Sep 17 00:00:00 2001 From: Ming Date: Thu, 12 Sep 2024 22:59:54 +0800 Subject: [PATCH] fix add bench build error & example (#211) --- ceno_zkvm/benches/riscv_add.rs | 25 ++++++++++++++++------ ceno_zkvm/examples/riscv_add.rs | 17 ++++++++------- ceno_zkvm/src/instructions/riscv/addsub.rs | 18 ++++++++-------- ceno_zkvm/src/instructions/riscv/config.rs | 13 ++++++++--- ceno_zkvm/src/structs.rs | 2 +- 5 files changed, 48 insertions(+), 27 deletions(-) diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index aa05b218f..c1bb982f9 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -2,9 +2,10 @@ use std::time::{Duration, Instant}; use ark_std::test_rng; use ceno_zkvm::{ - circuit_builder::{CircuitBuilder, ConstraintSystem}, + self, instructions::{riscv::addsub::AddInstruction, Instruction}, scheme::prover::ZKVMProver, + structs::{ZKVMConstraintSystem, ZKVMFixedTraces}, }; use const_env::from_env; use criterion::*; @@ -62,11 +63,22 @@ fn bench_add(c: &mut Criterion) { RAYON_NUM_THREADS } }; - let mut cs = ConstraintSystem::new(|| "risv_add"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let _ = AddInstruction::construct_circuit(&mut circuit_builder); - let pk = cs.key_gen(None); - let num_witin = pk.get_cs().num_witin; + let mut zkvm_cs = ZKVMConstraintSystem::default(); + let _ = zkvm_cs.register_opcode_circuit::>(); + let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); + zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); + + let pk = zkvm_cs + .clone() + .key_gen(zkvm_fixed_traces) + .expect("keygen failed"); + + let circuit_pk = pk + .circuit_pks + .get(&AddInstruction::::name()) + .unwrap() + .clone(); + let num_witin = circuit_pk.get_cs().num_witin; let prover = ZKVMProver::new(pk); let mut transcript = Transcript::new(b"riscv"); @@ -101,6 +113,7 @@ fn bench_add(c: &mut Criterion) { let timer = Instant::now(); let _ = prover .create_opcode_proof( + &circuit_pk, wits_in, num_instances, max_threads, diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index 5ff5df03a..773fbcf01 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -27,6 +27,7 @@ const RAYON_NUM_THREADS: usize = 8; // - x2 is initialized to -1, // - x3 is initialized to loop bound. // we use x4 to hold the acc_sum. +#[allow(clippy::unusual_byte_groupings)] const PROGRAM_ADD_LOOP: [u32; 4] = [ // func7 rs2 rs1 f3 rd opcode 0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1 @@ -106,7 +107,7 @@ fn main() { let verifier = ZKVMVerifier::new(vk); for instance_num_vars in args.start..args.end { - let num_instances = 1 << instance_num_vars; + let step_loop = 1 << (instance_num_vars - 1); // 1 step in loop contribute to 2 add instance let mut vm = VMState::new(CENO_PLATFORM); let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); @@ -114,7 +115,7 @@ fn main() { // vm.x4 += vm.x1 vm.init_register_unsafe(1usize, 1); vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement - vm.init_register_unsafe(3usize, num_instances as u32); + vm.init_register_unsafe(3usize, step_loop as u32); for (i, inst) in PROGRAM_ADD_LOOP.iter().enumerate() { vm.init_memory(pc_start + i, *inst); } @@ -148,17 +149,17 @@ fn main() { .create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges) .expect("create_proof failed"); + println!( + "AddInstruction::create_proof, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + let mut transcript = Transcript::new(b"riscv"); assert!( verifier .verify_proof(zkvm_proof, &mut transcript, &real_challenges) .expect("verify proof return with error"), ); - - println!( - "AddInstruction::create_proof, instance_num_vars = {}, time = {}", - instance_num_vars, - timer.elapsed().as_secs_f64() - ); } } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 7f5af92fe..27254dc82 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -188,20 +188,20 @@ fn add_sub_assignment( set_val!(instance, config.rs2_id, step.insn().rs2() as u64); set_val!(instance, config.rd_id, step.insn().rd() as u64); ExprLtInput { - lhs: step.rs1().unwrap().previous_cycle, // rs1 - rhs: step.cycle(), // cur_ts + lhs: step.rs1().unwrap().previous_cycle, + rhs: step.cycle(), } - .assign(instance, &config.lt_rs1_cfg); + .assign(instance, &config.lt_rs1_cfg, lk_multiplicity); ExprLtInput { - lhs: step.rs2().unwrap().previous_cycle, // rs2 - rhs: step.cycle() + 1, // cur_ts + lhs: step.rs2().unwrap().previous_cycle, + rhs: step.cycle() + 1, } - .assign(instance, &config.lt_rs2_cfg); + .assign(instance, &config.lt_rs2_cfg, lk_multiplicity); ExprLtInput { - lhs: step.rd().unwrap().previous_cycle, // rd - rhs: step.cycle() + 2, // cur_ts + lhs: step.rd().unwrap().previous_cycle, + rhs: step.cycle() + 2, } - .assign(instance, &config.lt_prev_ts_cfg); + .assign(instance, &config.lt_prev_ts_cfg, lk_multiplicity); set_val!( instance, config.prev_rs1_ts, diff --git a/ceno_zkvm/src/instructions/riscv/config.rs b/ceno_zkvm/src/instructions/riscv/config.rs index 378a1a9b1..2396475cd 100644 --- a/ceno_zkvm/src/instructions/riscv/config.rs +++ b/ceno_zkvm/src/instructions/riscv/config.rs @@ -1,6 +1,6 @@ use std::mem::MaybeUninit; -use crate::{expression::WitIn, set_val, utils::i64_to_base}; +use crate::{expression::WitIn, set_val, utils::i64_to_base, witness::LkMultiplicity}; use goldilocks::SmallField; use itertools::Itertools; @@ -184,7 +184,12 @@ pub struct ExprLtInput { } impl ExprLtInput { - pub fn assign(&self, instance: &mut [MaybeUninit], config: &ExprLtConfig) { + pub fn assign( + &self, + instance: &mut [MaybeUninit], + config: &ExprLtConfig, + lkm: &mut LkMultiplicity, + ) { let is_lt = if let Some(is_lt_wit) = config.is_lt { let is_lt = self.lhs < self.rhs; set_val!(instance, is_lt_wit, is_lt as u64); @@ -197,7 +202,9 @@ impl ExprLtInput { let diff = if is_lt { 1u64 << u32::BITS } else { 0 } + self.lhs - self.rhs; config.diff.iter().enumerate().for_each(|(i, wit)| { // extract the 16 bit limb from diff and assign to instance - set_val!(instance, wit, (diff >> (i * u16::BITS as usize)) & 0xffff); + let val = (diff >> (i * u16::BITS as usize)) & 0xffff; + lkm.assert_ux::<16>(val); + set_val!(instance, wit, val); }); } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index c7c136d66..e2bbe0b56 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -245,7 +245,7 @@ impl ZKVMWitnesses { #[derive(Default)] pub struct ZKVMProvingKey { // pk for opcode and table circuits - pub(crate) circuit_pks: BTreeMap>, + pub circuit_pks: BTreeMap>, } impl ZKVMProvingKey {