From 09b6996b9bd3a25586eb5919e069d7448eb31496 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 4 Sep 2024 20:46:29 +0800 Subject: [PATCH 01/15] rename --- ceno_zkvm/benches/riscv_add.rs | 2 +- ceno_zkvm/examples/riscv_add.rs | 2 +- ceno_zkvm/src/scheme.rs | 10 ++++++++-- ceno_zkvm/src/scheme/prover.rs | 18 ++++++++++++------ ceno_zkvm/src/scheme/tests.rs | 2 +- ceno_zkvm/src/scheme/verifier.rs | 6 +++--- 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index d40e92941..aa05b218f 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -100,7 +100,7 @@ fn bench_add(c: &mut Criterion) { .collect_vec(); let timer = Instant::now(); let _ = prover - .create_proof( + .create_opcode_proof( wits_in, num_instances, max_threads, diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index b7c547782..16c0d5dee 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -78,7 +78,7 @@ fn main() { .collect_vec(); let timer = Instant::now(); let _ = prover - .create_proof( + .create_opcode_proof( wits_in, num_instances, max_threads, diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index ea55b1ec2..2fd295d3e 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -13,7 +13,7 @@ pub mod verifier; mod tests; #[derive(Clone)] -pub struct ZKVMProof { +pub struct ZkvmOpcodeProof { // TODO support >1 opcodes pub num_instances: usize, @@ -39,7 +39,7 @@ pub struct ZKVMProof { } #[derive(Clone)] -pub struct ZKVMTableProof { +pub struct ZkvmTableProof { pub num_instances: usize, // logup sum at layer 1 pub lk_p1_out_eval: E, @@ -57,3 +57,9 @@ pub struct ZKVMTableProof { pub fixed_in_evals: Vec, pub wits_in_evals: Vec, } + +#[derive(Clone)] +pub struct ZkvmProof { + opcode_proofs: Vec>, + table_proofs: Vec>, +} \ No newline at end of file diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index e080e0c42..9fec79fc9 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -31,7 +31,7 @@ use crate::{ virtual_polys::VirtualPolynomials, }; -use super::{ZKVMProof, ZKVMTableProof}; +use super::{ZkvmOpcodeProof, ZkvmProof, ZkvmTableProof}; pub struct ZKVMProver { pk: ProvingKey, @@ -42,18 +42,24 @@ impl ZKVMProver { ZKVMProver { pk } } + pub fn create_proof( + + ) -> Result> { + + Ok() + } /// create proof giving witness and num_instances /// major flow break down into /// 1: witness layer inferring from input -> output /// 2: proof (sumcheck reduce) from output to input - pub fn create_proof( + pub fn create_opcode_proof( &self, witnesses: Vec>, num_instances: usize, max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], - ) -> Result, ZKVMError> { + ) -> Result, ZKVMError> { let cs = self.pk.get_cs(); let log2_num_instances = ceil_log2(num_instances); let next_pow2_instances = 1 << log2_num_instances; @@ -414,7 +420,7 @@ impl ZKVMProver { .collect(); exit_span!(span); - Ok(ZKVMProof { + Ok(ZkvmOpcodeProof { num_instances, record_r_out_evals, record_w_out_evals, @@ -438,7 +444,7 @@ impl ZKVMProver { max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], - ) -> Result, ZKVMError> { + ) -> Result, ZKVMError> { let cs = self.pk.get_cs(); let fixed = self .pk @@ -610,7 +616,7 @@ impl ZKVMProver { let wits_in_evals = evals; exit_span!(span); - Ok(ZKVMTableProof { + Ok(ZkvmTableProof { num_instances, lk_p1_out_eval, lk_p2_out_eval, diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 3a2008a36..568e7ecf1 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -74,7 +74,7 @@ fn test_rw_lk_expression_combination() { let challenges = [1.into(), 2.into()]; let proof = prover - .create_proof(wits_in, num_instances, 1, &mut transcript, &challenges) + .create_opcode_proof(wits_in, num_instances, 1, &mut transcript, &challenges) .expect("create_proof failed"); let verifier = ZKVMVerifier::new(vk); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 43f87ded0..8a0c9c116 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -24,7 +24,7 @@ use crate::{ }; use super::{ - constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZKVMProof, ZKVMTableProof, + constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZkvmOpcodeProof, ZkvmTableProof, }; pub struct ZKVMVerifier { @@ -39,7 +39,7 @@ impl ZKVMVerifier { /// verify proof and return input opening point pub fn verify( &self, - proof: &ZKVMProof, + proof: &ZkvmOpcodeProof, transcript: &mut Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, @@ -252,7 +252,7 @@ impl ZKVMVerifier { pub fn verify_table_proof( &self, - proof: &ZKVMTableProof, + proof: &ZkvmTableProof, transcript: &mut Transcript, num_logup_fanin: usize, _out_evals: &PointAndEval, From 9ea7739cddb917cd8097bc8411a4cbd1970674db Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 5 Sep 2024 17:00:33 +0800 Subject: [PATCH 02/15] draft vm prover and verifier --- ceno_zkvm/src/circuit_builder.rs | 42 ++++++++-- ceno_zkvm/src/instructions.rs | 2 + ceno_zkvm/src/instructions/riscv/addsub.rs | 6 ++ ceno_zkvm/src/keygen.rs | 30 +++++++ ceno_zkvm/src/lib.rs | 1 + ceno_zkvm/src/scheme.rs | 15 ++-- ceno_zkvm/src/scheme/prover.rs | 92 +++++++++++++++++----- ceno_zkvm/src/scheme/verifier.rs | 91 ++++++++++++++++++--- ceno_zkvm/src/witness.rs | 12 ++- 9 files changed, 245 insertions(+), 46 deletions(-) create mode 100644 ceno_zkvm/src/keygen.rs diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 794ababb4..41fb54b91 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -1,12 +1,17 @@ -use std::marker::PhantomData; +use itertools::Itertools; +use std::{ + collections::{BTreeMap, HashMap}, + marker::PhantomData, +}; use ff_ext::ExtensionField; -use multilinear_extensions::mle::DenseMultilinearExtension; +use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLEs}; use crate::{ error::ZKVMError, expression::{Expression, Fixed, WitIn}, structs::WitnessId, + witness::RowMajorMatrix, }; /// namespace used for annotation, preserve meta info during circuit construction @@ -135,7 +140,18 @@ impl ConstraintSystem { } } - pub fn key_gen(self, fixed_traces: Option>>) -> ProvingKey { + pub fn key_gen(self, fixed_traces: Option>) -> ProvingKey { + // TODO: commit to fixed_traces + + // transpose from row-major to column-major + let fixed_traces = fixed_traces.map(|t| { + t.de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec() + }); + ProvingKey { fixed_traces, vk: VerifyingKey { cs: self }, @@ -301,9 +317,6 @@ pub struct ProvingKey { } impl ProvingKey { - // pub fn create_pk(vk: VerifyingKey) -> Self { - // Self { vk } - // } pub fn get_cs(&self) -> &ConstraintSystem { self.vk.get_cs() } @@ -319,3 +332,20 @@ impl VerifyingKey { &self.cs } } + +#[derive(Default)] +pub struct ZKVMConstraintSystem { + pub circuit_css: BTreeMap>, +} + +#[derive(Default)] +pub struct ZKVMProvingKey { + // pk for opcode and table circuits + pub circuit_pks: BTreeMap>, +} + +#[derive(Default)] +pub struct ZKVMVerifyingKey { + // pk for opcode and table circuits + pub circuit_vks: BTreeMap>, +} diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 633eaa8db..ff511f574 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -10,6 +10,8 @@ pub mod riscv; pub trait Instruction { type InstructionConfig: Send + Sync; + + fn name() -> String; fn construct_circuit( circuit_builder: &mut CircuitBuilder, ) -> Result; diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 20609ff0d..523b0fd61 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -136,6 +136,9 @@ fn add_sub_gadget( impl Instruction for AddInstruction { // const NAME: &'static str = "ADD"; + fn name() -> String { + "ADD".into() + } type InstructionConfig = InstructionConfig; fn construct_circuit( circuit_builder: &mut CircuitBuilder, @@ -182,6 +185,9 @@ impl Instruction for AddInstruction { impl Instruction for SubInstruction { // const NAME: &'static str = "ADD"; + fn name() -> String { + "SUB".into() + } type InstructionConfig = InstructionConfig; fn construct_circuit( circuit_builder: &mut CircuitBuilder, diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs new file mode 100644 index 000000000..5965eee82 --- /dev/null +++ b/ceno_zkvm/src/keygen.rs @@ -0,0 +1,30 @@ +use crate::{ + circuit_builder::{ZKVMConstraintSystem, ZKVMProvingKey}, + witness::RowMajorMatrix, +}; +use ff_ext::ExtensionField; +use std::collections::BTreeMap; + +impl ZKVMConstraintSystem { + pub fn key_gen( + self, + mut vm_fixed_traces: BTreeMap>>, + ) -> ZKVMProvingKey { + let mut vm_pk = ZKVMProvingKey::default(); + + for (c_name, cs) in self.circuit_css.into_iter() { + let fixed_traces = vm_fixed_traces.remove(&c_name).expect( + format!( + "circuit {}'s trace is not present in vm_fixed_traces", + c_name + ) + .as_str(), + ); + + let circuit_pk = cs.key_gen(fixed_traces); + assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); + } + + vm_pk + } +} diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 3e9612f1a..c29f030ba 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -9,6 +9,7 @@ pub use utils::u64vec; mod chip_handler; pub mod circuit_builder; pub mod expression; +mod keygen; mod structs; mod uint; mod utils; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 2fd295d3e..7fa5dc79a 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,4 +1,5 @@ use ff_ext::ExtensionField; +use std::collections::HashMap; use sumcheck::structs::IOPProverMessage; use crate::structs::TowerProofs; @@ -13,7 +14,7 @@ pub mod verifier; mod tests; #[derive(Clone)] -pub struct ZkvmOpcodeProof { +pub struct ZKVMOpcodeProof { // TODO support >1 opcodes pub num_instances: usize, @@ -39,7 +40,7 @@ pub struct ZkvmOpcodeProof { } #[derive(Clone)] -pub struct ZkvmTableProof { +pub struct ZKVMTableProof { pub num_instances: usize, // logup sum at layer 1 pub lk_p1_out_eval: E, @@ -58,8 +59,8 @@ pub struct ZkvmTableProof { pub wits_in_evals: Vec, } -#[derive(Clone)] -pub struct ZkvmProof { - opcode_proofs: Vec>, - table_proofs: Vec>, -} \ No newline at end of file +#[derive(Default, Clone)] +pub struct ZKVMProof { + opcode_proofs: HashMap>, + table_proofs: HashMap>, +} diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 9fec79fc9..9d12d8a1c 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,10 +1,12 @@ -use std::{collections::BTreeSet, sync::Arc}; - use ff_ext::ExtensionField; +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; use itertools::Itertools; use multilinear_extensions::{ - mle::{IntoMLE, MultilinearExtension}, + mle::{IntoMLE, IntoMLEs, MultilinearExtension}, util::ceil_log2, virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, @@ -17,7 +19,7 @@ use sumcheck::{ use transcript::Transcript; use crate::{ - circuit_builder::ProvingKey, + circuit_builder::{ProvingKey, ZKVMProvingKey}, error::ZKVMError, scheme::{ constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, @@ -29,24 +31,77 @@ use crate::{ structs::{Point, TowerProofs, TowerProver, TowerProverSpec}, utils::{get_challenge_pows, proper_num_threads}, virtual_polys::VirtualPolynomials, + witness::RowMajorMatrix, }; -use super::{ZkvmOpcodeProof, ZkvmProof, ZkvmTableProof}; +use super::{ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; pub struct ZKVMProver { - pk: ProvingKey, + pk: ZKVMProvingKey, } impl ZKVMProver { - pub fn new(pk: ProvingKey) -> Self { + pub fn new(pk: ZKVMProvingKey) -> Self { ZKVMProver { pk } } + /// create proof for zkvm execution pub fn create_proof( + &self, + mut witnesses: HashMap>, + max_threads: usize, + transcript: &mut Transcript, + challenges: &[E; 2], + ) -> Result, ZKVMError> { + let mut vm_proof = ZKVMProof::default(); + for (circuit_name, pk) in self.pk.circuit_pks.iter() { + let witness = witnesses + .remove(circuit_name) + .expect(format!("witness for circuit {} is not found", circuit_name).as_str()); + + // TODO: add an enum for circuit type either in constraint_system or vk + let cs = pk.get_cs(); + let is_opcode_circuit = cs.lk_table_expressions.is_empty(); + let num_instances = witness.num_instances(); + + if is_opcode_circuit { + let opcode_proof = self.create_opcode_proof( + pk, + witness + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + num_instances, + max_threads, + transcript, + challenges, + )?; + vm_proof + .opcode_proofs + .insert(circuit_name.clone(), opcode_proof); + } else { + let table_proof = self.create_table_proof( + pk, + witness + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + num_instances, + max_threads, + transcript, + challenges, + )?; + vm_proof + .table_proofs + .insert(circuit_name.clone(), table_proof); + } + } - ) -> Result> { - - Ok() + Ok(vm_proof) } /// create proof giving witness and num_instances /// major flow break down into @@ -54,13 +109,14 @@ impl ZKVMProver { /// 2: proof (sumcheck reduce) from output to input pub fn create_opcode_proof( &self, + circuit_pk: &ProvingKey, witnesses: Vec>, num_instances: usize, max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], - ) -> Result, ZKVMError> { - let cs = self.pk.get_cs(); + ) -> Result, ZKVMError> { + let cs = circuit_pk.get_cs(); let log2_num_instances = ceil_log2(num_instances); let next_pow2_instances = 1 << log2_num_instances; let (chip_record_alpha, _) = (challenges[0], challenges[1]); @@ -420,7 +476,7 @@ impl ZKVMProver { .collect(); exit_span!(span); - Ok(ZkvmOpcodeProof { + Ok(ZKVMOpcodeProof { num_instances, record_r_out_evals, record_w_out_evals, @@ -439,15 +495,15 @@ impl ZKVMProver { pub fn create_table_proof( &self, + circuit_pk: &ProvingKey, witnesses: Vec>, num_instances: usize, max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], - ) -> Result, ZKVMError> { - let cs = self.pk.get_cs(); - let fixed = self - .pk + ) -> Result, ZKVMError> { + let cs = circuit_pk.get_cs(); + let fixed = circuit_pk .fixed_traces .as_ref() .expect("pk.fixed_traces must not be none for table circuit") @@ -616,7 +672,7 @@ impl ZKVMProver { let wits_in_evals = evals; exit_span!(span); - Ok(ZkvmTableProof { + Ok(ZKVMTableProof { num_instances, lk_p1_out_eval, lk_p2_out_eval, diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 8a0c9c116..40b4a818e 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -13,10 +13,10 @@ use sumcheck::structs::{IOPProof, IOPVerifierState}; use transcript::Transcript; use crate::{ - circuit_builder::VerifyingKey, + circuit_builder::{VerifyingKey, ZKVMVerifyingKey}, error::ZKVMError, scheme::{ - constants::{NUM_FANIN, SEL_DEGREE}, + constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, utils::eval_by_expr_with_fixed, }, structs::{Point, PointAndEval, TowerProofs}, @@ -24,28 +24,95 @@ use crate::{ }; use super::{ - constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZkvmOpcodeProof, ZkvmTableProof, + constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, utils::eval_by_expr, ZKVMOpcodeProof, ZKVMProof, + ZKVMTableProof, }; pub struct ZKVMVerifier { - vk: VerifyingKey, + vk: ZKVMVerifyingKey, } impl ZKVMVerifier { - pub fn new(vk: VerifyingKey) -> Self { + pub fn new(vk: ZKVMVerifyingKey) -> Self { ZKVMVerifier { vk } } + pub fn verify_proof( + &self, + vm_proof: ZKVMProof, + transcript: &mut Transcript, + challenges: &[E; 2], + ) -> Result { + let mut prod_r = E::ONE; + let mut prod_w = E::ONE; + let mut logup_sum = E::ZERO; + let point_eval = PointAndEval::default(); + for (name, opcode_proof) in vm_proof.opcode_proofs { + let circuit_vk = self + .vk + .circuit_vks + .get(&name) + .expect(format!("vk of opcode circuit {} is not present", name).as_str()); + let _rand_point = self.verify_opcode_proof( + circuit_vk, + &opcode_proof, + transcript, + NUM_FANIN, + &point_eval, + challenges, + )?; + + prod_r *= opcode_proof.record_r_out_evals.iter().product::(); + prod_w *= opcode_proof.record_w_out_evals.iter().product::(); + + logup_sum += + opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap(); + logup_sum += + opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); + } + + for (name, table_proof) in vm_proof.table_proofs { + let circuit_vk = self + .vk + .circuit_vks + .get(&name) + .expect(format!("vk of table circuit {} is not present", name).as_str()); + let _rand_point = self.verify_table_proof( + circuit_vk, + &table_proof, + transcript, + NUM_FANIN_LOGUP, + &point_eval, + challenges, + )?; + + logup_sum -= table_proof.lk_p1_out_eval * table_proof.lk_q1_out_eval.invert().unwrap(); + logup_sum -= table_proof.lk_p2_out_eval * table_proof.lk_q2_out_eval.invert().unwrap(); + } + // check rw_set equality across all proofs + if prod_r != prod_w { + return Ok(false); + } + + // check logup relation across all proofs + if logup_sum != E::ZERO { + return Ok(false); + } + + Ok(true) + } + /// verify proof and return input opening point - pub fn verify( + pub fn verify_opcode_proof( &self, - proof: &ZkvmOpcodeProof, + circuit_vk: &VerifyingKey, + proof: &ZKVMOpcodeProof, transcript: &mut Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // derive challenge from PCS ) -> Result, ZKVMError> { - let cs = self.vk.get_cs(); + let cs = circuit_vk.get_cs(); let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( cs.r_expressions.len(), cs.w_expressions.len(), @@ -64,9 +131,6 @@ impl ZKVMVerifier { // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; - // TODO check rw_set equality across all proofs - // TODO check logup relation across all proofs - let (rt_tower, record_evals, logup_p_evals, logup_q_evals) = TowerVerify::verify( vec![ proof.record_r_out_evals.clone(), @@ -252,13 +316,14 @@ impl ZKVMVerifier { pub fn verify_table_proof( &self, - proof: &ZkvmTableProof, + circuit_vk: &VerifyingKey, + proof: &ZKVMTableProof, transcript: &mut Transcript, num_logup_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // TODO: derive challenge from PCS ) -> Result, ZKVMError> { - let cs = self.vk.get_cs(); + let cs = circuit_vk.get_cs(); let lk_counts_per_instance = cs.lk_table_expressions.len(); let log2_lk_count = ceil_log2(lk_counts_per_instance); let (chip_record_alpha, _) = (challenges[0], challenges[1]); diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 9b44586da..3452943e0 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -19,17 +19,25 @@ macro_rules! set_val { pub struct RowMajorMatrix { // represent 2D in 1D linear memory and avoid double indirection by Vec> to improve performance values: Vec>, + num_padding_rows: usize, num_col: usize, } impl RowMajorMatrix { - pub fn new(num_row: usize, num_col: usize) -> Self { + pub fn new(num_rows: usize, num_col: usize) -> Self { + let num_total_rows = num_rows.next_power_of_two(); + let num_padding_rows = num_total_rows - num_rows; RowMajorMatrix { - values: create_uninit_vec(num_row * num_col), + values: create_uninit_vec(num_total_rows * num_col), + num_padding_rows, num_col, } } + pub fn num_instances(&self) -> usize { + self.values.len() / self.num_col - self.num_padding_rows + } + pub fn iter_mut(&mut self) -> ChunksMut> { self.values.chunks_mut(self.num_col) } From f7d8a4cd3cb502f7d0d263067d401f7e3a5a74df Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 5 Sep 2024 17:05:15 +0800 Subject: [PATCH 03/15] add name for blt circuit --- ceno_zkvm/src/instructions/riscv/blt.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ceno_zkvm/src/instructions/riscv/blt.rs b/ceno_zkvm/src/instructions/riscv/blt.rs index bd71b15ab..ff35327c2 100644 --- a/ceno_zkvm/src/instructions/riscv/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/blt.rs @@ -207,6 +207,9 @@ fn blt_gadget( impl Instruction for BltInstruction { // const NAME: &'static str = "BLT"; + fn name() -> String { + "BLT".into() + } type InstructionConfig = InstructionConfig; fn construct_circuit( circuit_builder: &mut CircuitBuilder, From 0bc30e6fca793daf434141080e018a9cdc18eeea Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 5 Sep 2024 19:39:52 +0800 Subject: [PATCH 04/15] wip --- ceno_zkvm/src/circuit_builder.rs | 19 ++- ceno_zkvm/src/instructions/riscv/blt.rs | 9 +- ceno_zkvm/src/keygen.rs | 2 +- ceno_zkvm/src/scheme/prover.rs | 2 +- ceno_zkvm/src/scheme/tests.rs | 96 +++++++-------- ceno_zkvm/src/tables/mod.rs | 25 ++++ ceno_zkvm/src/tables/range.rs | 150 ++++++++---------------- ceno_zkvm/src/witness.rs | 7 ++ multilinear_extensions/src/mle.rs | 2 +- 9 files changed, 145 insertions(+), 167 deletions(-) diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 41fb54b91..4c9d812a5 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -1,8 +1,5 @@ use itertools::Itertools; -use std::{ - collections::{BTreeMap, HashMap}, - marker::PhantomData, -}; +use std::{collections::BTreeMap, marker::PhantomData}; use ff_ext::ExtensionField; use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLEs}; @@ -140,7 +137,7 @@ impl ConstraintSystem { } } - pub fn key_gen(self, fixed_traces: Option>) -> ProvingKey { + pub fn key_gen(self, fixed_traces: Option>) -> ProvingKey { // TODO: commit to fixed_traces // transpose from row-major to column-major @@ -344,6 +341,18 @@ pub struct ZKVMProvingKey { pub circuit_pks: BTreeMap>, } +impl ZKVMProvingKey { + pub fn get_vk(&self) -> ZKVMVerifyingKey { + ZKVMVerifyingKey { + circuit_vks: self + .circuit_pks + .iter() + .map(|(name, pk)| (name.clone(), pk.vk.clone())) + .collect(), + } + } +} + #[derive(Default)] pub struct ZKVMVerifyingKey { // pk for opcode and table circuits diff --git a/ceno_zkvm/src/instructions/riscv/blt.rs b/ceno_zkvm/src/instructions/riscv/blt.rs index ff35327c2..7526f8f67 100644 --- a/ceno_zkvm/src/instructions/riscv/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/blt.rs @@ -233,18 +233,11 @@ impl Instruction for BltInstruction { mod test { use super::*; use ceno_emul::StepRecord; - use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; - use crate::{ - circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::Instruction, - scheme::mock_prover::MockProver, - }; - - use super::BltInstruction; + use crate::{circuit_builder::ConstraintSystem, scheme::mock_prover::MockProver}; #[test] fn test_blt_circuit() -> Result<(), ZKVMError> { diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index 5965eee82..c475878bf 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -8,7 +8,7 @@ use std::collections::BTreeMap; impl ZKVMConstraintSystem { pub fn key_gen( self, - mut vm_fixed_traces: BTreeMap>>, + mut vm_fixed_traces: BTreeMap>>, ) -> ZKVMProvingKey { let mut vm_pk = ZKVMProvingKey::default(); diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 9d12d8a1c..e7f0f8680 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -48,7 +48,7 @@ impl ZKVMProver { /// create proof for zkvm execution pub fn create_proof( &self, - mut witnesses: HashMap>, + mut witnesses: HashMap>, max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 568e7ecf1..a273582a9 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -47,51 +47,51 @@ impl TestCircuit { } } -#[test] -fn test_rw_lk_expression_combination() { - fn test_rw_lk_expression_combination_inner() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let _ = TestCircuit::construct_circuit::(&mut circuit_builder); - let pk = cs.key_gen(None); - let vk = pk.vk.clone(); - - // generate mock witness - let num_instances = 1 << 2; - let wits_in = (0..pk.get_cs().num_witin as usize) - .map(|_| { - (0..num_instances) - .map(|_| Goldilocks::ONE) - .collect::>() - .into_mle() - .into() - }) - .collect_vec(); - - // get proof - let prover = ZKVMProver::new(pk); - let mut transcript = Transcript::new(b"test"); - let challenges = [1.into(), 2.into()]; - - let proof = prover - .create_opcode_proof(wits_in, num_instances, 1, &mut transcript, &challenges) - .expect("create_proof failed"); - - let verifier = ZKVMVerifier::new(vk); - let mut v_transcript = Transcript::new(b"test"); - let _rt_input = verifier - .verify( - &proof, - &mut v_transcript, - NUM_FANIN, - &PointAndEval::default(), - &challenges, - ) - .expect("verifier failed"); - } - - // - test_rw_lk_expression_combination_inner::<19, 17>(); - test_rw_lk_expression_combination_inner::<61, 17>(); - test_rw_lk_expression_combination_inner::<17, 61>(); -} +// #[test] +// fn test_rw_lk_expression_combination() { +// fn test_rw_lk_expression_combination_inner() { +// let mut cs = ConstraintSystem::new(|| "test"); +// let mut circuit_builder = CircuitBuilder::::new(&mut cs); +// let _ = TestCircuit::construct_circuit::(&mut circuit_builder); +// let pk = cs.key_gen(None); +// let vk = pk.vk.clone(); +// +// // generate mock witness +// let num_instances = 1 << 2; +// let wits_in = (0..pk.get_cs().num_witin as usize) +// .map(|_| { +// (0..num_instances) +// .map(|_| Goldilocks::ONE) +// .collect::>() +// .into_mle() +// .into() +// }) +// .collect_vec(); +// +// // get proof +// let prover = ZKVMProver::new(pk); +// let mut transcript = Transcript::new(b"test"); +// let challenges = [1.into(), 2.into()]; +// +// let proof = prover +// .create_opcode_proof(wits_in, num_instances, 1, &mut transcript, &challenges) +// .expect("create_proof failed"); +// +// let verifier = ZKVMVerifier::new(vk); +// let mut v_transcript = Transcript::new(b"test"); +// let _rt_input = verifier +// .verify( +// &proof, +// &mut v_transcript, +// NUM_FANIN, +// &PointAndEval::default(), +// &challenges, +// ) +// .expect("verifier failed"); +// } +// +// // +// test_rw_lk_expression_combination_inner::<19, 17>(); +// test_rw_lk_expression_combination_inner::<61, 17>(); +// test_rw_lk_expression_combination_inner::<17, 61>(); +// } diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index b2277ba15..ebd2e78f2 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -1 +1,26 @@ +use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajorMatrix}; +use ff_ext::ExtensionField; + mod range; + +pub trait TableCircuit { + type TableConfig: Send + Sync; + type Input: Send + Sync; + + fn name() -> String; + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result; + + fn generate_fixed_traces( + config: &Self::TableConfig, + num_fixed: usize, + ) -> RowMajorMatrix; + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + inputs: &[Self::Input], + ) -> Result, ZKVMError>; +} diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index a081ee549..b0b496675 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -1,30 +1,27 @@ -use crate::{ - circuit_builder::CircuitBuilder, - error::ZKVMError, - expression::{Expression, Fixed, ToExpr, WitIn}, - structs::{ROMType, WitnessId}, -}; +use std::mem::MaybeUninit; + +use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, Fixed, ToExpr, WitIn}, set_fixed_val, set_val, structs::ROMType, tables::TableCircuit, witness::RowMajorMatrix}; use ff_ext::ExtensionField; -use itertools::Itertools; -use multilinear_extensions::mle::DenseMultilinearExtension; -use std::{collections::BTreeMap, marker::PhantomData}; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; #[derive(Clone, Debug)] -pub struct RangeTableConfig { +pub struct RangeTableConfig { u16_tbl: Fixed, u16_mlt: WitIn, - _marker: PhantomData, } -#[derive(Default)] -pub struct RangeTableTrace { - pub fixed: BTreeMap>, - pub wits: BTreeMap>, -} +pub struct RangeTableCircuit; + +impl TableCircuit for RangeTableCircuit { + type TableConfig = RangeTableConfig; + type Input = usize; + + fn name() -> String { + "RANGE".into() + } -impl RangeTableConfig { #[allow(unused)] - fn construct_circuit(cb: &mut CircuitBuilder) -> Result, ZKVMError> { + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let u16_tbl = cb.create_fixed(|| "u16_tbl")?; let u16_mlt = cb.create_witin(|| "u16_mlt")?; @@ -35,97 +32,44 @@ impl RangeTableConfig { cb.lk_table_record(|| "u16 table", u16_table_values, u16_mlt.expr())?; - Ok(RangeTableConfig { - u16_tbl, - u16_mlt, - _marker: Default::default(), - }) + Ok(RangeTableConfig { u16_tbl, u16_mlt }) } + fn generate_fixed_traces( + config: &RangeTableConfig, + num_fixed: usize, + ) -> RowMajorMatrix { + let num_u16s = 1 << 16; + let mut fixed = RowMajorMatrix::::new(num_u16s, num_fixed); + fixed + .par_iter_mut() + .zip((0..num_u16s).into_par_iter()) + .for_each(|(row, i)| { + set_fixed_val!(row, config.u16_tbl.0, E::BaseField::from(i as u64)); + }); + + fixed + } #[allow(unused)] - fn generate_traces(self, inputs: &[u16]) -> RangeTableTrace { - let mut u16_mlt = vec![0; 1 << 16]; + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + inputs: &[Self::Input], + ) -> Result, ZKVMError> { + let num_u16s = 1 << 16; + let mut u16_mlt = vec![0; num_u16s]; for limb in inputs { - u16_mlt[*limb as usize] += 1; + u16_mlt[*limb] += 1; } - let u16_tbl = DenseMultilinearExtension::from_evaluations_vec( - 16, - (0..(1 << 16)).map(E::BaseField::from).collect_vec(), - ); - let u16_mlt = DenseMultilinearExtension::from_evaluations_vec( - 16, - u16_mlt.into_iter().map(E::BaseField::from).collect_vec(), - ); - - let config = self.clone(); - let mut traces = RangeTableTrace::default(); - traces.fixed.insert(config.u16_tbl, u16_tbl); - traces.wits.insert(config.u16_mlt.id, u16_mlt); - - traces - } -} - -#[cfg(test)] -mod tests { - use crate::{ - circuit_builder::{CircuitBuilder, ConstraintSystem}, - scheme::{constants::NUM_FANIN_LOGUP, prover::ZKVMProver, verifier::ZKVMVerifier}, - structs::PointAndEval, - tables::range::RangeTableConfig, - }; - use goldilocks::GoldilocksExt2; - use itertools::Itertools; - use transcript::Transcript; - - #[test] - fn test_range_circuit() { - let mut cs = ConstraintSystem::new(|| "riscv"); - let config = cs - .namespace( - || "range", - |cs| { - let mut cb = CircuitBuilder::::new(cs); - RangeTableConfig::construct_circuit(&mut cb) - }, - ) - .expect("construct range table circuit"); - - let traces = config.generate_traces((0..1 << 8).collect_vec().as_slice()); - - let pk = cs.key_gen(Some(traces.fixed.clone().into_values().collect_vec())); - let vk = pk.vk.clone(); - let prover = ZKVMProver::new(pk); - - let mut transcript = Transcript::new(b"range"); - let challenges = [1.into(), 2.into()]; - - let proof = prover - .create_table_proof( - traces - .wits - .into_values() - .map(|mle| mle.into()) - .collect_vec(), - // TODO: fix the verification error for num_instances is not power-of-two case - 1 << 16, - 1, - &mut transcript, - &challenges, - ) - .expect("create proof"); + let mut witness = RowMajorMatrix::::new(u16_mlt.len(), num_witin); + witness + .par_iter_mut() + .zip(u16_mlt.into_par_iter()) + .for_each(|(row, mlt)| { + set_val!(row, config.u16_mlt, E::BaseField::from(mlt)); + }); - let mut transcript = Transcript::new(b"range"); - let verifier = ZKVMVerifier::new(vk); - verifier - .verify_table_proof( - &proof, - &mut transcript, - NUM_FANIN_LOGUP, - &PointAndEval::default(), - &challenges, - ) - .expect("verify proof failed"); + Ok(witness) } } diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 3452943e0..660f9de33 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -16,6 +16,13 @@ macro_rules! set_val { }; } +#[macro_export] +macro_rules! set_fixed_val { + ($ins:ident, $field:expr, $val:expr) => { + $ins[$field as usize] = MaybeUninit::new($val); + }; +} + pub struct RowMajorMatrix { // represent 2D in 1D linear memory and avoid double indirection by Vec> to improve performance values: Vec>, diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 9c59a6806..a4530919c 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -124,7 +124,7 @@ pub trait IntoMLEs: Sized { fn into_mles(self) -> Vec; } -impl IntoMLEs> for Vec> { +impl> IntoMLEs> for Vec> { fn into_mles(self) -> Vec> { self.into_iter().map(|v| v.into_mle()).collect() } From 7084e09cf3ca28fe2a3d0a89a36d0f4de9ab6b6a Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 5 Sep 2024 23:48:12 +0800 Subject: [PATCH 05/15] update riscv_add example --- ceno_zkvm/examples/riscv_add.rs | 108 +++++++++++++++------ ceno_zkvm/src/circuit_builder.rs | 6 ++ ceno_zkvm/src/instructions/riscv/addsub.rs | 12 +-- ceno_zkvm/src/scheme/prover.rs | 4 +- ceno_zkvm/src/tables/mod.rs | 1 + ceno_zkvm/src/tables/range.rs | 16 ++- 6 files changed, 104 insertions(+), 43 deletions(-) diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index 16c0d5dee..6d9c65f9a 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -1,4 +1,4 @@ -use std::time::Instant; +use std::{collections::BTreeMap, time::Instant}; use ark_std::test_rng; use ceno_zkvm::{ @@ -8,6 +8,12 @@ use ceno_zkvm::{ }; use const_env::from_env; +use ceno_emul::StepRecord; +use ceno_zkvm::{ + circuit_builder::{ZKVMConstraintSystem, ZKVMVerifyingKey}, + scheme::verifier::ZKVMVerifier, + tables::{RangeTableCircuit, TableCircuit}, +}; use ff_ext::ff::Field; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; @@ -21,6 +27,8 @@ use transcript::Transcript; const RAYON_NUM_THREADS: usize = 8; fn main() { + type E = GoldilocksExt2; + let max_threads = { if !is_power_of_2(RAYON_NUM_THREADS) { #[cfg(not(feature = "non_pow2_rayon_thread"))] @@ -41,16 +49,6 @@ fn main() { RAYON_NUM_THREADS } }; - let mut cs = ConstraintSystem::new(|| "risv_add"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let _ = AddInstruction::construct_circuit(&mut circuit_builder); - let pk = cs.key_gen(None); - let num_witin = pk.get_cs().num_witin; - - let prover = ZKVMProver::new(pk); - let mut transcript = Transcript::new(b"riscv"); - let mut rng = test_rng(); - let real_challenges = [E::random(&mut rng), E::random(&mut rng)]; let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() @@ -64,34 +62,82 @@ fn main() { .with(flame_layer.with_threads_collapsed(true)); tracing::subscriber::set_global_default(subscriber).unwrap(); + // keygen + let mut zkvm_fixed_traces = BTreeMap::default(); + let mut zkvm_cs = ZKVMConstraintSystem::default(); + + let (add_cs, add_config) = { + let mut cs = ConstraintSystem::new(|| "riscv_add"); + let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let config = AddInstruction::construct_circuit(&mut circuit_builder).unwrap(); + zkvm_cs.add_cs(AddInstruction::::name(), cs.clone()); + zkvm_fixed_traces.insert(AddInstruction::::name(), None); + (cs, config) + }; + let (range_cs, range_config) = { + let mut cs = ConstraintSystem::new(|| "riscv_range"); + let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let config = RangeTableCircuit::construct_circuit(&mut circuit_builder).unwrap(); + zkvm_cs.add_cs( + as TableCircuit>::name(), + cs.clone(), + ); + zkvm_fixed_traces.insert( + as TableCircuit>::name(), + Some(RangeTableCircuit::::generate_fixed_traces( + &config, + cs.num_fixed, + )), + ); + (cs, config) + }; + let pk = zkvm_cs.key_gen(zkvm_fixed_traces); + let vk = pk.get_vk(); + + // proving + let prover = ZKVMProver::new(pk); + let verifier = ZKVMVerifier::new(vk); + for instance_num_vars in 20..22 { - // generate mock witness + // TODO: witness generation from step records emitted by tracer let num_instances = 1 << instance_num_vars; - let wits_in = (0..num_witin as usize) - .map(|_| { - (0..num_instances) - .map(|_| Goldilocks::random(&mut rng)) - .collect::>() - .into_mle() - .into() - }) - .collect_vec(); + let mut zkvm_witness = BTreeMap::default(); + let add_witness = AddInstruction::assign_instances( + &add_config, + add_cs.num_witin as usize, + vec![StepRecord::default(); num_instances], + ) + .unwrap(); + let range_witness = RangeTableCircuit::::assign_instances( + &range_config, + range_cs.num_witin as usize, + &[], + ) + .unwrap(); + + zkvm_witness.insert(AddInstruction::::name(), add_witness); + zkvm_witness.insert(RangeTableCircuit::::name(), range_witness); + let timer = Instant::now(); - let _ = prover - .create_opcode_proof( - wits_in, - num_instances, - max_threads, - &mut transcript, - &real_challenges, - ) + + let mut transcript = Transcript::new(b"riscv"); + let mut rng = test_rng(); + let real_challenges = [E::random(&mut rng), E::random(&mut rng)]; + + let zkvm_proof = prover + .create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges) .expect("create_proof failed"); + + assert!( + verifier + .verify_proof(zkvm_proof, &mut transcript, &real_challenges,) + .expect("verify proof return with error"), + ); + println!( "AddInstruction::create_proof, instance_num_vars = {}, time = {}", instance_num_vars, timer.elapsed().as_secs_f64() ); } - - type E = GoldilocksExt2; } diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 4c9d812a5..517f1c04f 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -335,6 +335,12 @@ pub struct ZKVMConstraintSystem { pub circuit_css: BTreeMap>, } +impl ZKVMConstraintSystem { + pub fn add_cs(&mut self, name: String, cs: ConstraintSystem) { + assert!(self.circuit_css.insert(name, cs).is_none()); + } +} + #[derive(Default)] pub struct ZKVMProvingKey { // pk for opcode and table circuits diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 99768c186..0e346b058 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -17,8 +17,8 @@ use crate::{ }; use core::mem::MaybeUninit; -pub struct AddInstruction; -pub struct SubInstruction; +pub struct AddInstruction(PhantomData); +pub struct SubInstruction(PhantomData); #[derive(Debug)] pub struct InstructionConfig { @@ -37,11 +37,11 @@ pub struct InstructionConfig { phantom: PhantomData, } -impl RIVInstruction for AddInstruction { +impl RIVInstruction for AddInstruction { const OPCODE_TYPE: OpcodeType = OpcodeType::RType(OPType::Op, 0x000, 0x0000000); } -impl RIVInstruction for SubInstruction { +impl RIVInstruction for SubInstruction { const OPCODE_TYPE: OpcodeType = OpcodeType::RType(OPType::Op, 0x000, 0x0100000); } @@ -134,7 +134,7 @@ fn add_sub_gadget( }) } -impl Instruction for AddInstruction { +impl Instruction for AddInstruction { // const NAME: &'static str = "ADD"; fn name() -> String { "ADD".into() @@ -183,7 +183,7 @@ impl Instruction for AddInstruction { } } -impl Instruction for SubInstruction { +impl Instruction for SubInstruction { // const NAME: &'static str = "ADD"; fn name() -> String { "SUB".into() diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index e7f0f8680..4771df151 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,6 +1,6 @@ use ff_ext::ExtensionField; use std::{ - collections::{BTreeSet, HashMap}, + collections::{BTreeMap, BTreeSet, HashMap}, sync::Arc, }; @@ -48,7 +48,7 @@ impl ZKVMProver { /// create proof for zkvm execution pub fn create_proof( &self, - mut witnesses: HashMap>, + mut witnesses: BTreeMap>, max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index ebd2e78f2..b345fda84 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -2,6 +2,7 @@ use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajor use ff_ext::ExtensionField; mod range; +pub use range::RangeTableCircuit; pub trait TableCircuit { type TableConfig: Send + Sync; diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index b0b496675..0dabf5791 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -1,6 +1,14 @@ -use std::mem::MaybeUninit; +use std::{marker::PhantomData, mem::MaybeUninit}; -use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, Fixed, ToExpr, WitIn}, set_fixed_val, set_val, structs::ROMType, tables::TableCircuit, witness::RowMajorMatrix}; +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, Fixed, ToExpr, WitIn}, + set_fixed_val, set_val, + structs::ROMType, + tables::TableCircuit, + witness::RowMajorMatrix, +}; use ff_ext::ExtensionField; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; @@ -10,9 +18,9 @@ pub struct RangeTableConfig { u16_mlt: WitIn, } -pub struct RangeTableCircuit; +pub struct RangeTableCircuit(PhantomData); -impl TableCircuit for RangeTableCircuit { +impl TableCircuit for RangeTableCircuit { type TableConfig = RangeTableConfig; type Input = usize; From e598e98e6e9307d9573a19f17e300464314ef65d Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 9 Sep 2024 19:35:57 +0800 Subject: [PATCH 06/15] pass example test --- ceno_zkvm/examples/riscv_add.rs | 16 +++++----- ceno_zkvm/src/error.rs | 2 +- ceno_zkvm/src/scheme/prover.rs | 23 +++++++++++++- ceno_zkvm/src/scheme/verifier.rs | 51 ++++++++++++++++++++++++-------- ceno_zkvm/src/tables/range.rs | 1 + 5 files changed, 71 insertions(+), 22 deletions(-) diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index 6d9c65f9a..bb2dbacab 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -10,14 +10,12 @@ use const_env::from_env; use ceno_emul::StepRecord; use ceno_zkvm::{ - circuit_builder::{ZKVMConstraintSystem, ZKVMVerifyingKey}, + circuit_builder::ZKVMConstraintSystem, scheme::verifier::ZKVMVerifier, tables::{RangeTableCircuit, TableCircuit}, }; use ff_ext::ff::Field; -use goldilocks::{Goldilocks, GoldilocksExt2}; -use itertools::Itertools; -use multilinear_extensions::mle::IntoMLE; +use goldilocks::GoldilocksExt2; use sumcheck::util::is_power_of_2; use tracing_flame::FlameLayer; use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry}; @@ -98,7 +96,7 @@ fn main() { let prover = ZKVMProver::new(pk); let verifier = ZKVMVerifier::new(vk); - for instance_num_vars in 20..22 { + for instance_num_vars in 15..22 { // TODO: witness generation from step records emitted by tracer let num_instances = 1 << instance_num_vars; let mut zkvm_witness = BTreeMap::default(); @@ -111,7 +109,10 @@ fn main() { let range_witness = RangeTableCircuit::::assign_instances( &range_config, range_cs.num_witin as usize, - &[], + // TODO: use real data + vec![vec![0; num_instances * 2], vec![4; num_instances * 6]] + .concat() + .as_slice(), ) .unwrap(); @@ -128,9 +129,10 @@ fn main() { .create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges) .expect("create_proof failed"); + let mut transcript = Transcript::new(b"riscv"); assert!( verifier - .verify_proof(zkvm_proof, &mut transcript, &real_challenges,) + .verify_proof(zkvm_proof, &mut transcript, &real_challenges) .expect("verify proof return with error"), ); diff --git a/ceno_zkvm/src/error.rs b/ceno_zkvm/src/error.rs index d623364c9..ea59969f8 100644 --- a/ceno_zkvm/src/error.rs +++ b/ceno_zkvm/src/error.rs @@ -7,7 +7,7 @@ pub enum UtilError { pub enum ZKVMError { CircuitError, UtilError(UtilError), - VerifyError(&'static str), + VerifyError(String), } impl From for ZKVMError { diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 4771df151..685e6f1cb 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,6 +1,6 @@ use ff_ext::ExtensionField; use std::{ - collections::{BTreeMap, BTreeSet, HashMap}, + collections::{BTreeMap, BTreeSet}, sync::Arc, }; @@ -65,6 +65,17 @@ impl ZKVMProver { let num_instances = witness.num_instances(); if is_opcode_circuit { + tracing::debug!( + "opcode circuit {} has {} witnesses, {} reads, {} writes, {} lookups", + circuit_name, + cs.num_witin, + cs.r_expressions.len(), + cs.w_expressions.len(), + cs.lk_expressions.len(), + ); + for lk_s in cs.lk_expressions_namespace_map.iter() { + tracing::debug!("opcode circuit {}: {}", circuit_name, lk_s); + } let opcode_proof = self.create_opcode_proof( pk, witness @@ -78,6 +89,11 @@ impl ZKVMProver { transcript, challenges, )?; + tracing::info!( + "generated proof for opcode {} with num_instances={}", + circuit_name, + num_instances + ); vm_proof .opcode_proofs .insert(circuit_name.clone(), opcode_proof); @@ -95,6 +111,11 @@ impl ZKVMProver { transcript, challenges, )?; + tracing::info!( + "generated proof for table {} with num_instances={}", + circuit_name, + num_instances + ); vm_proof .table_proofs .insert(circuit_name.clone(), table_proof); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 40b4a818e..92c4905f7 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -46,7 +46,9 @@ impl ZKVMVerifier { let mut prod_r = E::ONE; let mut prod_w = E::ONE; let mut logup_sum = E::ZERO; + let dummy_table_item = challenges[0]; let point_eval = PointAndEval::default(); + let mut dummy_table_item_multiplicity = 0; for (name, opcode_proof) in vm_proof.opcode_proofs { let circuit_vk = self .vk @@ -61,6 +63,16 @@ impl ZKVMVerifier { &point_eval, challenges, )?; + tracing::info!("verified proof for opcode {}", name); + + // getting the number of dummy padding item that we used in this opcode circuit + let num_lks = circuit_vk.get_cs().lk_expressions.len(); + let num_padded_lks_per_instance = num_lks.next_power_of_two() - num_lks; + let num_padded_instance = + opcode_proof.num_instances.next_power_of_two() - opcode_proof.num_instances; + dummy_table_item_multiplicity += num_padded_lks_per_instance + * opcode_proof.num_instances + + num_lks.next_power_of_two() * num_padded_instance; prod_r *= opcode_proof.record_r_out_evals.iter().product::(); prod_w *= opcode_proof.record_w_out_evals.iter().product::(); @@ -85,18 +97,26 @@ impl ZKVMVerifier { &point_eval, challenges, )?; + tracing::info!("verified proof for table {}", name); logup_sum -= table_proof.lk_p1_out_eval * table_proof.lk_q1_out_eval.invert().unwrap(); logup_sum -= table_proof.lk_p2_out_eval * table_proof.lk_q2_out_eval.invert().unwrap(); } + logup_sum -= + E::from(dummy_table_item_multiplicity as u64) * dummy_table_item.invert().unwrap(); + // check rw_set equality across all proofs - if prod_r != prod_w { - return Ok(false); - } + // TODO: enable this when we have cpu init/finalize and mem init/finalize + // if prod_r != prod_w { + // return Err(ZKVMError::VerifyError("prod_r != prod_w".into())); + // } // check logup relation across all proofs if logup_sum != E::ZERO { - return Ok(false); + return Err(ZKVMError::VerifyError(format!( + "logup_sum({:?}) != 0", + logup_sum + ))); } Ok(true) @@ -159,7 +179,7 @@ impl ZKVMVerifier { // index 0 is LogUp witness for Fixed Lookup table if logup_p_evals[0].eval != E::ONE { return Err(ZKVMError::VerifyError( - "Lookup table witness p(x) != constant 1", + "Lookup table witness p(x) != constant 1".into(), )); } @@ -279,7 +299,7 @@ impl ZKVMVerifier { .sum::(); if computed_evals != expected_evaluation { return Err(ZKVMError::VerifyError( - "main + sel evaluation verify failed", + "main + sel evaluation verify failed".into(), )); } // verify records (degree = 1) statement, thus no sumcheck @@ -298,7 +318,9 @@ impl ZKVMVerifier { eval_by_expr(&proof.wits_in_evals, challenges, expr) != *expected_evals }) { - return Err(ZKVMError::VerifyError("record evaluate != expected_evals")); + return Err(ZKVMError::VerifyError( + "record evaluate != expected_evals".into(), + )); } // verify zero expression (degree = 1) statement, thus no sumcheck @@ -326,7 +348,6 @@ impl ZKVMVerifier { let cs = circuit_vk.get_cs(); let lk_counts_per_instance = cs.lk_table_expressions.len(); let log2_lk_count = ceil_log2(lk_counts_per_instance); - let (chip_record_alpha, _) = (challenges[0], challenges[1]); let num_instances = proof.num_instances; let log2_num_instances = ceil_log2(num_instances); @@ -394,8 +415,7 @@ impl ZKVMVerifier { * ((0..lk_counts_per_instance) .map(|i| proof.lk_d_in_evals[i] * eq_lk[i]) .sum::() - + chip_record_alpha - * (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE)), + + (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE)), *alpha_lk_n * sel_lk * ((0..lk_counts_per_instance) @@ -405,7 +425,9 @@ impl ZKVMVerifier { .iter() .sum::(); if computed_evals != expected_evaluation { - return Err(ZKVMError::VerifyError("sel evaluation verify failed")); + return Err(ZKVMError::VerifyError( + "sel evaluation verify failed".into(), + )); } // verify records (degree = 1) statement, thus no sumcheck if cs @@ -427,7 +449,9 @@ impl ZKVMVerifier { ) != *expected_evals }) { - return Err(ZKVMError::VerifyError("record evaluate != expected_evals")); + return Err(ZKVMError::VerifyError( + "record evaluate != expected_evals".into(), + )); } Ok(input_opening_point) @@ -524,6 +548,7 @@ impl TowerVerify { }, transcript, ); + tracing::debug!("verified tower proof at layer {}/{}", round + 1, expected_max_round-1); // check expected_evaluation let rt: Point = sumcheck_claim.point.iter().map(|c| c.elements).collect(); @@ -555,7 +580,7 @@ impl TowerVerify { }) .sum::(); if expected_evaluation != sumcheck_claim.expected_evaluation { - return Err(ZKVMError::VerifyError("mismatch tower evaluation")); + return Err(ZKVMError::VerifyError("mismatch tower evaluation".into())); } // derive single eval diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index 0dabf5791..4c1916265 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -69,6 +69,7 @@ impl TableCircuit for RangeTableCircuit { for limb in inputs { u16_mlt[*limb] += 1; } + tracing::debug!("u16_mult[4] = {}", u16_mlt[4]); let mut witness = RowMajorMatrix::::new(u16_mlt.len(), num_witin); witness From 865e0cc76689436c3ca0e487e208782c434676fd Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 9 Sep 2024 22:18:57 +0800 Subject: [PATCH 07/15] add insn in step record when tracing --- ceno_emul/src/rv32im.rs | 6 ++++++ ceno_emul/src/tracer.rs | 10 ++++++++++ ceno_emul/src/vm_state.rs | 8 +++++++- 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 710d7e7d4..64e758f55 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -183,6 +183,12 @@ pub struct Instruction { pub func7: u32, } +impl Default for Instruction { + fn default() -> Self { + insn(InsnKind::INVALID, InsnCategory::Invalid, 0x00, 0x0, 0x00) + } +} + impl DecodedInstruction { pub fn new(insn: u32) -> Self { Self { diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index e2352ef0f..e8ae6efc5 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -5,6 +5,7 @@ use crate::{ rv32im::DecodedInstruction, CENO_PLATFORM, }; +use crate::rv32im::Instruction; /// An instruction and its context in an execution trace. That is concrete values of registers and memory. /// @@ -22,6 +23,7 @@ pub struct StepRecord { pub cycle: Cycle, pub pc: Change, pub insn_code: Word, + pub insn: Instruction, pub rs1: Option, pub rs2: Option, @@ -62,6 +64,10 @@ impl StepRecord { DecodedInstruction::new(self.insn_code) } + pub fn insn(&self) -> Instruction { + self.insn + } + pub fn rs1(&self) -> Option { self.rs1.clone() } @@ -134,6 +140,10 @@ impl Tracer { self.record.insn_code = value; } + pub fn store_insn(&mut self, insn: Instruction) { + self.record.insn = insn; + } + pub fn load_register(&mut self, idx: RegIdx, value: Word) { let addr = CENO_PLATFORM.register_vma(idx).into(); diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index b478fc5ff..546203fe8 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -82,6 +82,10 @@ impl VMState { Ok(step) } } + + pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) { + self.registers[idx] = value; + } } impl EmuContext for VMState { @@ -109,7 +113,9 @@ impl EmuContext for VMState { Err(anyhow!("Trap {:?}", cause)) // Crash. } - fn on_insn_decoded(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) {} + fn on_insn_decoded(&mut self, insn: &Instruction, _decoded: &DecodedInstruction) { + self.tracer.store_insn(*insn); + } fn on_normal_end(&mut self, _kind: &Instruction, _decoded: &DecodedInstruction) { self.tracer.store_pc(ByteAddr(self.pc)); From fb1174dfb93f59665c0f975b27f6cc290429e106 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 9 Sep 2024 23:28:28 +0800 Subject: [PATCH 08/15] use tracer to generate step records for riscv_add example --- ceno_zkvm/examples/riscv_add.rs | 77 ++++++++++++++++++---- ceno_zkvm/src/circuit_builder.rs | 9 +-- ceno_zkvm/src/error.rs | 3 + ceno_zkvm/src/instructions/riscv/addsub.rs | 5 +- ceno_zkvm/src/keygen.rs | 15 ++--- ceno_zkvm/src/lib.rs | 2 + ceno_zkvm/src/scheme/prover.rs | 2 +- ceno_zkvm/src/scheme/verifier.rs | 4 +- 8 files changed, 83 insertions(+), 34 deletions(-) diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index bb2dbacab..729e2f430 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -5,10 +5,11 @@ use ceno_zkvm::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{riscv::addsub::AddInstruction, Instruction}, scheme::prover::ZKVMProver, + UIntValue, }; use const_env::from_env; -use ceno_emul::StepRecord; +use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM}; use ceno_zkvm::{ circuit_builder::ZKVMConstraintSystem, scheme::verifier::ZKVMVerifier, @@ -16,6 +17,7 @@ use ceno_zkvm::{ }; use ff_ext::ff::Field; use goldilocks::GoldilocksExt2; +use itertools::Itertools; use sumcheck::util::is_power_of_2; use tracing_flame::FlameLayer; use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry}; @@ -24,6 +26,20 @@ use transcript::Transcript; #[from_env] const RAYON_NUM_THREADS: usize = 8; +// For now, we assume registers +// - x0 is not touched, +// - x1 is initialized to 1, +// - x2 is initialized to -1, +// - x3 is initialized to loop bound. +// we use x4 to hold the acc_sum. +const PROGRAM_ADD_LOOP: [u32; 4] = [ + // func7 rs2 rs1 f3 rd opcode + 0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1 + 0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1 + 0b_1_111111_00000_00011_001_1100_1_1100011, // bne x3, x0, -8 + 0b_000000000000_00000_000_00000_1110011, // ecall halt +]; + fn main() { type E = GoldilocksExt2; @@ -89,30 +105,63 @@ fn main() { ); (cs, config) }; - let pk = zkvm_cs.key_gen(zkvm_fixed_traces); + let pk = zkvm_cs.key_gen(zkvm_fixed_traces).expect("keygen failed"); let vk = pk.get_vk(); // proving let prover = ZKVMProver::new(pk); let verifier = ZKVMVerifier::new(vk); - for instance_num_vars in 15..22 { - // TODO: witness generation from step records emitted by tracer + for instance_num_vars in 8..22 { let num_instances = 1 << instance_num_vars; + let mut vm = VMState::new(CENO_PLATFORM); + let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); + + // init vm.x1 = 1, vm.x2 = -1, vm.x3 = num_instances + // vm.x4 += vm.x1 + vm.init_register_unsafe(1usize, 1); + vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement + vm.init_register_unsafe(3usize, num_instances as u32); + for (i, inst) in PROGRAM_ADD_LOOP.iter().enumerate() { + vm.init_memory(pc_start + i, *inst); + } + let records = vm + .iter_until_success() + .collect::, _>>() + .expect("vm exec failed") + .into_iter() + .filter(|record| record.insn().kind == ADD) + .collect::>(); + tracing::info!("tracer generated {} ADD records", records.len()); + + // TODO: generate range check inputs from opcode_circuit::assign_instances() + let rc_inputs = records + .iter() + .flat_map(|record| { + let rs1 = UIntValue::new(record.rs1().unwrap().value); + let rs2 = UIntValue::new(record.rs2().unwrap().value); + + let rd_prev = UIntValue::new(record.rd().unwrap().value.before); + let rd = UIntValue::new(record.rd().unwrap().value.after); + let carries = rs1 + .add_u16_carries(&rs2) + .into_iter() + .map(|c| c as u16) + .collect_vec(); + + [rd_prev.limbs, rd.limbs, carries].concat() + }) + .map(|x| x as usize) + .collect::>(); + let mut zkvm_witness = BTreeMap::default(); - let add_witness = AddInstruction::assign_instances( - &add_config, - add_cs.num_witin as usize, - vec![StepRecord::default(); num_instances], - ) - .unwrap(); + let add_witness = + AddInstruction::assign_instances(&add_config, add_cs.num_witin as usize, records) + .unwrap(); let range_witness = RangeTableCircuit::::assign_instances( &range_config, range_cs.num_witin as usize, - // TODO: use real data - vec![vec![0; num_instances * 2], vec![4; num_instances * 6]] - .concat() - .as_slice(), + &rc_inputs, ) .unwrap(); diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 517f1c04f..be2139f05 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -141,13 +141,8 @@ impl ConstraintSystem { // TODO: commit to fixed_traces // transpose from row-major to column-major - let fixed_traces = fixed_traces.map(|t| { - t.de_interleaving() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec() - }); + let fixed_traces = + fixed_traces.map(|t| t.de_interleaving().into_mles().into_iter().collect_vec()); ProvingKey { fixed_traces, diff --git a/ceno_zkvm/src/error.rs b/ceno_zkvm/src/error.rs index ea59969f8..b7791def9 100644 --- a/ceno_zkvm/src/error.rs +++ b/ceno_zkvm/src/error.rs @@ -7,6 +7,9 @@ pub enum UtilError { pub enum ZKVMError { CircuitError, UtilError(UtilError), + WitnessNotFound(String), + VKNotFound(String), + FixedTraceNotFound(String), VerifyError(String), } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 6224d016c..2b873d50d 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -161,9 +161,11 @@ impl Instruction for AddInstruction { set_val!(instance, config.ts, 2); let addend_0 = UIntValue::new(step.rs1().unwrap().value); let addend_1 = UIntValue::new(step.rs2().unwrap().value); + let outcome = UIntValue::new(step.rd().unwrap().value.after); + let rd_prev = UIntValue::new(step.rd().unwrap().value.before); config .prev_rd_value - .assign_limbs(instance, [0, 0].iter().map(E::BaseField::from).collect()); + .assign_limbs(instance, rd_prev.u16_fields()); config .addend_0 .assign_limbs(instance, addend_0.u16_fields()); @@ -178,6 +180,7 @@ impl Instruction for AddInstruction { .map(|carry| E::BaseField::from(carry as u64)) .collect_vec(), ); + config.outcome.assign_limbs(instance, outcome.u16_fields()); // TODO #167 set_val!(instance, config.rs1_id, 2); set_val!(instance, config.rs2_id, 2); diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index c475878bf..3bce91d11 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -1,5 +1,6 @@ use crate::{ circuit_builder::{ZKVMConstraintSystem, ZKVMProvingKey}, + error::ZKVMError, witness::RowMajorMatrix, }; use ff_ext::ExtensionField; @@ -9,22 +10,18 @@ impl ZKVMConstraintSystem { pub fn key_gen( self, mut vm_fixed_traces: BTreeMap>>, - ) -> ZKVMProvingKey { + ) -> Result, ZKVMError> { let mut vm_pk = ZKVMProvingKey::default(); for (c_name, cs) in self.circuit_css.into_iter() { - let fixed_traces = vm_fixed_traces.remove(&c_name).expect( - format!( - "circuit {}'s trace is not present in vm_fixed_traces", - c_name - ) - .as_str(), - ); + let fixed_traces = vm_fixed_traces + .remove(&c_name) + .ok_or(ZKVMError::FixedTraceNotFound(c_name.clone()))?; let circuit_pk = cs.key_gen(fixed_traces); assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); } - vm_pk + Ok(vm_pk) } } diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index c29f030ba..dbe64e06e 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -15,3 +15,5 @@ mod uint; mod utils; mod virtual_polys; mod witness; + +pub use uint::UIntValue; diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 685e6f1cb..6784029b0 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -57,7 +57,7 @@ impl ZKVMProver { for (circuit_name, pk) in self.pk.circuit_pks.iter() { let witness = witnesses .remove(circuit_name) - .expect(format!("witness for circuit {} is not found", circuit_name).as_str()); + .ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; // TODO: add an enum for circuit type either in constraint_system or vk let cs = pk.get_cs(); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 92c4905f7..4f6803a24 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -54,7 +54,7 @@ impl ZKVMVerifier { .vk .circuit_vks .get(&name) - .expect(format!("vk of opcode circuit {} is not present", name).as_str()); + .ok_or(ZKVMError::VKNotFound(name.clone()))?; let _rand_point = self.verify_opcode_proof( circuit_vk, &opcode_proof, @@ -88,7 +88,7 @@ impl ZKVMVerifier { .vk .circuit_vks .get(&name) - .expect(format!("vk of table circuit {} is not present", name).as_str()); + .ok_or(ZKVMError::VKNotFound(name.clone()))?; let _rand_point = self.verify_table_proof( circuit_vk, &table_proof, From 3c00947d989692f5eea775dd3619b823b423851e Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 9 Sep 2024 23:39:50 +0800 Subject: [PATCH 09/15] fix clippy --- ceno_zkvm/src/instructions/riscv/addsub.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 2b873d50d..d171baa3c 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -1,6 +1,5 @@ use std::marker::PhantomData; -use ark_std::iterable::Iterable; use ceno_emul::StepRecord; use ff_ext::ExtensionField; use itertools::Itertools; From d0745adf01e4f3864881c1378f6c8713979534a2 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 9 Sep 2024 23:57:34 +0800 Subject: [PATCH 10/15] fix add unit-test failure --- ceno_zkvm/src/instructions/riscv/addsub.rs | 26 +++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index d171baa3c..b6b02e5f5 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -242,7 +242,7 @@ impl Instruction for SubInstruction { #[cfg(test)] mod test { - use ceno_emul::{ReadOp, StepRecord}; + use ceno_emul::{Change, ReadOp, StepRecord, WriteOp}; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; @@ -276,15 +276,23 @@ mod test { cb.cs.num_witin as usize, vec![StepRecord { rs1: Some(ReadOp { - addr: 0.into(), + addr: 2.into(), value: 11u32, previous_cycle: 0, }), rs2: Some(ReadOp { - addr: 0.into(), + addr: 3.into(), value: 0xfffffffeu32, previous_cycle: 0, }), + rd: Some(WriteOp { + addr: 4.into(), + value: Change { + before: 0u32, + after: 9u32, + }, + previous_cycle: 0, + }), ..Default::default() }], ) @@ -323,15 +331,23 @@ mod test { cb.cs.num_witin as usize, vec![StepRecord { rs1: Some(ReadOp { - addr: 0.into(), + addr: 2.into(), value: u32::MAX - 1, previous_cycle: 0, }), rs2: Some(ReadOp { - addr: 0.into(), + addr: 3.into(), value: u32::MAX - 1, previous_cycle: 0, }), + rd: Some(WriteOp { + addr: 4.into(), + value: Change { + before: 0u32, + after: u32::MAX - 2, + }, + previous_cycle: 0, + }), ..Default::default() }], ) From fab77d014b4b0ad8b5ae9734860f059c269df3f2 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 10 Sep 2024 18:42:50 +0800 Subject: [PATCH 11/15] chores: address comments --- ceno_zkvm/src/tables/range.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index dc1e427f8..f5eb72d48 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -4,6 +4,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, Fixed, ToExpr, WitIn}, + scheme::constants::MIN_PAR_SIZE, set_fixed_val, set_val, structs::ROMType, tables::TableCircuit, @@ -29,7 +30,6 @@ impl TableCircuit for RangeTableCircuit { "RANGE".into() } - #[allow(unused)] fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let u16_tbl = cb.create_fixed(|| "u16_tbl")?; let u16_mlt = cb.create_witin(|| "u16_mlt")?; @@ -52,6 +52,7 @@ impl TableCircuit for RangeTableCircuit { let mut fixed = RowMajorMatrix::::new(num_u16s, num_fixed); fixed .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) .zip((0..num_u16s).into_par_iter()) .for_each(|(row, i)| { set_fixed_val!(row, config.u16_tbl.0, E::BaseField::from(i as u64)); @@ -59,7 +60,7 @@ impl TableCircuit for RangeTableCircuit { fixed } - #[allow(unused)] + fn assign_instances( config: &Self::TableConfig, num_witin: usize, @@ -73,6 +74,7 @@ impl TableCircuit for RangeTableCircuit { let mut witness = RowMajorMatrix::::new(u16_mlt.len(), num_witin); witness .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) .zip(u16_mlt.into_par_iter()) .for_each(|(row, mlt)| { set_val!(row, config.u16_mlt, E::BaseField::from(mlt as u64)); From 672bbf28f4b2832a8c3a40823d1fa607192114a1 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 10 Sep 2024 22:15:22 +0800 Subject: [PATCH 12/15] add zkvm witness assign and zkvm cs api for better readability --- ceno_zkvm/examples/riscv_add.rs | 74 +++++------- ceno_zkvm/src/circuit_builder.rs | 64 +---------- ceno_zkvm/src/keygen.rs | 7 +- ceno_zkvm/src/lib.rs | 2 +- ceno_zkvm/src/scheme/prover.rs | 14 +-- ceno_zkvm/src/scheme/verifier.rs | 3 +- ceno_zkvm/src/structs.rs | 191 ++++++++++++++++++++++++++++++- ceno_zkvm/src/tables/mod.rs | 2 +- ceno_zkvm/src/tables/range.rs | 3 +- 9 files changed, 233 insertions(+), 127 deletions(-) diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index a63dc6a10..6e66177c5 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -1,19 +1,17 @@ -use std::{collections::BTreeMap, time::Instant}; +use std::time::Instant; use ark_std::test_rng; use ceno_zkvm::{ - circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{riscv::addsub::AddInstruction, Instruction}, + instructions::{riscv::addsub::AddInstruction}, scheme::prover::ZKVMProver, - ROMType, }; use const_env::from_env; use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM}; use ceno_zkvm::{ - circuit_builder::ZKVMConstraintSystem, scheme::verifier::ZKVMVerifier, - tables::{RangeTableCircuit, TableCircuit}, + structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, + tables::RangeTableCircuit, }; use ff_ext::ff::Field; use goldilocks::GoldilocksExt2; @@ -76,35 +74,19 @@ fn main() { tracing::subscriber::set_global_default(subscriber).unwrap(); // keygen - let mut zkvm_fixed_traces = BTreeMap::default(); let mut zkvm_cs = ZKVMConstraintSystem::default(); - - let (add_cs, add_config) = { - let mut cs = ConstraintSystem::new(|| "riscv_add"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let config = AddInstruction::construct_circuit(&mut circuit_builder).unwrap(); - zkvm_cs.add_cs(AddInstruction::::name(), cs.clone()); - zkvm_fixed_traces.insert(AddInstruction::::name(), None); - (cs, config) - }; - let (range_cs, range_config) = { - let mut cs = ConstraintSystem::new(|| "riscv_range"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let config = RangeTableCircuit::construct_circuit(&mut circuit_builder).unwrap(); - zkvm_cs.add_cs( - as TableCircuit>::name(), - cs.clone(), - ); - zkvm_fixed_traces.insert( - as TableCircuit>::name(), - Some(RangeTableCircuit::::generate_fixed_traces( - &config, - cs.num_fixed, - )), - ); - (cs, config) - }; - let pk = zkvm_cs.key_gen(zkvm_fixed_traces).expect("keygen failed"); + let add_config = zkvm_cs.register_opcode_circuit::>(); + let range_config = zkvm_cs.register_table_circuit::>(); + + let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); + zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); + zkvm_fixed_traces + .register_table_circuit::>(&zkvm_cs, range_config.clone()); + + let pk = zkvm_cs + .clone() + .key_gen(zkvm_fixed_traces) + .expect("keygen failed"); let vk = pk.get_vk(); // proving @@ -133,20 +115,16 @@ fn main() { .collect::>(); tracing::info!("tracer generated {} ADD records", records.len()); - let mut zkvm_witness = BTreeMap::default(); - let (add_witness, table_inputs) = - AddInstruction::assign_instances(&add_config, add_cs.num_witin as usize, records) - .unwrap(); - let table_inputs = table_inputs.into_finalize_result(); - let range_witness = RangeTableCircuit::::assign_instances( - &range_config, - range_cs.num_witin as usize, - &table_inputs[ROMType::U16 as usize], - ) - .unwrap(); - - zkvm_witness.insert(AddInstruction::::name(), add_witness); - zkvm_witness.insert(RangeTableCircuit::::name(), range_witness); + let mut zkvm_witness = ZKVMWitnesses::default(); + // assign opcode circuits + zkvm_witness + .assign_opcode_circuit::>(&zkvm_cs, &add_config, records) + .unwrap(); + zkvm_witness.finalize_lk_multiplicities(); + // assign table circuits + zkvm_witness + .assign_table_circuit::>(&zkvm_cs, &range_config) + .unwrap(); let timer = Instant::now(); diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index be2139f05..fc127b5ba 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -1,13 +1,13 @@ use itertools::Itertools; -use std::{collections::BTreeMap, marker::PhantomData}; +use std::marker::PhantomData; use ff_ext::ExtensionField; -use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLEs}; +use multilinear_extensions::mle::IntoMLEs; use crate::{ error::ZKVMError, expression::{Expression, Fixed, WitIn}, - structs::WitnessId, + structs::{ProvingKey, VerifyingKey, WitnessId}, witness::RowMajorMatrix, }; @@ -301,61 +301,3 @@ impl ConstraintSystem { pub struct CircuitBuilder<'a, E: ExtensionField> { pub(crate) cs: &'a mut ConstraintSystem, } - -#[derive(Clone, Debug)] -pub struct ProvingKey { - pub fixed_traces: Option>>, - pub vk: VerifyingKey, -} - -impl ProvingKey { - pub fn get_cs(&self) -> &ConstraintSystem { - self.vk.get_cs() - } -} - -#[derive(Clone, Debug)] -pub struct VerifyingKey { - cs: ConstraintSystem, -} - -impl VerifyingKey { - pub fn get_cs(&self) -> &ConstraintSystem { - &self.cs - } -} - -#[derive(Default)] -pub struct ZKVMConstraintSystem { - pub circuit_css: BTreeMap>, -} - -impl ZKVMConstraintSystem { - pub fn add_cs(&mut self, name: String, cs: ConstraintSystem) { - assert!(self.circuit_css.insert(name, cs).is_none()); - } -} - -#[derive(Default)] -pub struct ZKVMProvingKey { - // pk for opcode and table circuits - pub circuit_pks: BTreeMap>, -} - -impl ZKVMProvingKey { - pub fn get_vk(&self) -> ZKVMVerifyingKey { - ZKVMVerifyingKey { - circuit_vks: self - .circuit_pks - .iter() - .map(|(name, pk)| (name.clone(), pk.vk.clone())) - .collect(), - } - } -} - -#[derive(Default)] -pub struct ZKVMVerifyingKey { - // pk for opcode and table circuits - pub circuit_vks: BTreeMap>, -} diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index 3bce91d11..248a6b240 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -1,20 +1,19 @@ use crate::{ - circuit_builder::{ZKVMConstraintSystem, ZKVMProvingKey}, error::ZKVMError, - witness::RowMajorMatrix, + structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMProvingKey}, }; use ff_ext::ExtensionField; -use std::collections::BTreeMap; impl ZKVMConstraintSystem { pub fn key_gen( self, - mut vm_fixed_traces: BTreeMap>>, + mut vm_fixed_traces: ZKVMFixedTraces, ) -> Result, ZKVMError> { let mut vm_pk = ZKVMProvingKey::default(); for (c_name, cs) in self.circuit_css.into_iter() { let fixed_traces = vm_fixed_traces + .circuit_fixed_traces .remove(&c_name) .ok_or(ZKVMError::FixedTraceNotFound(c_name.clone()))?; diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 3df49fd92..5aa9da5e6 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -11,7 +11,7 @@ mod chip_handler; pub mod circuit_builder; pub mod expression; mod keygen; -mod structs; +pub mod structs; mod uint; mod utils; mod virtual_polys; diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 6784029b0..8dac99666 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,8 +1,5 @@ use ff_ext::ExtensionField; -use std::{ - collections::{BTreeMap, BTreeSet}, - sync::Arc, -}; +use std::{collections::BTreeSet, sync::Arc}; use itertools::Itertools; use multilinear_extensions::{ @@ -19,7 +16,6 @@ use sumcheck::{ use transcript::Transcript; use crate::{ - circuit_builder::{ProvingKey, ZKVMProvingKey}, error::ZKVMError, scheme::{ constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, @@ -28,10 +24,11 @@ use crate::{ wit_infer_by_expr, }, }, - structs::{Point, TowerProofs, TowerProver, TowerProverSpec}, + structs::{ + Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, + }, utils::{get_challenge_pows, proper_num_threads}, virtual_polys::VirtualPolynomials, - witness::RowMajorMatrix, }; use super::{ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; @@ -48,7 +45,7 @@ impl ZKVMProver { /// create proof for zkvm execution pub fn create_proof( &self, - mut witnesses: BTreeMap>, + mut witnesses: ZKVMWitnesses, max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], @@ -56,6 +53,7 @@ impl ZKVMProver { let mut vm_proof = ZKVMProof::default(); for (circuit_name, pk) in self.pk.circuit_pks.iter() { let witness = witnesses + .witnesses .remove(circuit_name) .ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 4f6803a24..0facee0e3 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -13,13 +13,12 @@ use sumcheck::structs::{IOPProof, IOPVerifierState}; use transcript::Transcript; use crate::{ - circuit_builder::{VerifyingKey, ZKVMVerifyingKey}, error::ZKVMError, scheme::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, utils::eval_by_expr_with_fixed, }, - structs::{Point, PointAndEval, TowerProofs}, + structs::{Point, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, utils::{get_challenge_pows, sel_eval}, }; diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 321c37634..84a54bb55 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,6 +1,18 @@ +use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + error::ZKVMError, + instructions::Instruction, + tables::TableCircuit, + witness::{LkMultiplicity, RowMajorMatrix}, +}; +use ceno_emul::StepRecord; use ff_ext::ExtensionField; -use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; +use itertools::Itertools; +use multilinear_extensions::{ + mle::DenseMultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, +}; use serde::Serialize; +use std::collections::{BTreeMap, HashMap}; use sumcheck::structs::IOPProverMessage; pub struct TowerProver; @@ -75,3 +87,180 @@ impl PointAndEval { } } } + +#[derive(Clone, Debug)] +pub struct ProvingKey { + pub fixed_traces: Option>>, + pub vk: VerifyingKey, +} + +impl ProvingKey { + pub fn get_cs(&self) -> &ConstraintSystem { + self.vk.get_cs() + } +} + +#[derive(Clone, Debug)] +pub struct VerifyingKey { + pub(crate) cs: ConstraintSystem, +} + +impl VerifyingKey { + pub fn get_cs(&self) -> &ConstraintSystem { + &self.cs + } +} + +#[derive(Default, Clone)] +pub struct ZKVMConstraintSystem { + pub(crate) circuit_css: BTreeMap>, +} + +impl ZKVMConstraintSystem { + pub fn register_opcode_circuit>(&mut self) -> OC::InstructionConfig { + let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name())); + let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let config = OC::construct_circuit(&mut circuit_builder).unwrap(); + assert!(self.circuit_css.insert(OC::name(), cs).is_none()); + + config + } + + pub fn register_table_circuit>(&mut self) -> TC::TableConfig { + let mut cs = ConstraintSystem::new(|| format!("riscv_table/{}", TC::name())); + let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let config = TC::construct_circuit(&mut circuit_builder).unwrap(); + assert!(self.circuit_css.insert(TC::name(), cs.clone()).is_none()); + + config + } + + pub fn get_cs(&self, name: &String) -> Option<&ConstraintSystem> { + self.circuit_css.get(name) + } +} + +#[derive(Default)] +pub struct ZKVMFixedTraces { + pub circuit_fixed_traces: BTreeMap>>, +} + +impl ZKVMFixedTraces { + pub fn register_opcode_circuit>(&mut self, _cs: &ZKVMConstraintSystem) { + assert!(self.circuit_fixed_traces.insert(OC::name(), None).is_none()); + } + + pub fn register_table_circuit>( + &mut self, + cs: &ZKVMConstraintSystem, + config: TC::TableConfig, + ) { + let cs = cs.get_cs(&TC::name()).expect("cs not found"); + assert!( + self.circuit_fixed_traces + .insert( + TC::name(), + Some(TC::generate_fixed_traces(&config, cs.num_fixed,)), + ) + .is_none() + ); + } +} + +#[derive(Default)] +pub struct ZKVMWitnesses { + pub witnesses: BTreeMap>, + lk_mlts: BTreeMap, + combined_lk_mlt: Option>>, +} + +impl ZKVMWitnesses { + pub fn assign_opcode_circuit>( + &mut self, + cs: &ZKVMConstraintSystem, + config: &OC::InstructionConfig, + records: Vec, + ) -> Result<(), ZKVMError> { + assert!(self.combined_lk_mlt.is_none()); + + let cs = cs.get_cs(&OC::name()).unwrap(); + let (witness, logup_multiplicity) = + OC::assign_instances(config, cs.num_witin as usize, records)?; + assert!(self.witnesses.insert(OC::name(), witness).is_none()); + assert!( + self.lk_mlts + .insert(OC::name(), logup_multiplicity) + .is_none() + ); + + Ok(()) + } + + // merge the multiplicities in each opcode circuit into one + pub fn finalize_lk_multiplicities(&mut self) { + assert!(self.combined_lk_mlt.is_none()); + assert!(!self.lk_mlts.is_empty()); + + let mut combined_lk_mlt = vec![]; + let keys = self.lk_mlts.keys().cloned().collect_vec(); + for name in keys { + let lk_mlt = self.lk_mlts.remove(&name).unwrap().into_finalize_result(); + if combined_lk_mlt.is_empty() { + combined_lk_mlt = lk_mlt.to_vec(); + } else { + combined_lk_mlt + .iter_mut() + .zip_eq(lk_mlt.iter()) + .for_each(|(m1, m2)| { + for (key, value) in m2 { + *m1.entry(*key).or_insert(0) += value; + } + }); + } + } + + self.combined_lk_mlt = Some(combined_lk_mlt); + } + + pub fn assign_table_circuit>( + &mut self, + cs: &ZKVMConstraintSystem, + config: &TC::TableConfig, + ) -> Result<(), ZKVMError> { + assert!(self.combined_lk_mlt.is_some()); + + let cs = cs.get_cs(&TC::name()).unwrap(); + let witness = TC::assign_instances( + config, + cs.num_witin as usize, + self.combined_lk_mlt.as_ref().unwrap(), + )?; + assert!(self.witnesses.insert(TC::name(), witness).is_none()); + + Ok(()) + } +} + +#[derive(Default)] +pub struct ZKVMProvingKey { + // pk for opcode and table circuits + pub(crate) circuit_pks: BTreeMap>, +} + +impl ZKVMProvingKey { + pub fn get_vk(&self) -> ZKVMVerifyingKey { + ZKVMVerifyingKey { + circuit_vks: self + .circuit_pks + .iter() + .map(|(name, pk)| (name.clone(), pk.vk.clone())) + .collect(), + } + } +} + +#[derive(Default)] +pub struct ZKVMVerifyingKey { + // pk for opcode and table circuits + pub circuit_vks: BTreeMap>, +} diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index ee7d06454..a66094952 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -23,6 +23,6 @@ pub trait TableCircuit { fn assign_instances( config: &Self::TableConfig, num_witin: usize, - multiplicity: &HashMap, + multiplicity: &[HashMap], ) -> Result, ZKVMError>; } diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index f5eb72d48..293077f46 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -64,8 +64,9 @@ impl TableCircuit for RangeTableCircuit { fn assign_instances( config: &Self::TableConfig, num_witin: usize, - multiplicity: &HashMap, + multiplicity: &[HashMap], ) -> Result, ZKVMError> { + let multiplicity = &multiplicity[ROMType::U16 as usize]; let mut u16_mlt = vec![0; 1 << RANGE_CHIP_BIT_WIDTH]; for (limb, mlt) in multiplicity { u16_mlt[*limb as usize] = *mlt; From 8d4d0ad17775dbd2300107de687767a104e69abd Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 10 Sep 2024 22:19:09 +0800 Subject: [PATCH 13/15] fmt --- ceno_zkvm/examples/riscv_add.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_add.rs index 6e66177c5..ec8732085 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_add.rs @@ -1,10 +1,7 @@ use std::time::Instant; use ark_std::test_rng; -use ceno_zkvm::{ - instructions::{riscv::addsub::AddInstruction}, - scheme::prover::ZKVMProver, -}; +use ceno_zkvm::{instructions::riscv::addsub::AddInstruction, scheme::prover::ZKVMProver}; use const_env::from_env; use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM}; From 5514e5e88bd84d8e2b050f457f4a1be1c6e48636 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 11 Sep 2024 12:25:37 +0800 Subject: [PATCH 14/15] bring tests back --- ceno_zkvm/src/scheme/prover.rs | 2 +- ceno_zkvm/src/scheme/tests.rs | 174 +++++++++++++++++++------------ ceno_zkvm/src/scheme/verifier.rs | 2 +- ceno_zkvm/src/structs.rs | 2 +- 4 files changed, 112 insertions(+), 68 deletions(-) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 8dac99666..fdf5a9aa2 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -34,7 +34,7 @@ use crate::{ use super::{ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; pub struct ZKVMProver { - pk: ZKVMProvingKey, + pub(crate) pk: ZKVMProvingKey, } impl ZKVMProver { diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index a273582a9..3941fcaaa 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -1,97 +1,141 @@ -use std::marker::PhantomData; +use std::{marker::PhantomData, mem::MaybeUninit}; +use ceno_emul::StepRecord; use ff::Field; use ff_ext::ExtensionField; -use goldilocks::{Goldilocks, GoldilocksExt2}; +use goldilocks::GoldilocksExt2; use itertools::Itertools; -use multilinear_extensions::mle::IntoMLE; +use multilinear_extensions::mle::IntoMLEs; +use rand::rngs::ThreadRng; use transcript::Transcript; use crate::{ - circuit_builder::{CircuitBuilder, ConstraintSystem}, + circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{Expression, ToExpr}, - structs::PointAndEval, + expression::{Expression, ToExpr, WitIn}, + instructions::Instruction, + set_val, + structs::{PointAndEval, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, + witness::LkMultiplicity, }; use super::{constants::NUM_FANIN, prover::ZKVMProver, verifier::ZKVMVerifier}; -struct TestCircuit { +struct TestConfig { + pub(crate) reg_id: WitIn, +} +struct TestCircuit { phantom: PhantomData, } -impl TestCircuit { - pub fn construct_circuit( - cb: &mut CircuitBuilder, - ) -> Result { - let regid = cb.create_witin(|| "reg_id")?; +impl Instruction for TestCircuit { + type InstructionConfig = TestConfig; + + fn name() -> String { + "TEST".into() + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let reg_id = cb.create_witin(|| "reg_id")?; (0..RW).try_for_each(|_| { let record = cb.rlc_chip_record(vec![ Expression::::Constant(E::BaseField::ONE), - regid.expr(), + reg_id.expr(), ]); cb.read_record(|| "read", record.clone())?; cb.write_record(|| "write", record)?; Result::<(), ZKVMError>::Ok(()) })?; (0..L).try_for_each(|_| { - cb.assert_ux::<_, _, 16>(|| "regid_in_range", regid.expr())?; + cb.assert_ux::<_, _, 16>(|| "regid_in_range", reg_id.expr())?; Result::<(), ZKVMError>::Ok(()) })?; assert_eq!(cb.cs.lk_expressions.len(), L); assert_eq!(cb.cs.r_expressions.len(), RW); assert_eq!(cb.cs.w_expressions.len(), RW); - Ok(Self { - phantom: PhantomData, - }) + + Ok(TestConfig { reg_id }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + _lk_multiplicity: &mut LkMultiplicity, + _step: StepRecord, + ) -> Result<(), ZKVMError> { + set_val!(instance, config.reg_id, E::BaseField::ONE); + + Ok(()) } } -// #[test] -// fn test_rw_lk_expression_combination() { -// fn test_rw_lk_expression_combination_inner() { -// let mut cs = ConstraintSystem::new(|| "test"); -// let mut circuit_builder = CircuitBuilder::::new(&mut cs); -// let _ = TestCircuit::construct_circuit::(&mut circuit_builder); -// let pk = cs.key_gen(None); -// let vk = pk.vk.clone(); -// -// // generate mock witness -// let num_instances = 1 << 2; -// let wits_in = (0..pk.get_cs().num_witin as usize) -// .map(|_| { -// (0..num_instances) -// .map(|_| Goldilocks::ONE) -// .collect::>() -// .into_mle() -// .into() -// }) -// .collect_vec(); -// -// // get proof -// let prover = ZKVMProver::new(pk); -// let mut transcript = Transcript::new(b"test"); -// let challenges = [1.into(), 2.into()]; -// -// let proof = prover -// .create_opcode_proof(wits_in, num_instances, 1, &mut transcript, &challenges) -// .expect("create_proof failed"); -// -// let verifier = ZKVMVerifier::new(vk); -// let mut v_transcript = Transcript::new(b"test"); -// let _rt_input = verifier -// .verify( -// &proof, -// &mut v_transcript, -// NUM_FANIN, -// &PointAndEval::default(), -// &challenges, -// ) -// .expect("verifier failed"); -// } -// -// // -// test_rw_lk_expression_combination_inner::<19, 17>(); -// test_rw_lk_expression_combination_inner::<61, 17>(); -// test_rw_lk_expression_combination_inner::<17, 61>(); -// } +#[test] +fn test_rw_lk_expression_combination() { + fn test_rw_lk_expression_combination_inner() { + type E = GoldilocksExt2; + let name = TestCircuit::::name(); + let mut zkvm_cs = ZKVMConstraintSystem::default(); + let config = zkvm_cs.register_opcode_circuit::>(); + + let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); + zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); + + let pk = zkvm_cs.clone().key_gen(zkvm_fixed_traces).unwrap(); + let vk = pk.get_vk(); + + // generate mock witness + let num_instances = 1 << 2; + let mut zkvm_witness = ZKVMWitnesses::default(); + zkvm_witness + .assign_opcode_circuit::>( + &zkvm_cs, + &config, + vec![StepRecord::default(); num_instances], + ) + .unwrap(); + + // get proof + let prover = ZKVMProver::new(pk); + let mut transcript = Transcript::new(b"test"); + let mut rng = ThreadRng::default(); + let challenges = [E::random(&mut rng), E::random(&mut rng)]; + + let wits_in = zkvm_witness + .witnesses + .remove(&name) + .unwrap() + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(); + let proof = prover + .create_opcode_proof( + &prover.pk.circuit_pks.get(&name).unwrap(), + wits_in, + num_instances, + 1, + &mut transcript, + &challenges, + ) + .expect("create_proof failed"); + + let verifier = ZKVMVerifier::new(vk.clone()); + let mut v_transcript = Transcript::new(b"test"); + let _rt_input = verifier + .verify_opcode_proof( + verifier.vk.circuit_vks.get(&name).unwrap(), + &proof, + &mut v_transcript, + NUM_FANIN, + &PointAndEval::default(), + &challenges, + ) + .expect("verifier failed"); + } + + // + test_rw_lk_expression_combination_inner::<19, 17>(); + test_rw_lk_expression_combination_inner::<61, 17>(); + test_rw_lk_expression_combination_inner::<17, 61>(); +} diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 0facee0e3..fb2f2e97a 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -28,7 +28,7 @@ use super::{ }; pub struct ZKVMVerifier { - vk: ZKVMVerifyingKey, + pub(crate) vk: ZKVMVerifyingKey, } impl ZKVMVerifier { diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 84a54bb55..16af00f7c 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -259,7 +259,7 @@ impl ZKVMProvingKey { } } -#[derive(Default)] +#[derive(Default, Clone)] pub struct ZKVMVerifyingKey { // pk for opcode and table circuits pub circuit_vks: BTreeMap>, From 7209c396ba964ee692d55ae208db63547c829b2b Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 11 Sep 2024 12:37:26 +0800 Subject: [PATCH 15/15] fix --- ceno_zkvm/src/scheme/tests.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 3941fcaaa..587beeb01 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -61,7 +61,7 @@ impl Instruction for Test config: &Self::InstructionConfig, instance: &mut [MaybeUninit], _lk_multiplicity: &mut LkMultiplicity, - _step: StepRecord, + _step: &StepRecord, ) -> Result<(), ZKVMError> { set_val!(instance, config.reg_id, E::BaseField::ONE);