From 0cfa10bb6e80ea01700bce0db74f82ea8934d935 Mon Sep 17 00:00:00 2001 From: xkx Date: Wed, 11 Sep 2024 13:42:00 +0800 Subject: [PATCH] Feat: add e2e prover (#188) --- ceno_emul/src/rv32im.rs | 6 + ceno_emul/src/tracer.rs | 10 ++ ceno_emul/src/vm_state.rs | 8 +- ceno_zkvm/benches/riscv_add.rs | 2 +- ceno_zkvm/examples/riscv_add.rs | 130 ++++++++++---- ceno_zkvm/src/circuit_builder.rs | 40 ++--- ceno_zkvm/src/error.rs | 5 +- ceno_zkvm/src/instructions.rs | 2 + ceno_zkvm/src/instructions/riscv/addsub.rs | 50 ++++-- ceno_zkvm/src/instructions/riscv/blt.rs | 3 + ceno_zkvm/src/keygen.rs | 26 +++ ceno_zkvm/src/lib.rs | 6 +- ceno_zkvm/src/scheme.rs | 9 +- ceno_zkvm/src/scheme/prover.rs | 111 ++++++++++-- ceno_zkvm/src/scheme/tests.rs | 114 ++++++++---- ceno_zkvm/src/scheme/verifier.rs | 133 +++++++++++--- ceno_zkvm/src/structs.rs | 199 ++++++++++++++++++++- ceno_zkvm/src/tables/mod.rs | 27 +++ ceno_zkvm/src/tables/range.rs | 155 ++++++---------- ceno_zkvm/src/uint.rs | 2 +- ceno_zkvm/src/witness.rs | 21 ++- multilinear_extensions/src/mle.rs | 2 +- 22 files changed, 792 insertions(+), 269 deletions(-) create mode 100644 ceno_zkvm/src/keygen.rs diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 710d7e7d4..64e758f55 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -183,6 +183,12 @@ pub struct Instruction { pub func7: u32, } +impl Default for Instruction { + fn default() -> Self { + insn(InsnKind::INVALID, InsnCategory::Invalid, 0x00, 0x0, 0x00) + } +} + impl DecodedInstruction { pub fn new(insn: u32) -> Self { Self { diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index f7845cf2a..625b86ef1 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -5,6 +5,7 @@ use crate::{ rv32im::DecodedInstruction, CENO_PLATFORM, }; +use crate::rv32im::Instruction; /// An instruction and its context in an execution trace. That is concrete values of registers and memory. /// @@ -22,6 +23,7 @@ pub struct StepRecord { pub cycle: Cycle, pub pc: Change, pub insn_code: Word, + pub insn: Instruction, pub rs1: Option, pub rs2: Option, @@ -69,6 +71,10 @@ impl StepRecord { DecodedInstruction::new(self.insn_code) } + pub fn insn(&self) -> Instruction { + self.insn + } + pub fn rs1(&self) -> Option { self.rs1.clone() } @@ -141,6 +147,10 @@ impl Tracer { self.record.insn_code = value; } + pub fn store_insn(&mut self, insn: Instruction) { + self.record.insn = insn; + } + pub fn load_register(&mut self, idx: RegIdx, value: Word) { let addr = CENO_PLATFORM.register_vma(idx).into(); diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 778c77839..ffd449535 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -82,6 +82,10 @@ impl VMState { Ok(step) } } + + pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) { + self.registers[idx] = value; + } } impl EmuContext for VMState { @@ -109,7 +113,9 @@ impl EmuContext for VMState { Err(anyhow!("Trap {:?}", cause)) // Crash. } - fn on_insn_decoded(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) {} + fn on_insn_decoded(&mut self, insn: &Instruction, _decoded: &DecodedInstruction) { + self.tracer.store_insn(*insn); + } fn on_normal_end(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) { self.tracer.store_pc(ByteAddr(self.pc)); diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index d40e92941..aa05b218f 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -100,7 +100,7 @@ fn bench_add(c: &mut Criterion) { .collect_vec(); let timer = Instant::now(); let _ = prover - .create_proof( + .create_opcode_proof( wits_in, num_instances, max_threads, diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index b7c547782..ec8732085 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -1,17 +1,17 @@ use std::time::Instant; use ark_std::test_rng; -use ceno_zkvm::{ - circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{riscv::addsub::AddInstruction, Instruction}, - scheme::prover::ZKVMProver, -}; +use ceno_zkvm::{instructions::riscv::addsub::AddInstruction, scheme::prover::ZKVMProver}; use const_env::from_env; +use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM}; +use ceno_zkvm::{ + scheme::verifier::ZKVMVerifier, + structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, + tables::RangeTableCircuit, +}; use ff_ext::ff::Field; -use goldilocks::{Goldilocks, GoldilocksExt2}; -use itertools::Itertools; -use multilinear_extensions::mle::IntoMLE; +use goldilocks::GoldilocksExt2; use sumcheck::util::is_power_of_2; use tracing_flame::FlameLayer; use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry}; @@ -20,7 +20,23 @@ use transcript::Transcript; #[from_env] const RAYON_NUM_THREADS: usize = 8; +// For now, we assume registers +// - x0 is not touched, +// - x1 is initialized to 1, +// - x2 is initialized to -1, +// - x3 is initialized to loop bound. +// we use x4 to hold the acc_sum. +const PROGRAM_ADD_LOOP: [u32; 4] = [ + // func7 rs2 rs1 f3 rd opcode + 0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1 + 0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1 + 0b_1_111111_00000_00011_001_1100_1_1100011, // bne x3, x0, -8 + 0b_000000000000_00000_000_00000_1110011, // ecall halt +]; + fn main() { + type E = GoldilocksExt2; + let max_threads = { if !is_power_of_2(RAYON_NUM_THREADS) { #[cfg(not(feature = "non_pow2_rayon_thread"))] @@ -41,16 +57,6 @@ fn main() { RAYON_NUM_THREADS } }; - let mut cs = ConstraintSystem::new(|| "risv_add"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let _ = AddInstruction::construct_circuit(&mut circuit_builder); - let pk = cs.key_gen(None); - let num_witin = pk.get_cs().num_witin; - - let prover = ZKVMProver::new(pk); - let mut transcript = Transcript::new(b"riscv"); - let mut rng = test_rng(); - let real_challenges = [E::random(&mut rng), E::random(&mut rng)]; let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() @@ -64,34 +70,80 @@ fn main() { .with(flame_layer.with_threads_collapsed(true)); tracing::subscriber::set_global_default(subscriber).unwrap(); - for instance_num_vars in 20..22 { - // generate mock witness + // keygen + let mut zkvm_cs = ZKVMConstraintSystem::default(); + let add_config = zkvm_cs.register_opcode_circuit::>(); + let range_config = zkvm_cs.register_table_circuit::>(); + + let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); + zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); + zkvm_fixed_traces + .register_table_circuit::>(&zkvm_cs, range_config.clone()); + + let pk = zkvm_cs + .clone() + .key_gen(zkvm_fixed_traces) + .expect("keygen failed"); + let vk = pk.get_vk(); + + // proving + let prover = ZKVMProver::new(pk); + let verifier = ZKVMVerifier::new(vk); + + for instance_num_vars in 8..22 { let num_instances = 1 << instance_num_vars; - let wits_in = (0..num_witin as usize) - .map(|_| { - (0..num_instances) - .map(|_| Goldilocks::random(&mut rng)) - .collect::>() - .into_mle() - .into() - }) - .collect_vec(); + let mut vm = VMState::new(CENO_PLATFORM); + let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); + + // init vm.x1 = 1, vm.x2 = -1, vm.x3 = num_instances + // vm.x4 += vm.x1 + vm.init_register_unsafe(1usize, 1); + vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement + vm.init_register_unsafe(3usize, num_instances as u32); + for (i, inst) in PROGRAM_ADD_LOOP.iter().enumerate() { + vm.init_memory(pc_start + i, *inst); + } + let records = vm + .iter_until_success() + .collect::, _>>() + .expect("vm exec failed") + .into_iter() + .filter(|record| record.insn().kind == ADD) + .collect::>(); + tracing::info!("tracer generated {} ADD records", records.len()); + + let mut zkvm_witness = ZKVMWitnesses::default(); + // assign opcode circuits + zkvm_witness + .assign_opcode_circuit::>(&zkvm_cs, &add_config, records) + .unwrap(); + zkvm_witness.finalize_lk_multiplicities(); + // assign table circuits + zkvm_witness + .assign_table_circuit::>(&zkvm_cs, &range_config) + .unwrap(); + let timer = Instant::now(); - let _ = prover - .create_proof( - wits_in, - num_instances, - max_threads, - &mut transcript, - &real_challenges, - ) + + let mut transcript = Transcript::new(b"riscv"); + let mut rng = test_rng(); + let real_challenges = [E::random(&mut rng), E::random(&mut rng)]; + + let zkvm_proof = prover + .create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges) .expect("create_proof failed"); + + let mut transcript = Transcript::new(b"riscv"); + assert!( + verifier + .verify_proof(zkvm_proof, &mut transcript, &real_challenges) + .expect("verify proof return with error"), + ); + println!( "AddInstruction::create_proof, instance_num_vars = {}, time = {}", instance_num_vars, timer.elapsed().as_secs_f64() ); } - - type E = GoldilocksExt2; } diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 794ababb4..fc127b5ba 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -1,12 +1,14 @@ +use itertools::Itertools; use std::marker::PhantomData; use ff_ext::ExtensionField; -use multilinear_extensions::mle::DenseMultilinearExtension; +use multilinear_extensions::mle::IntoMLEs; use crate::{ error::ZKVMError, expression::{Expression, Fixed, WitIn}, - structs::WitnessId, + structs::{ProvingKey, VerifyingKey, WitnessId}, + witness::RowMajorMatrix, }; /// namespace used for annotation, preserve meta info during circuit construction @@ -135,7 +137,13 @@ impl ConstraintSystem { } } - pub fn key_gen(self, fixed_traces: Option>>) -> ProvingKey { + pub fn key_gen(self, fixed_traces: Option>) -> ProvingKey { + // TODO: commit to fixed_traces + + // transpose from row-major to column-major + let fixed_traces = + fixed_traces.map(|t| t.de_interleaving().into_mles().into_iter().collect_vec()); + ProvingKey { fixed_traces, vk: VerifyingKey { cs: self }, @@ -293,29 +301,3 @@ impl ConstraintSystem { pub struct CircuitBuilder<'a, E: ExtensionField> { pub(crate) cs: &'a mut ConstraintSystem, } - -#[derive(Clone, Debug)] -pub struct ProvingKey { - pub fixed_traces: Option>>, - pub vk: VerifyingKey, -} - -impl ProvingKey { - // pub fn create_pk(vk: VerifyingKey) -> Self { - // Self { vk } - // } - pub fn get_cs(&self) -> &ConstraintSystem { - self.vk.get_cs() - } -} - -#[derive(Clone, Debug)] -pub struct VerifyingKey { - cs: ConstraintSystem, -} - -impl VerifyingKey { - pub fn get_cs(&self) -> &ConstraintSystem { - &self.cs - } -} diff --git a/ceno_zkvm/src/error.rs b/ceno_zkvm/src/error.rs index d623364c9..b7791def9 100644 --- a/ceno_zkvm/src/error.rs +++ b/ceno_zkvm/src/error.rs @@ -7,7 +7,10 @@ pub enum UtilError { pub enum ZKVMError { CircuitError, UtilError(UtilError), - VerifyError(&'static str), + WitnessNotFound(String), + VKNotFound(String), + FixedTraceNotFound(String), + VerifyError(String), } impl From for ZKVMError { diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 228fce182..e717a7682 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -17,6 +17,8 @@ pub mod riscv; pub trait Instruction { type InstructionConfig: Send + Sync; + + fn name() -> String; fn construct_circuit( circuit_builder: &mut CircuitBuilder, ) -> Result; diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index da3377545..f563ba68b 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -1,6 +1,5 @@ use std::marker::PhantomData; -use ark_std::iterable::Iterable; use ceno_emul::StepRecord; use ff_ext::ExtensionField; use itertools::Itertools; @@ -21,8 +20,8 @@ use crate::{ }; use core::mem::MaybeUninit; -pub struct AddInstruction; -pub struct SubInstruction; +pub struct AddInstruction(PhantomData); +pub struct SubInstruction(PhantomData); #[derive(Debug)] pub struct InstructionConfig { @@ -41,11 +40,11 @@ pub struct InstructionConfig { phantom: PhantomData, } -impl RIVInstruction for AddInstruction { +impl RIVInstruction for AddInstruction { const OPCODE_TYPE: OpcodeType = OpcodeType::RType(OPType::Op, 0x000, 0x0000000); } -impl RIVInstruction for SubInstruction { +impl RIVInstruction for SubInstruction { const OPCODE_TYPE: OpcodeType = OpcodeType::RType(OPType::Op, 0x000, 0x0100000); } @@ -61,7 +60,7 @@ fn add_sub_gadget( let next_pc = pc.expr() + PC_STEP_SIZE.into(); // Execution result = addend0 + addend1, with carry. - let prev_rd_value = RegUInt::new(|| "prev_rd_value", circuit_builder)?; + let prev_rd_value = RegUInt::new_unchecked(|| "prev_rd_value", circuit_builder)?; let (addend_0, addend_1, outcome) = if IS_ADD { // outcome = addend_0 + addend_1 @@ -139,8 +138,11 @@ fn add_sub_gadget( }) } -impl Instruction for AddInstruction { +impl Instruction for AddInstruction { // const NAME: &'static str = "ADD"; + fn name() -> String { + "ADD".into() + } type InstructionConfig = InstructionConfig; fn construct_circuit( circuit_builder: &mut CircuitBuilder, @@ -160,9 +162,10 @@ impl Instruction for AddInstruction { set_val!(instance, config.ts, 2); let addend_0 = UIntValue::new_unchecked(step.rs1().unwrap().value); let addend_1 = UIntValue::new_unchecked(step.rs2().unwrap().value); + let rd_prev = UIntValue::new_unchecked(step.rd().unwrap().value.before); config .prev_rd_value - .assign_limbs(instance, [0, 0].iter().map(E::BaseField::from).collect()); + .assign_limbs(instance, rd_prev.u16_fields()); config .addend_0 .assign_limbs(instance, addend_0.u16_fields()); @@ -188,8 +191,11 @@ impl Instruction for AddInstruction { } } -impl Instruction for SubInstruction { +impl Instruction for SubInstruction { // const NAME: &'static str = "ADD"; + fn name() -> String { + "SUB".into() + } type InstructionConfig = InstructionConfig; fn construct_circuit( circuit_builder: &mut CircuitBuilder, @@ -237,7 +243,7 @@ impl Instruction for SubInstruction { #[cfg(test)] mod test { - use ceno_emul::{ReadOp, StepRecord}; + use ceno_emul::{Change, ReadOp, StepRecord, WriteOp}; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; @@ -271,15 +277,23 @@ mod test { cb.cs.num_witin as usize, vec![StepRecord { rs1: Some(ReadOp { - addr: 0.into(), + addr: 2.into(), value: 11u32, previous_cycle: 0, }), rs2: Some(ReadOp { - addr: 0.into(), + addr: 3.into(), value: 0xfffffffeu32, previous_cycle: 0, }), + rd: Some(WriteOp { + addr: 4.into(), + value: Change { + before: 0u32, + after: 9u32, + }, + previous_cycle: 0, + }), ..Default::default() }], ) @@ -318,15 +332,23 @@ mod test { cb.cs.num_witin as usize, vec![StepRecord { rs1: Some(ReadOp { - addr: 0.into(), + addr: 2.into(), value: u32::MAX - 1, previous_cycle: 0, }), rs2: Some(ReadOp { - addr: 0.into(), + addr: 3.into(), value: u32::MAX - 1, previous_cycle: 0, }), + rd: Some(WriteOp { + addr: 4.into(), + value: Change { + before: 0u32, + after: u32::MAX - 2, + }, + previous_cycle: 0, + }), ..Default::default() }], ) diff --git a/ceno_zkvm/src/instructions/riscv/blt.rs b/ceno_zkvm/src/instructions/riscv/blt.rs index afec12458..c6bc8ab53 100644 --- a/ceno_zkvm/src/instructions/riscv/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/blt.rs @@ -213,6 +213,9 @@ fn blt_gadget( impl Instruction for BltInstruction { // const NAME: &'static str = "BLT"; + fn name() -> String { + "BLT".into() + } type InstructionConfig = InstructionConfig; fn construct_circuit( circuit_builder: &mut CircuitBuilder, diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs new file mode 100644 index 000000000..248a6b240 --- /dev/null +++ b/ceno_zkvm/src/keygen.rs @@ -0,0 +1,26 @@ +use crate::{ + error::ZKVMError, + structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMProvingKey}, +}; +use ff_ext::ExtensionField; + +impl ZKVMConstraintSystem { + pub fn key_gen( + self, + mut vm_fixed_traces: ZKVMFixedTraces, + ) -> Result, ZKVMError> { + let mut vm_pk = ZKVMProvingKey::default(); + + for (c_name, cs) in self.circuit_css.into_iter() { + let fixed_traces = vm_fixed_traces + .circuit_fixed_traces + .remove(&c_name) + .ok_or(ZKVMError::FixedTraceNotFound(c_name.clone()))?; + + let circuit_pk = cs.key_gen(fixed_traces); + assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); + } + + Ok(vm_pk) + } +} diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 811132e63..5aa9da5e6 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -10,8 +10,12 @@ pub use utils::u64vec; mod chip_handler; pub mod circuit_builder; pub mod expression; -mod structs; +mod keygen; +pub mod structs; mod uint; mod utils; mod virtual_polys; mod witness; + +pub use structs::ROMType; +pub use uint::UIntValue; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index a1aa33217..cc9b7d353 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,4 +1,5 @@ use ff_ext::ExtensionField; +use std::collections::HashMap; use sumcheck::structs::IOPProverMessage; use crate::structs::TowerProofs; @@ -14,7 +15,7 @@ pub mod mock_prover; mod tests; #[derive(Clone)] -pub struct ZKVMProof { +pub struct ZKVMOpcodeProof { // TODO support >1 opcodes pub num_instances: usize, @@ -58,3 +59,9 @@ pub struct ZKVMTableProof { pub fixed_in_evals: Vec, pub wits_in_evals: Vec, } + +#[derive(Default, Clone)] +pub struct ZKVMProof { + opcode_proofs: HashMap>, + table_proofs: HashMap>, +} diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index e080e0c42..fdf5a9aa2 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,10 +1,9 @@ -use std::{collections::BTreeSet, sync::Arc}; - use ff_ext::ExtensionField; +use std::{collections::BTreeSet, sync::Arc}; use itertools::Itertools; use multilinear_extensions::{ - mle::{IntoMLE, MultilinearExtension}, + mle::{IntoMLE, IntoMLEs, MultilinearExtension}, util::ceil_log2, virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, @@ -17,7 +16,6 @@ use sumcheck::{ use transcript::Transcript; use crate::{ - circuit_builder::ProvingKey, error::ZKVMError, scheme::{ constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, @@ -26,35 +24,118 @@ use crate::{ wit_infer_by_expr, }, }, - structs::{Point, TowerProofs, TowerProver, TowerProverSpec}, + structs::{ + Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, + }, utils::{get_challenge_pows, proper_num_threads}, virtual_polys::VirtualPolynomials, }; -use super::{ZKVMProof, ZKVMTableProof}; +use super::{ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; pub struct ZKVMProver { - pk: ProvingKey, + pub(crate) pk: ZKVMProvingKey, } impl ZKVMProver { - pub fn new(pk: ProvingKey) -> Self { + pub fn new(pk: ZKVMProvingKey) -> Self { ZKVMProver { pk } } + /// create proof for zkvm execution + pub fn create_proof( + &self, + mut witnesses: ZKVMWitnesses, + max_threads: usize, + transcript: &mut Transcript, + challenges: &[E; 2], + ) -> Result, ZKVMError> { + let mut vm_proof = ZKVMProof::default(); + for (circuit_name, pk) in self.pk.circuit_pks.iter() { + let witness = witnesses + .witnesses + .remove(circuit_name) + .ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; + + // TODO: add an enum for circuit type either in constraint_system or vk + let cs = pk.get_cs(); + let is_opcode_circuit = cs.lk_table_expressions.is_empty(); + let num_instances = witness.num_instances(); + + if is_opcode_circuit { + tracing::debug!( + "opcode circuit {} has {} witnesses, {} reads, {} writes, {} lookups", + circuit_name, + cs.num_witin, + cs.r_expressions.len(), + cs.w_expressions.len(), + cs.lk_expressions.len(), + ); + for lk_s in cs.lk_expressions_namespace_map.iter() { + tracing::debug!("opcode circuit {}: {}", circuit_name, lk_s); + } + let opcode_proof = self.create_opcode_proof( + pk, + witness + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + num_instances, + max_threads, + transcript, + challenges, + )?; + tracing::info!( + "generated proof for opcode {} with num_instances={}", + circuit_name, + num_instances + ); + vm_proof + .opcode_proofs + .insert(circuit_name.clone(), opcode_proof); + } else { + let table_proof = self.create_table_proof( + pk, + witness + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + num_instances, + max_threads, + transcript, + challenges, + )?; + tracing::info!( + "generated proof for table {} with num_instances={}", + circuit_name, + num_instances + ); + vm_proof + .table_proofs + .insert(circuit_name.clone(), table_proof); + } + } + + Ok(vm_proof) + } /// create proof giving witness and num_instances /// major flow break down into /// 1: witness layer inferring from input -> output /// 2: proof (sumcheck reduce) from output to input - pub fn create_proof( + pub fn create_opcode_proof( &self, + circuit_pk: &ProvingKey, witnesses: Vec>, num_instances: usize, max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], - ) -> Result, ZKVMError> { - let cs = self.pk.get_cs(); + ) -> Result, ZKVMError> { + let cs = circuit_pk.get_cs(); let log2_num_instances = ceil_log2(num_instances); let next_pow2_instances = 1 << log2_num_instances; let (chip_record_alpha, _) = (challenges[0], challenges[1]); @@ -414,7 +495,7 @@ impl ZKVMProver { .collect(); exit_span!(span); - Ok(ZKVMProof { + Ok(ZKVMOpcodeProof { num_instances, record_r_out_evals, record_w_out_evals, @@ -433,15 +514,15 @@ impl ZKVMProver { pub fn create_table_proof( &self, + circuit_pk: &ProvingKey, witnesses: Vec>, num_instances: usize, max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { - let cs = self.pk.get_cs(); - let fixed = self - .pk + let cs = circuit_pk.get_cs(); + let fixed = circuit_pk .fixed_traces .as_ref() .expect("pk.fixed_traces must not be none for table circuit") diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 3a2008a36..587beeb01 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -1,86 +1,130 @@ -use std::marker::PhantomData; +use std::{marker::PhantomData, mem::MaybeUninit}; +use ceno_emul::StepRecord; use ff::Field; use ff_ext::ExtensionField; -use goldilocks::{Goldilocks, GoldilocksExt2}; +use goldilocks::GoldilocksExt2; use itertools::Itertools; -use multilinear_extensions::mle::IntoMLE; +use multilinear_extensions::mle::IntoMLEs; +use rand::rngs::ThreadRng; use transcript::Transcript; use crate::{ - circuit_builder::{CircuitBuilder, ConstraintSystem}, + circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{Expression, ToExpr}, - structs::PointAndEval, + expression::{Expression, ToExpr, WitIn}, + instructions::Instruction, + set_val, + structs::{PointAndEval, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, + witness::LkMultiplicity, }; use super::{constants::NUM_FANIN, prover::ZKVMProver, verifier::ZKVMVerifier}; -struct TestCircuit { +struct TestConfig { + pub(crate) reg_id: WitIn, +} +struct TestCircuit { phantom: PhantomData, } -impl TestCircuit { - pub fn construct_circuit( - cb: &mut CircuitBuilder, - ) -> Result { - let regid = cb.create_witin(|| "reg_id")?; +impl Instruction for TestCircuit { + type InstructionConfig = TestConfig; + + fn name() -> String { + "TEST".into() + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let reg_id = cb.create_witin(|| "reg_id")?; (0..RW).try_for_each(|_| { let record = cb.rlc_chip_record(vec![ Expression::::Constant(E::BaseField::ONE), - regid.expr(), + reg_id.expr(), ]); cb.read_record(|| "read", record.clone())?; cb.write_record(|| "write", record)?; Result::<(), ZKVMError>::Ok(()) })?; (0..L).try_for_each(|_| { - cb.assert_ux::<_, _, 16>(|| "regid_in_range", regid.expr())?; + cb.assert_ux::<_, _, 16>(|| "regid_in_range", reg_id.expr())?; Result::<(), ZKVMError>::Ok(()) })?; assert_eq!(cb.cs.lk_expressions.len(), L); assert_eq!(cb.cs.r_expressions.len(), RW); assert_eq!(cb.cs.w_expressions.len(), RW); - Ok(Self { - phantom: PhantomData, - }) + + Ok(TestConfig { reg_id }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + set_val!(instance, config.reg_id, E::BaseField::ONE); + + Ok(()) } } #[test] fn test_rw_lk_expression_combination() { fn test_rw_lk_expression_combination_inner() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let _ = TestCircuit::construct_circuit::(&mut circuit_builder); - let pk = cs.key_gen(None); - let vk = pk.vk.clone(); + type E = GoldilocksExt2; + let name = TestCircuit::::name(); + let mut zkvm_cs = ZKVMConstraintSystem::default(); + let config = zkvm_cs.register_opcode_circuit::>(); + + let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); + zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); + + let pk = zkvm_cs.clone().key_gen(zkvm_fixed_traces).unwrap(); + let vk = pk.get_vk(); // generate mock witness let num_instances = 1 << 2; - let wits_in = (0..pk.get_cs().num_witin as usize) - .map(|_| { - (0..num_instances) - .map(|_| Goldilocks::ONE) - .collect::>() - .into_mle() - .into() - }) - .collect_vec(); + let mut zkvm_witness = ZKVMWitnesses::default(); + zkvm_witness + .assign_opcode_circuit::>( + &zkvm_cs, + &config, + vec![StepRecord::default(); num_instances], + ) + .unwrap(); // get proof let prover = ZKVMProver::new(pk); let mut transcript = Transcript::new(b"test"); - let challenges = [1.into(), 2.into()]; + let mut rng = ThreadRng::default(); + let challenges = [E::random(&mut rng), E::random(&mut rng)]; + let wits_in = zkvm_witness + .witnesses + .remove(&name) + .unwrap() + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(); let proof = prover - .create_proof(wits_in, num_instances, 1, &mut transcript, &challenges) + .create_opcode_proof( + &prover.pk.circuit_pks.get(&name).unwrap(), + wits_in, + num_instances, + 1, + &mut transcript, + &challenges, + ) .expect("create_proof failed"); - let verifier = ZKVMVerifier::new(vk); + let verifier = ZKVMVerifier::new(vk.clone()); let mut v_transcript = Transcript::new(b"test"); let _rt_input = verifier - .verify( + .verify_opcode_proof( + verifier.vk.circuit_vks.get(&name).unwrap(), &proof, &mut v_transcript, NUM_FANIN, diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 43f87ded0..fb2f2e97a 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -13,39 +13,125 @@ use sumcheck::structs::{IOPProof, IOPVerifierState}; use transcript::Transcript; use crate::{ - circuit_builder::VerifyingKey, error::ZKVMError, scheme::{ - constants::{NUM_FANIN, SEL_DEGREE}, + constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, utils::eval_by_expr_with_fixed, }, - structs::{Point, PointAndEval, TowerProofs}, + structs::{Point, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, utils::{get_challenge_pows, sel_eval}, }; use super::{ - constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZKVMProof, ZKVMTableProof, + constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZKVMOpcodeProof, ZKVMProof, + ZKVMTableProof, }; pub struct ZKVMVerifier { - vk: VerifyingKey, + pub(crate) vk: ZKVMVerifyingKey, } impl ZKVMVerifier { - pub fn new(vk: VerifyingKey) -> Self { + pub fn new(vk: ZKVMVerifyingKey) -> Self { ZKVMVerifier { vk } } + pub fn verify_proof( + &self, + vm_proof: ZKVMProof, + transcript: &mut Transcript, + challenges: &[E; 2], + ) -> Result { + let mut prod_r = E::ONE; + let mut prod_w = E::ONE; + let mut logup_sum = E::ZERO; + let dummy_table_item = challenges[0]; + let point_eval = PointAndEval::default(); + let mut dummy_table_item_multiplicity = 0; + for (name, opcode_proof) in vm_proof.opcode_proofs { + let circuit_vk = self + .vk + .circuit_vks + .get(&name) + .ok_or(ZKVMError::VKNotFound(name.clone()))?; + let _rand_point = self.verify_opcode_proof( + circuit_vk, + &opcode_proof, + transcript, + NUM_FANIN, + &point_eval, + challenges, + )?; + tracing::info!("verified proof for opcode {}", name); + + // getting the number of dummy padding item that we used in this opcode circuit + let num_lks = circuit_vk.get_cs().lk_expressions.len(); + let num_padded_lks_per_instance = num_lks.next_power_of_two() - num_lks; + let num_padded_instance = + opcode_proof.num_instances.next_power_of_two() - opcode_proof.num_instances; + dummy_table_item_multiplicity += num_padded_lks_per_instance + * opcode_proof.num_instances + + num_lks.next_power_of_two() * num_padded_instance; + + prod_r *= opcode_proof.record_r_out_evals.iter().product::(); + prod_w *= opcode_proof.record_w_out_evals.iter().product::(); + + logup_sum += + opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap(); + logup_sum += + opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); + } + + for (name, table_proof) in vm_proof.table_proofs { + let circuit_vk = self + .vk + .circuit_vks + .get(&name) + .ok_or(ZKVMError::VKNotFound(name.clone()))?; + let _rand_point = self.verify_table_proof( + circuit_vk, + &table_proof, + transcript, + NUM_FANIN_LOGUP, + &point_eval, + challenges, + )?; + tracing::info!("verified proof for table {}", name); + + logup_sum -= table_proof.lk_p1_out_eval * table_proof.lk_q1_out_eval.invert().unwrap(); + logup_sum -= table_proof.lk_p2_out_eval * table_proof.lk_q2_out_eval.invert().unwrap(); + } + logup_sum -= + E::from(dummy_table_item_multiplicity as u64) * dummy_table_item.invert().unwrap(); + + // check rw_set equality across all proofs + // TODO: enable this when we have cpu init/finalize and mem init/finalize + // if prod_r != prod_w { + // return Err(ZKVMError::VerifyError("prod_r != prod_w".into())); + // } + + // check logup relation across all proofs + if logup_sum != E::ZERO { + return Err(ZKVMError::VerifyError(format!( + "logup_sum({:?}) != 0", + logup_sum + ))); + } + + Ok(true) + } + /// verify proof and return input opening point - pub fn verify( + pub fn verify_opcode_proof( &self, - proof: &ZKVMProof, + circuit_vk: &VerifyingKey, + proof: &ZKVMOpcodeProof, transcript: &mut Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // derive challenge from PCS ) -> Result, ZKVMError> { - let cs = self.vk.get_cs(); + let cs = circuit_vk.get_cs(); let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( cs.r_expressions.len(), cs.w_expressions.len(), @@ -64,9 +150,6 @@ impl ZKVMVerifier { // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; - // TODO check rw_set equality across all proofs - // TODO check logup relation across all proofs - let (rt_tower, record_evals, logup_p_evals, logup_q_evals) = TowerVerify::verify( vec![ proof.record_r_out_evals.clone(), @@ -95,7 +178,7 @@ impl ZKVMVerifier { // index 0 is LogUp witness for Fixed Lookup table if logup_p_evals[0].eval != E::ONE { return Err(ZKVMError::VerifyError( - "Lookup table witness p(x) != constant 1", + "Lookup table witness p(x) != constant 1".into(), )); } @@ -215,7 +298,7 @@ impl ZKVMVerifier { .sum::(); if computed_evals != expected_evaluation { return Err(ZKVMError::VerifyError( - "main + sel evaluation verify failed", + "main + sel evaluation verify failed".into(), )); } // verify records (degree = 1) statement, thus no sumcheck @@ -234,7 +317,9 @@ impl ZKVMVerifier { eval_by_expr(&proof.wits_in_evals, challenges, expr) != *expected_evals }) { - return Err(ZKVMError::VerifyError("record evaluate != expected_evals")); + return Err(ZKVMError::VerifyError( + "record evaluate != expected_evals".into(), + )); } // verify zero expression (degree = 1) statement, thus no sumcheck @@ -252,16 +337,16 @@ impl ZKVMVerifier { pub fn verify_table_proof( &self, + circuit_vk: &VerifyingKey, proof: &ZKVMTableProof, transcript: &mut Transcript, num_logup_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // TODO: derive challenge from PCS ) -> Result, ZKVMError> { - let cs = self.vk.get_cs(); + let cs = circuit_vk.get_cs(); let lk_counts_per_instance = cs.lk_table_expressions.len(); let log2_lk_count = ceil_log2(lk_counts_per_instance); - let (chip_record_alpha, _) = (challenges[0], challenges[1]); let num_instances = proof.num_instances; let log2_num_instances = ceil_log2(num_instances); @@ -329,8 +414,7 @@ impl ZKVMVerifier { * ((0..lk_counts_per_instance) .map(|i| proof.lk_d_in_evals[i] * eq_lk[i]) .sum::() - + chip_record_alpha - * (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE)), + + (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE)), *alpha_lk_n * sel_lk * ((0..lk_counts_per_instance) @@ -340,7 +424,9 @@ impl ZKVMVerifier { .iter() .sum::(); if computed_evals != expected_evaluation { - return Err(ZKVMError::VerifyError("sel evaluation verify failed")); + return Err(ZKVMError::VerifyError( + "sel evaluation verify failed".into(), + )); } // verify records (degree = 1) statement, thus no sumcheck if cs @@ -362,7 +448,9 @@ impl ZKVMVerifier { ) != *expected_evals }) { - return Err(ZKVMError::VerifyError("record evaluate != expected_evals")); + return Err(ZKVMError::VerifyError( + "record evaluate != expected_evals".into(), + )); } Ok(input_opening_point) @@ -459,6 +547,7 @@ impl TowerVerify { }, transcript, ); + tracing::debug!("verified tower proof at layer {}/{}", round + 1, expected_max_round-1); // check expected_evaluation let rt: Point = sumcheck_claim.point.iter().map(|c| c.elements).collect(); @@ -490,7 +579,7 @@ impl TowerVerify { }) .sum::(); if expected_evaluation != sumcheck_claim.expected_evaluation { - return Err(ZKVMError::VerifyError("mismatch tower evaluation")); + return Err(ZKVMError::VerifyError("mismatch tower evaluation".into())); } // derive single eval diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 8105ea50b..16af00f7c 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,6 +1,18 @@ +use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + error::ZKVMError, + instructions::Instruction, + tables::TableCircuit, + witness::{LkMultiplicity, RowMajorMatrix}, +}; +use ceno_emul::StepRecord; use ff_ext::ExtensionField; -use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; +use itertools::Itertools; +use multilinear_extensions::{ + mle::DenseMultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, +}; use serde::Serialize; +use std::collections::{BTreeMap, HashMap}; use sumcheck::structs::IOPProverMessage; pub struct TowerProver; @@ -28,10 +40,10 @@ pub type WitnessId = u16; pub type ChallengeId = u16; pub enum ROMType { - U5, // 2^5 = 32 - U16, // 2^16 = 65,536 - And, // a ^ b where a, b are bytes - Ltu, // a <(usign) b where a, b are bytes + U5 = 0, // 2^5 = 32 + U16, // 2^16 = 65,536 + And, // a ^ b where a, b are bytes + Ltu, // a <(usign) b where a, b are bytes } #[derive(Clone, Debug, Copy)] @@ -75,3 +87,180 @@ impl PointAndEval { } } } + +#[derive(Clone, Debug)] +pub struct ProvingKey { + pub fixed_traces: Option>>, + pub vk: VerifyingKey, +} + +impl ProvingKey { + pub fn get_cs(&self) -> &ConstraintSystem { + self.vk.get_cs() + } +} + +#[derive(Clone, Debug)] +pub struct VerifyingKey { + pub(crate) cs: ConstraintSystem, +} + +impl VerifyingKey { + pub fn get_cs(&self) -> &ConstraintSystem { + &self.cs + } +} + +#[derive(Default, Clone)] +pub struct ZKVMConstraintSystem { + pub(crate) circuit_css: BTreeMap>, +} + +impl ZKVMConstraintSystem { + pub fn register_opcode_circuit>(&mut self) -> OC::InstructionConfig { + let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name())); + let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let config = OC::construct_circuit(&mut circuit_builder).unwrap(); + assert!(self.circuit_css.insert(OC::name(), cs).is_none()); + + config + } + + pub fn register_table_circuit>(&mut self) -> TC::TableConfig { + let mut cs = ConstraintSystem::new(|| format!("riscv_table/{}", TC::name())); + let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let config = TC::construct_circuit(&mut circuit_builder).unwrap(); + assert!(self.circuit_css.insert(TC::name(), cs.clone()).is_none()); + + config + } + + pub fn get_cs(&self, name: &String) -> Option<&ConstraintSystem> { + self.circuit_css.get(name) + } +} + +#[derive(Default)] +pub struct ZKVMFixedTraces { + pub circuit_fixed_traces: BTreeMap>>, +} + +impl ZKVMFixedTraces { + pub fn register_opcode_circuit>(&mut self, _cs: &ZKVMConstraintSystem) { + assert!(self.circuit_fixed_traces.insert(OC::name(), None).is_none()); + } + + pub fn register_table_circuit>( + &mut self, + cs: &ZKVMConstraintSystem, + config: TC::TableConfig, + ) { + let cs = cs.get_cs(&TC::name()).expect("cs not found"); + assert!( + self.circuit_fixed_traces + .insert( + TC::name(), + Some(TC::generate_fixed_traces(&config, cs.num_fixed,)), + ) + .is_none() + ); + } +} + +#[derive(Default)] +pub struct ZKVMWitnesses { + pub witnesses: BTreeMap>, + lk_mlts: BTreeMap, + combined_lk_mlt: Option>>, +} + +impl ZKVMWitnesses { + pub fn assign_opcode_circuit>( + &mut self, + cs: &ZKVMConstraintSystem, + config: &OC::InstructionConfig, + records: Vec, + ) -> Result<(), ZKVMError> { + assert!(self.combined_lk_mlt.is_none()); + + let cs = cs.get_cs(&OC::name()).unwrap(); + let (witness, logup_multiplicity) = + OC::assign_instances(config, cs.num_witin as usize, records)?; + assert!(self.witnesses.insert(OC::name(), witness).is_none()); + assert!( + self.lk_mlts + .insert(OC::name(), logup_multiplicity) + .is_none() + ); + + Ok(()) + } + + // merge the multiplicities in each opcode circuit into one + pub fn finalize_lk_multiplicities(&mut self) { + assert!(self.combined_lk_mlt.is_none()); + assert!(!self.lk_mlts.is_empty()); + + let mut combined_lk_mlt = vec![]; + let keys = self.lk_mlts.keys().cloned().collect_vec(); + for name in keys { + let lk_mlt = self.lk_mlts.remove(&name).unwrap().into_finalize_result(); + if combined_lk_mlt.is_empty() { + combined_lk_mlt = lk_mlt.to_vec(); + } else { + combined_lk_mlt + .iter_mut() + .zip_eq(lk_mlt.iter()) + .for_each(|(m1, m2)| { + for (key, value) in m2 { + *m1.entry(*key).or_insert(0) += value; + } + }); + } + } + + self.combined_lk_mlt = Some(combined_lk_mlt); + } + + pub fn assign_table_circuit>( + &mut self, + cs: &ZKVMConstraintSystem, + config: &TC::TableConfig, + ) -> Result<(), ZKVMError> { + assert!(self.combined_lk_mlt.is_some()); + + let cs = cs.get_cs(&TC::name()).unwrap(); + let witness = TC::assign_instances( + config, + cs.num_witin as usize, + self.combined_lk_mlt.as_ref().unwrap(), + )?; + assert!(self.witnesses.insert(TC::name(), witness).is_none()); + + Ok(()) + } +} + +#[derive(Default)] +pub struct ZKVMProvingKey { + // pk for opcode and table circuits + pub(crate) circuit_pks: BTreeMap>, +} + +impl ZKVMProvingKey { + pub fn get_vk(&self) -> ZKVMVerifyingKey { + ZKVMVerifyingKey { + circuit_vks: self + .circuit_pks + .iter() + .map(|(name, pk)| (name.clone(), pk.vk.clone())) + .collect(), + } + } +} + +#[derive(Default, Clone)] +pub struct ZKVMVerifyingKey { + // pk for opcode and table circuits + pub circuit_vks: BTreeMap>, +} diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index b2277ba15..a66094952 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -1 +1,28 @@ +use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajorMatrix}; +use ff_ext::ExtensionField; +use std::collections::HashMap; + mod range; +pub use range::RangeTableCircuit; + +pub trait TableCircuit { + type TableConfig: Send + Sync; + type Input: Send + Sync; + + fn name() -> String; + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result; + + fn generate_fixed_traces( + config: &Self::TableConfig, + num_fixed: usize, + ) -> RowMajorMatrix; + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + multiplicity: &[HashMap], + ) -> Result, ZKVMError>; +} diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index a081ee549..293077f46 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -1,30 +1,36 @@ +use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; + use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, Fixed, ToExpr, WitIn}, - structs::{ROMType, WitnessId}, + scheme::constants::MIN_PAR_SIZE, + set_fixed_val, set_val, + structs::ROMType, + tables::TableCircuit, + uint::constants::RANGE_CHIP_BIT_WIDTH, + witness::RowMajorMatrix, }; use ff_ext::ExtensionField; -use itertools::Itertools; -use multilinear_extensions::mle::DenseMultilinearExtension; -use std::{collections::BTreeMap, marker::PhantomData}; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; #[derive(Clone, Debug)] -pub struct RangeTableConfig { +pub struct RangeTableConfig { u16_tbl: Fixed, u16_mlt: WitIn, - _marker: PhantomData, } -#[derive(Default)] -pub struct RangeTableTrace { - pub fixed: BTreeMap>, - pub wits: BTreeMap>, -} +pub struct RangeTableCircuit(PhantomData); + +impl TableCircuit for RangeTableCircuit { + type TableConfig = RangeTableConfig; + type Input = u64; + + fn name() -> String { + "RANGE".into() + } -impl RangeTableConfig { - #[allow(unused)] - fn construct_circuit(cb: &mut CircuitBuilder) -> Result, ZKVMError> { + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let u16_tbl = cb.create_fixed(|| "u16_tbl")?; let u16_mlt = cb.create_witin(|| "u16_mlt")?; @@ -35,97 +41,46 @@ impl RangeTableConfig { cb.lk_table_record(|| "u16 table", u16_table_values, u16_mlt.expr())?; - Ok(RangeTableConfig { - u16_tbl, - u16_mlt, - _marker: Default::default(), - }) + Ok(RangeTableConfig { u16_tbl, u16_mlt }) } - #[allow(unused)] - fn generate_traces(self, inputs: &[u16]) -> RangeTableTrace { - let mut u16_mlt = vec![0; 1 << 16]; - for limb in inputs { - u16_mlt[*limb as usize] += 1; - } - - let u16_tbl = DenseMultilinearExtension::from_evaluations_vec( - 16, - (0..(1 << 16)).map(E::BaseField::from).collect_vec(), - ); - let u16_mlt = DenseMultilinearExtension::from_evaluations_vec( - 16, - u16_mlt.into_iter().map(E::BaseField::from).collect_vec(), - ); - - let config = self.clone(); - let mut traces = RangeTableTrace::default(); - traces.fixed.insert(config.u16_tbl, u16_tbl); - traces.wits.insert(config.u16_mlt.id, u16_mlt); - - traces + fn generate_fixed_traces( + config: &RangeTableConfig, + num_fixed: usize, + ) -> RowMajorMatrix { + let num_u16s = 1 << 16; + let mut fixed = RowMajorMatrix::::new(num_u16s, num_fixed); + fixed + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip((0..num_u16s).into_par_iter()) + .for_each(|(row, i)| { + set_fixed_val!(row, config.u16_tbl.0, E::BaseField::from(i as u64)); + }); + + fixed } -} - -#[cfg(test)] -mod tests { - use crate::{ - circuit_builder::{CircuitBuilder, ConstraintSystem}, - scheme::{constants::NUM_FANIN_LOGUP, prover::ZKVMProver, verifier::ZKVMVerifier}, - structs::PointAndEval, - tables::range::RangeTableConfig, - }; - use goldilocks::GoldilocksExt2; - use itertools::Itertools; - use transcript::Transcript; - - #[test] - fn test_range_circuit() { - let mut cs = ConstraintSystem::new(|| "riscv"); - let config = cs - .namespace( - || "range", - |cs| { - let mut cb = CircuitBuilder::::new(cs); - RangeTableConfig::construct_circuit(&mut cb) - }, - ) - .expect("construct range table circuit"); - let traces = config.generate_traces((0..1 << 8).collect_vec().as_slice()); - - let pk = cs.key_gen(Some(traces.fixed.clone().into_values().collect_vec())); - let vk = pk.vk.clone(); - let prover = ZKVMProver::new(pk); - - let mut transcript = Transcript::new(b"range"); - let challenges = [1.into(), 2.into()]; + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + multiplicity: &[HashMap], + ) -> Result, ZKVMError> { + let multiplicity = &multiplicity[ROMType::U16 as usize]; + let mut u16_mlt = vec![0; 1 << RANGE_CHIP_BIT_WIDTH]; + for (limb, mlt) in multiplicity { + u16_mlt[*limb as usize] = *mlt; + } - let proof = prover - .create_table_proof( - traces - .wits - .into_values() - .map(|mle| mle.into()) - .collect_vec(), - // TODO: fix the verification error for num_instances is not power-of-two case - 1 << 16, - 1, - &mut transcript, - &challenges, - ) - .expect("create proof"); + let mut witness = RowMajorMatrix::::new(u16_mlt.len(), num_witin); + witness + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip(u16_mlt.into_par_iter()) + .for_each(|(row, mlt)| { + set_val!(row, config.u16_mlt, E::BaseField::from(mlt as u64)); + }); - let mut transcript = Transcript::new(b"range"); - let verifier = ZKVMVerifier::new(vk); - verifier - .verify_table_proof( - &proof, - &mut transcript, - NUM_FANIN_LOGUP, - &PointAndEval::default(), - &challenges, - ) - .expect("verify proof failed"); + Ok(witness) } } diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 44927014d..993a2d8b3 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -1,5 +1,5 @@ mod arithmetic; -mod constants; +pub mod constants; pub mod util; use crate::{ diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 93dca0023..a77d9c4c3 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -23,20 +23,35 @@ macro_rules! set_val { }; } +#[macro_export] +macro_rules! set_fixed_val { + ($ins:ident, $field:expr, $val:expr) => { + $ins[$field as usize] = MaybeUninit::new($val); + }; +} + pub struct RowMajorMatrix { // represent 2D in 1D linear memory and avoid double indirection by Vec> to improve performance values: Vec>, + num_padding_rows: usize, num_col: usize, } impl RowMajorMatrix { - pub fn new(num_row: usize, num_col: usize) -> Self { + pub fn new(num_rows: usize, num_col: usize) -> Self { + let num_total_rows = num_rows.next_power_of_two(); + let num_padding_rows = num_total_rows - num_rows; RowMajorMatrix { - values: create_uninit_vec(num_row * num_col), + values: create_uninit_vec(num_total_rows * num_col), + num_padding_rows, num_col, } } + pub fn num_instances(&self) -> usize { + self.values.len() / self.num_col - self.num_padding_rows + } + pub fn iter_mut(&mut self) -> ChunksMut> { self.values.chunks_mut(self.num_col) } @@ -128,7 +143,7 @@ impl LkMultiplicity { } /// merge result from multiple thread local to single result - fn into_finalize_result(self) -> [HashMap; mem::variant_count::()] { + pub fn into_finalize_result(self) -> [HashMap; mem::variant_count::()] { Arc::try_unwrap(self.multiplicity) .unwrap() .into_iter() diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 9c59a6806..a4530919c 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -124,7 +124,7 @@ pub trait IntoMLEs: Sized { fn into_mles(self) -> Vec; } -impl IntoMLEs> for Vec> { +impl> IntoMLEs> for Vec> { fn into_mles(self) -> Vec> { self.into_iter().map(|v| v.into_mle()).collect() }