From bb465c293598f48c5043c623e6d41957cc0260b1 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 18 Oct 2024 15:08:32 +0800 Subject: [PATCH 1/3] wip --- ceno_zkvm/src/scheme.rs | 18 ++ ceno_zkvm/src/scheme/prover.rs | 372 ++++++++++++++++++++++++++++++++- 2 files changed, 389 insertions(+), 1 deletion(-) diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 3e07f64d8..687cf3445 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -64,6 +64,24 @@ pub struct ZKVMTableProof> pub wits_opening_proof: PCS::Proof, } +#[derive(Clone, Serialize, Deserialize)] +pub struct ZKVMMemProof> { + // tower evaluation at layer 1 + pub r_out_evals: Vec<[E; 2]>, + pub w_out_evals: Vec<[E; 2]>, + + pub tower_proof: TowerProofs, + + // tower leafs layer witin + pub rw_in_evals: Vec, + + pub fixed_in_evals: Vec, + pub fixed_opening_proof: PCS::Proof, + pub wits_commit: PCS::Commitment, + pub wits_in_evals: Vec, + pub wits_opening_proof: PCS::Proof, +} + #[derive(Default, Clone, Debug)] pub struct PublicValues { exit_code: T, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2c92197a5..07db9d5a4 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -36,7 +36,7 @@ use crate::{ virtual_polys::VirtualPolynomials, }; -use super::{PublicValues, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; +use super::{PublicValues, ZKVMMemProof, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; pub struct ZKVMProver> { pub pk: ZKVMProvingKey, @@ -1035,6 +1035,376 @@ impl> ZKVMProver { wits_opening_proof, }) } + + #[allow(clippy::too_many_arguments)] + /// create memory proof for initialize/finalize + pub fn create_mem_proof( + &self, + name: &str, + pp: &PCS::ProverParam, + circuit_pk: &ProvingKey, + witnesses: Vec>, + wits_commit: PCS::CommitmentWithData, + // pi: &[E::BaseField], + addr_index: usize, + max_threads: usize, + transcript: &mut Transcript, + challenges: &[E; 2], + ) -> Result, ZKVMError> { + let cs = circuit_pk.get_cs(); + let fixed = circuit_pk + .fixed_traces + .as_ref() + .expect("pk.fixed_traces must not be none for table circuit") + .iter() + .map(|f| -> ArcMultilinearExtension { Arc::new(f.get_ranged_mle(1, 0)) }) + .collect::>>(); + + // assert!(addr_index < fixed.len()); + // let address_fixed = fixed[addr_index]; + + // sanity check + assert_eq!(witnesses.len(), cs.num_witin as usize); + assert_eq!(fixed.len(), cs.num_fixed); + // check all witness size are power of 2 + assert!( + witnesses + .iter() + .all(|v| { v.evaluations().len().is_power_of_two() }) + ); + assert!(!cs.r_table_expressions.is_empty() || !cs.w_table_expressions.is_empty()); + assert!( + cs.r_table_expressions + .iter() + .zip_eq(cs.w_table_expressions.iter()) + .all(|(r, w)| r.table_len == w.table_len) + ); + + // non-uniform PIOP by selecting expression via auxiliary input addr_index + let w_table_expr = cs.w_table_expressions[addr_index]; + assert_eq!(w_table_expr.values.degree(), 1); + let r_table_expr = cs.r_table_expressions[addr_index]; + assert_eq!(r_table_expr.values.degree(), 1); + + // main constraint: lookup denominator and numerator record witness inference + let span = entered_span!("wit_inference::record"); + let (w_set_wit, r_set_wit) = rayon::join( + || wit_infer_by_expr(&fixed, &witnesses, pi, challenges, &w_table_expr), + || wit_infer_by_expr(&fixed, &witnesses, pi, challenges, &r_table_expr), + ); + exit_span!(span); + + // infer all tower witness after last layer + let span = entered_span!("wit_inference::tower_witness_lk_last_layer"); + let mut r_set_last_layer = { + let (first, second) = r_set_wit + .get_ext_field_vec() + .split_at(r_set_wit.evaluations().len() / 2); + let res = vec![ + first.to_vec().into_mle().into(), + second.to_vec().into_mle().into(), + ]; + assert_eq!(res.len(), NUM_FANIN_LOGUP); + res + }; + let mut w_set_last_layer = { + let (first, second) = w_set_wit + .get_ext_field_vec() + .split_at(r_set_wit.evaluations().len() / 2); + let res = vec![ + first.to_vec().into_mle().into(), + second.to_vec().into_mle().into(), + ]; + assert_eq!(res.len(), NUM_FANIN_LOGUP); + res + }; + exit_span!(span); + + let span = entered_span!("wit_inference::tower_witness_rw_layers"); + let r_wit_layers = + infer_tower_product_witness(r_set_wit.num_vars(), r_set_last_layer, NUM_FANIN); + let w_wit_layers = + infer_tower_product_witness(w_set_wit.num_vars(), w_set_last_layer, NUM_FANIN); + exit_span!(span); + + if cfg!(test) { + // sanity check + assert_eq!(r_wit_layers.len(), cs.r_table_expressions.len()); + assert!( + r_wit_layers + .iter() + .zip(r_set_wit.iter()) // depth equals to num_vars + .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) + ); + assert!(r_wit_layers.iter().all(|layers| { + layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + w[0].evaluations().len() == expected_size + && w[1].evaluations().len() == expected_size + }) + })); + + assert_eq!(w_wit_layers.len(), cs.w_table_expressions.len()); + assert!( + w_wit_layers + .iter() + .zip(w_set_wit.iter()) // depth equals to num_vars + .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) + ); + assert!(w_wit_layers.iter().all(|layers| { + layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + w[0].evaluations().len() == expected_size + && w[1].evaluations().len() == expected_size + }) + })); + + assert_eq!(lk_wit_layers.len(), cs.lk_table_expressions.len()); + assert!( + lk_wit_layers + .iter() + .zip(lk_n_wit.iter()) // depth equals to num_vars + .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) + ); + assert!(lk_wit_layers.iter().all(|layers| { + layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + let (p1, p2, q1, q2) = (&w[0], &w[1], &w[2], &w[3]); + p1.evaluations().len() == expected_size + && p2.evaluations().len() == expected_size + && q1.evaluations().len() == expected_size + && q2.evaluations().len() == expected_size + }) + })); + } + + // product constraint tower sumcheck + let span = entered_span!("sumcheck::tower"); + // final evals for verifier + let r_out_evals = r_wit_layers + .iter() + .map(|r_wit_layers| { + [ + r_wit_layers[0][0].get_ext_field_vec()[0], + r_wit_layers[0][1].get_ext_field_vec()[0], + ] + }) + .collect_vec(); + let w_out_evals = w_wit_layers + .iter() + .map(|w_wit_layers| { + [ + w_wit_layers[0][0].get_ext_field_vec()[0], + w_wit_layers[0][1].get_ext_field_vec()[0], + ] + }) + .collect_vec(); + let lk_out_evals = lk_wit_layers + .iter() + .map(|lk_wit_layers| { + [ + // p1, p2, q1, q2 + lk_wit_layers[0][0].get_ext_field_vec()[0], + lk_wit_layers[0][1].get_ext_field_vec()[0], + lk_wit_layers[0][2].get_ext_field_vec()[0], + lk_wit_layers[0][3].get_ext_field_vec()[0], + ] + }) + .collect_vec(); + + let (rt_tower, tower_proof) = TowerProver::create_proof( + max_threads, + // pattern [r1, w1, r2, w2, ...] same pair are chain together + r_wit_layers + .into_iter() + .zip(w_wit_layers) + .flat_map(|(r, w)| { + vec![TowerProverSpec { witness: r }, TowerProverSpec { + witness: w, + }] + }) + .collect_vec(), + lk_wit_layers + .into_iter() + .map(|lk_wit_layers| TowerProverSpec { + witness: lk_wit_layers, + }) + .collect_vec(), + NUM_FANIN_LOGUP, + transcript, + ); + assert_eq!( + rt_tower.len(), // num var length should equal to max_num_instance + max_log2_num_instance + ); + exit_span!(span); + + // same point sumcheck is optional when all witin + fixed are in same num_vars + let is_skip_same_point_sumcheck = witnesses + .iter() + .chain(fixed.iter()) + .map(|v| v.num_vars()) + .all_equal(); + + let (input_open_point, same_r_sumcheck_proofs, rw_in_evals, lk_in_evals) = + if is_skip_same_point_sumcheck { + (rt_tower, None, vec![], vec![]) + } else { + // one sumcheck to make them opening on same point r (with different prefix) + // If all table length are the same, we can skip this sumcheck + let span = entered_span!("sumcheck::opening_same_point"); + // NOTE: max concurrency will be dominated by smallest table since it will blo + let num_threads = proper_num_threads(min_log2_num_instance, max_threads); + let alpha_pow = get_challenge_pows( + cs.r_table_expressions.len() + + cs.w_table_expressions.len() + + cs.lk_table_expressions.len() * 2, + transcript, + ); + let mut alpha_pow_iter = alpha_pow.iter(); + + // create eq + // TODO same size rt lead to same identical poly eq which can be merged together + let eq = tower_proof + .prod_specs_points + .iter() + .step_by(2) // r,w are in same length therefore share same point + .chain(tower_proof.logup_specs_points.iter()) + .map(|layer_points| { + let rt = layer_points.last().unwrap(); + build_eq_x_r_vec(rt).into_mle().into() + }) + .collect::>>(); + + let (eq_rw, eq_lk) = eq.split_at(cs.r_table_expressions.len()); + + let mut virtual_polys = + VirtualPolynomials::::new(num_threads, max_log2_num_instance); + + // alpha_r{i} * eq(rt_{i}, s) * r(s) + alpha_w{i} * eq(rt_{i}, s) * w(s) + for ((r_set_wit, w_set_wit), eq) in r_set_wit + .iter() + .zip_eq(w_set_wit.iter()) + .zip_eq(eq_rw.iter()) + { + let alpha = alpha_pow_iter.next().unwrap(); + virtual_polys.add_mle_list(vec![eq, r_set_wit], *alpha); + let alpha = alpha_pow_iter.next().unwrap(); + virtual_polys.add_mle_list(vec![eq, w_set_wit], *alpha); + } + + // alpha_lkn{i} * eq(rt_{i}, s) * lk_n(s) + alpha_lkd{i} * eq(rt_{i}, s) * lk_d(s) + for ((lk_n_wit, lk_d_wit), eq) in + lk_n_wit.iter().zip_eq(lk_d_wit.iter()).zip_eq(eq_lk.iter()) + { + let alpha = alpha_pow_iter.next().unwrap(); + virtual_polys.add_mle_list(vec![eq, lk_n_wit], *alpha); + let alpha = alpha_pow_iter.next().unwrap(); + virtual_polys.add_mle_list(vec![eq, lk_d_wit], *alpha); + } + + let (same_r_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( + num_threads, + virtual_polys.get_batched_polys(), + transcript, + ); + let evals = state.get_mle_final_evaluations(); + let mut evals_iter = evals.into_iter(); + let rw_in_evals = cs + // r, w table len are identical + .r_table_expressions + .iter() + .flat_map(|_table| { + let _eq = evals_iter.next().unwrap(); // skip eq + [evals_iter.next().unwrap(), evals_iter.next().unwrap()] // r, w + }) + .collect_vec(); + let lk_in_evals = cs + .lk_table_expressions + .iter() + .flat_map(|_table| { + let _eq = evals_iter.next().unwrap(); // skip eq + [evals_iter.next().unwrap(), evals_iter.next().unwrap()] // n, d + }) + .collect_vec(); + assert_eq!(evals_iter.count(), 0); + + let input_open_point = same_r_sumcheck_proofs.point.clone(); + assert_eq!(input_open_point.len(), max_log2_num_instance); + exit_span!(span); + + ( + input_open_point, + Some(same_r_sumcheck_proofs.proofs), + rw_in_evals, + lk_in_evals, + ) + }; + + let span = entered_span!("fixed::evals + witin::evals"); + let mut evals = witnesses + .par_iter() + .chain(fixed.par_iter()) + .map(|poly| poly.evaluate(&input_open_point[..poly.num_vars()])) + .collect::>(); + let fixed_in_evals = evals.split_off(witnesses.len()); + let wits_in_evals = evals; + exit_span!(span); + + let span = entered_span!("pcs_opening"); + let fixed_opening_proof = PCS::simple_batch_open( + pp, + &fixed, + circuit_pk.fixed_commit_wd.as_ref().unwrap(), + &input_open_point, + fixed_in_evals.as_slice(), + transcript, + ) + .map_err(ZKVMError::PCSError)?; + let fixed_commit = PCS::get_pure_commitment(circuit_pk.fixed_commit_wd.as_ref().unwrap()); + tracing::debug!( + "[table {}] build opening proof for {} fixed polys at {:?}: values = {:?}, commit = {:?}", + name, + fixed.len(), + input_open_point, + fixed_in_evals, + fixed_commit, + ); + let wits_opening_proof = PCS::simple_batch_open( + pp, + &witnesses, + &wits_commit, + &input_open_point, + wits_in_evals.as_slice(), + transcript, + ) + .map_err(ZKVMError::PCSError)?; + exit_span!(span); + let wits_commit = PCS::get_pure_commitment(&wits_commit); + tracing::debug!( + "[table {}] build opening proof for {} polys at {:?}: values = {:?}, commit = {:?}", + name, + witnesses.len(), + input_open_point, + wits_in_evals, + wits_commit, + ); + + Ok(ZKVMTableProof { + r_out_evals, + w_out_evals, + lk_out_evals, + same_r_sumcheck_proofs, + rw_in_evals, + lk_in_evals, + tower_proof, + fixed_in_evals, + fixed_opening_proof, + wits_in_evals, + wits_commit, + wits_opening_proof, + }) + } } /// TowerProofs From 1a3927a4a5de432f9d97e12167ce9051cbcbb6d5 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Sat, 19 Oct 2024 10:24:48 +0800 Subject: [PATCH 2/3] mem still impl ram table trait to reuse padding --- ceno_zkvm/src/scheme.rs | 7 +- ceno_zkvm/src/scheme/prover.rs | 260 ++++---------------- ceno_zkvm/src/structs.rs | 2 +- ceno_zkvm/src/tables/mod.rs | 6 +- ceno_zkvm/src/tables/ops/ops_circuit.rs | 1 + ceno_zkvm/src/tables/program.rs | 1 + ceno_zkvm/src/tables/ram/ram_circuit.rs | 1 + ceno_zkvm/src/tables/range/range_circuit.rs | 1 + 8 files changed, 63 insertions(+), 216 deletions(-) diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 687cf3445..d762d3b3b 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -67,14 +67,11 @@ pub struct ZKVMTableProof> #[derive(Clone, Serialize, Deserialize)] pub struct ZKVMMemProof> { // tower evaluation at layer 1 - pub r_out_evals: Vec<[E; 2]>, - pub w_out_evals: Vec<[E; 2]>, + pub r_out_evals: [E; 2], + pub w_out_evals: [E; 2], pub tower_proof: TowerProofs, - // tower leafs layer witin - pub rw_in_evals: Vec, - pub fixed_in_evals: Vec, pub fixed_opening_proof: PCS::Proof, pub wits_commit: PCS::Commitment, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 07db9d5a4..9de6e2398 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -32,6 +32,7 @@ use crate::{ structs::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, + tables::MemTableConfig, utils::{get_challenge_pows, next_pow2_instance_padding, proper_num_threads}, virtual_polys::VirtualPolynomials, }; @@ -1046,7 +1047,6 @@ impl> ZKVMProver { witnesses: Vec>, wits_commit: PCS::CommitmentWithData, // pi: &[E::BaseField], - addr_index: usize, max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], @@ -1060,11 +1060,12 @@ impl> ZKVMProver { .map(|f| -> ArcMultilinearExtension { Arc::new(f.get_ranged_mle(1, 0)) }) .collect::>>(); - // assert!(addr_index < fixed.len()); - // let address_fixed = fixed[addr_index]; - // sanity check assert_eq!(witnesses.len(), cs.num_witin as usize); + assert_eq!(witnesses.len(), 1); + + let num_vars = witnesses[0].num_vars(); + assert_eq!(fixed.len(), cs.num_fixed); // check all witness size are power of 2 assert!( @@ -1080,23 +1081,27 @@ impl> ZKVMProver { .all(|(r, w)| r.table_len == w.table_len) ); + assert!( + num_vars < MemTableConfig::ADDR_RANGE[1] && num_vars >= MemTableConfig::ADDR_RANGE[0] + ); + let fixed_addr_index = num_vars - MemTableConfig::ADDR_RANGE[0]; + // non-uniform PIOP by selecting expression via auxiliary input addr_index - let w_table_expr = cs.w_table_expressions[addr_index]; + let w_table_expr = &cs.w_table_expressions[fixed_addr_index]; assert_eq!(w_table_expr.values.degree(), 1); - let r_table_expr = cs.r_table_expressions[addr_index]; + let r_table_expr = &cs.r_table_expressions[fixed_addr_index]; assert_eq!(r_table_expr.values.degree(), 1); - // main constraint: lookup denominator and numerator record witness inference let span = entered_span!("wit_inference::record"); let (w_set_wit, r_set_wit) = rayon::join( - || wit_infer_by_expr(&fixed, &witnesses, pi, challenges, &w_table_expr), - || wit_infer_by_expr(&fixed, &witnesses, pi, challenges, &r_table_expr), + || wit_infer_by_expr(&fixed, &witnesses, &[], challenges, &w_table_expr.values), + || wit_infer_by_expr(&fixed, &witnesses, &[], challenges, &r_table_expr.values), ); exit_span!(span); // infer all tower witness after last layer let span = entered_span!("wit_inference::tower_witness_lk_last_layer"); - let mut r_set_last_layer = { + let r_set_last_layer = { let (first, second) = r_set_wit .get_ext_field_vec() .split_at(r_set_wit.evaluations().len() / 2); @@ -1107,7 +1112,7 @@ impl> ZKVMProver { assert_eq!(res.len(), NUM_FANIN_LOGUP); res }; - let mut w_set_last_layer = { + let w_set_last_layer = { let (first, second) = w_set_wit .get_ext_field_vec() .split_at(r_set_wit.evaluations().len() / 2); @@ -1129,217 +1134,58 @@ impl> ZKVMProver { if cfg!(test) { // sanity check - assert_eq!(r_wit_layers.len(), cs.r_table_expressions.len()); - assert!( - r_wit_layers - .iter() - .zip(r_set_wit.iter()) // depth equals to num_vars - .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) - ); - assert!(r_wit_layers.iter().all(|layers| { - layers.iter().enumerate().all(|(i, w)| { - let expected_size = 1 << i; - w[0].evaluations().len() == expected_size - && w[1].evaluations().len() == expected_size - }) - })); - - assert_eq!(w_wit_layers.len(), cs.w_table_expressions.len()); - assert!( - w_wit_layers - .iter() - .zip(w_set_wit.iter()) // depth equals to num_vars - .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) - ); - assert!(w_wit_layers.iter().all(|layers| { - layers.iter().enumerate().all(|(i, w)| { - let expected_size = 1 << i; - w[0].evaluations().len() == expected_size - && w[1].evaluations().len() == expected_size - }) + assert_eq!(r_wit_layers.len(), r_set_wit.num_vars()); + assert!(r_wit_layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + w[0].evaluations().len() == expected_size + && w[1].evaluations().len() == expected_size })); - - assert_eq!(lk_wit_layers.len(), cs.lk_table_expressions.len()); - assert!( - lk_wit_layers - .iter() - .zip(lk_n_wit.iter()) // depth equals to num_vars - .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) - ); - assert!(lk_wit_layers.iter().all(|layers| { - layers.iter().enumerate().all(|(i, w)| { - let expected_size = 1 << i; - let (p1, p2, q1, q2) = (&w[0], &w[1], &w[2], &w[3]); - p1.evaluations().len() == expected_size - && p2.evaluations().len() == expected_size - && q1.evaluations().len() == expected_size - && q2.evaluations().len() == expected_size - }) + assert_eq!(w_wit_layers.len(), w_set_wit.num_vars()); + assert!(w_wit_layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + w[0].evaluations().len() == expected_size + && w[1].evaluations().len() == expected_size })); } // product constraint tower sumcheck let span = entered_span!("sumcheck::tower"); // final evals for verifier - let r_out_evals = r_wit_layers - .iter() - .map(|r_wit_layers| { - [ - r_wit_layers[0][0].get_ext_field_vec()[0], - r_wit_layers[0][1].get_ext_field_vec()[0], - ] - }) - .collect_vec(); - let w_out_evals = w_wit_layers - .iter() - .map(|w_wit_layers| { - [ - w_wit_layers[0][0].get_ext_field_vec()[0], - w_wit_layers[0][1].get_ext_field_vec()[0], - ] - }) - .collect_vec(); - let lk_out_evals = lk_wit_layers - .iter() - .map(|lk_wit_layers| { - [ - // p1, p2, q1, q2 - lk_wit_layers[0][0].get_ext_field_vec()[0], - lk_wit_layers[0][1].get_ext_field_vec()[0], - lk_wit_layers[0][2].get_ext_field_vec()[0], - lk_wit_layers[0][3].get_ext_field_vec()[0], - ] - }) - .collect_vec(); + let r_out_evals: [E; 2] = { + [ + r_wit_layers[0][0].get_ext_field_vec()[0], + r_wit_layers[0][1].get_ext_field_vec()[0], + ] + }; + let w_out_evals: [E; 2] = { + [ + w_wit_layers[0][0].get_ext_field_vec()[0], + w_wit_layers[0][1].get_ext_field_vec()[0], + ] + }; let (rt_tower, tower_proof) = TowerProver::create_proof( max_threads, - // pattern [r1, w1, r2, w2, ...] same pair are chain together - r_wit_layers - .into_iter() - .zip(w_wit_layers) - .flat_map(|(r, w)| { - vec![TowerProverSpec { witness: r }, TowerProverSpec { - witness: w, - }] - }) - .collect_vec(), - lk_wit_layers - .into_iter() - .map(|lk_wit_layers| TowerProverSpec { - witness: lk_wit_layers, - }) - .collect_vec(), + // pattern [r1, w1] same pair are chain together + vec![ + TowerProverSpec { + witness: r_wit_layers, + }, + TowerProverSpec { + witness: w_wit_layers, + }, + ], + vec![], NUM_FANIN_LOGUP, transcript, ); assert_eq!( rt_tower.len(), // num var length should equal to max_num_instance - max_log2_num_instance + num_vars ); exit_span!(span); - // same point sumcheck is optional when all witin + fixed are in same num_vars - let is_skip_same_point_sumcheck = witnesses - .iter() - .chain(fixed.iter()) - .map(|v| v.num_vars()) - .all_equal(); - - let (input_open_point, same_r_sumcheck_proofs, rw_in_evals, lk_in_evals) = - if is_skip_same_point_sumcheck { - (rt_tower, None, vec![], vec![]) - } else { - // one sumcheck to make them opening on same point r (with different prefix) - // If all table length are the same, we can skip this sumcheck - let span = entered_span!("sumcheck::opening_same_point"); - // NOTE: max concurrency will be dominated by smallest table since it will blo - let num_threads = proper_num_threads(min_log2_num_instance, max_threads); - let alpha_pow = get_challenge_pows( - cs.r_table_expressions.len() - + cs.w_table_expressions.len() - + cs.lk_table_expressions.len() * 2, - transcript, - ); - let mut alpha_pow_iter = alpha_pow.iter(); - - // create eq - // TODO same size rt lead to same identical poly eq which can be merged together - let eq = tower_proof - .prod_specs_points - .iter() - .step_by(2) // r,w are in same length therefore share same point - .chain(tower_proof.logup_specs_points.iter()) - .map(|layer_points| { - let rt = layer_points.last().unwrap(); - build_eq_x_r_vec(rt).into_mle().into() - }) - .collect::>>(); - - let (eq_rw, eq_lk) = eq.split_at(cs.r_table_expressions.len()); - - let mut virtual_polys = - VirtualPolynomials::::new(num_threads, max_log2_num_instance); - - // alpha_r{i} * eq(rt_{i}, s) * r(s) + alpha_w{i} * eq(rt_{i}, s) * w(s) - for ((r_set_wit, w_set_wit), eq) in r_set_wit - .iter() - .zip_eq(w_set_wit.iter()) - .zip_eq(eq_rw.iter()) - { - let alpha = alpha_pow_iter.next().unwrap(); - virtual_polys.add_mle_list(vec![eq, r_set_wit], *alpha); - let alpha = alpha_pow_iter.next().unwrap(); - virtual_polys.add_mle_list(vec![eq, w_set_wit], *alpha); - } - - // alpha_lkn{i} * eq(rt_{i}, s) * lk_n(s) + alpha_lkd{i} * eq(rt_{i}, s) * lk_d(s) - for ((lk_n_wit, lk_d_wit), eq) in - lk_n_wit.iter().zip_eq(lk_d_wit.iter()).zip_eq(eq_lk.iter()) - { - let alpha = alpha_pow_iter.next().unwrap(); - virtual_polys.add_mle_list(vec![eq, lk_n_wit], *alpha); - let alpha = alpha_pow_iter.next().unwrap(); - virtual_polys.add_mle_list(vec![eq, lk_d_wit], *alpha); - } - - let (same_r_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( - num_threads, - virtual_polys.get_batched_polys(), - transcript, - ); - let evals = state.get_mle_final_evaluations(); - let mut evals_iter = evals.into_iter(); - let rw_in_evals = cs - // r, w table len are identical - .r_table_expressions - .iter() - .flat_map(|_table| { - let _eq = evals_iter.next().unwrap(); // skip eq - [evals_iter.next().unwrap(), evals_iter.next().unwrap()] // r, w - }) - .collect_vec(); - let lk_in_evals = cs - .lk_table_expressions - .iter() - .flat_map(|_table| { - let _eq = evals_iter.next().unwrap(); // skip eq - [evals_iter.next().unwrap(), evals_iter.next().unwrap()] // n, d - }) - .collect_vec(); - assert_eq!(evals_iter.count(), 0); - - let input_open_point = same_r_sumcheck_proofs.point.clone(); - assert_eq!(input_open_point.len(), max_log2_num_instance); - exit_span!(span); - - ( - input_open_point, - Some(same_r_sumcheck_proofs.proofs), - rw_in_evals, - lk_in_evals, - ) - }; + let input_open_point = rt_tower; let span = entered_span!("fixed::evals + witin::evals"); let mut evals = witnesses @@ -1390,13 +1236,9 @@ impl> ZKVMProver { wits_commit, ); - Ok(ZKVMTableProof { + Ok(ZKVMMemProof { r_out_evals, w_out_evals, - lk_out_evals, - same_r_sumcheck_proofs, - rw_in_evals, - lk_in_evals, tower_proof, fixed_in_evals, fixed_opening_proof, diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index a9b78f416..9ff46392f 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -185,7 +185,7 @@ impl ZKVMFixedTraces { assert!(self.circuit_fixed_traces.insert(OC::name(), None).is_none()); } - pub fn register_table_circuit>( + pub fn register_table_circuit>>( &mut self, cs: &ZKVMConstraintSystem, config: TC::TableConfig, diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 2ef7e293a..4e575f4a0 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -18,10 +18,14 @@ pub use program::{InsnRecord, ProgramTableCircuit}; mod ram; pub use ram::*; +mod mem; +pub use mem::*; + pub trait TableCircuit { type TableConfig: Send + Sync; type FixedInput: Send + Sync + ?Sized; type WitnessInput: Send + Sync + ?Sized; + type FixedOutput: Send + Sync + ?Sized; fn name() -> String; @@ -33,7 +37,7 @@ pub trait TableCircuit { config: &Self::TableConfig, num_fixed: usize, input: &Self::FixedInput, - ) -> RowMajorMatrix; + ) -> Self::FixedOutput; fn assign_instances( config: &Self::TableConfig, diff --git a/ceno_zkvm/src/tables/ops/ops_circuit.rs b/ceno_zkvm/src/tables/ops/ops_circuit.rs index cd48ebe19..83f909ab0 100644 --- a/ceno_zkvm/src/tables/ops/ops_circuit.rs +++ b/ceno_zkvm/src/tables/ops/ops_circuit.rs @@ -34,6 +34,7 @@ impl TableCircuit for OpsTableCircuit type TableConfig = OpTableConfig; type FixedInput = (); type WitnessInput = (); + type FixedOutput = RowMajorMatrix; fn name() -> String { format!("OPS_{:?}", OP::ROM_TYPE) diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 3514365c8..9fb838db3 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -118,6 +118,7 @@ impl TableCircuit { type TableConfig = ProgramTableConfig; type FixedInput = [u32; PROGRAM_SIZE]; + type FixedOutput = RowMajorMatrix; type WitnessInput = usize; fn name() -> String { diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index eb5c1fbe7..51081742c 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -34,6 +34,7 @@ impl TableCircuit type TableConfig = RamTableConfig; type FixedInput = Option>; type WitnessInput = Vec; + type FixedOutput = RowMajorMatrix; fn name() -> String { format!("RAM_{:?}", RAM::RAM_TYPE) diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index bb7c83448..ca18fd3ab 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -27,6 +27,7 @@ impl TableCircuit for RangeTableCircuit type TableConfig = RangeTableConfig; type FixedInput = (); type WitnessInput = (); + type FixedOutput = RowMajorMatrix; fn name() -> String { format!("RANGE_{:?}", RANGE::ROM_TYPE) From fad5cb14647dce2a4ec3df7de653a77a2ef3cef5 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Sat, 19 Oct 2024 11:40:36 +0800 Subject: [PATCH 3/3] wip --- ceno_zkvm/src/tables/mem.rs | 156 ++++++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 ceno_zkvm/src/tables/mem.rs diff --git a/ceno_zkvm/src/tables/mem.rs b/ceno_zkvm/src/tables/mem.rs new file mode 100644 index 000000000..a80da69d0 --- /dev/null +++ b/ceno_zkvm/src/tables/mem.rs @@ -0,0 +1,156 @@ +use ff_ext::ExtensionField; +use goldilocks::SmallField; +use itertools::{Itertools, izip}; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, Fixed, ToExpr, WitIn}, + instructions::riscv::constants::UINT_LIMBS, + scheme::constants::MIN_PAR_SIZE, + set_val, + structs::RAMType, + witness::RowMajorMatrix, +}; + +use super::TableCircuit; + +#[derive(Clone, Debug)] +pub struct MemTableConfig { + addrs: Vec, + + final_v: Vec, +} + +impl MemTableConfig { + const V_LIMBS: usize = UINT_LIMBS + 1; // + 1 for ts + + #[cfg(test)] + pub const ADDR_RANGE: [usize; 2] = [16, 16]; + + #[cfg(not(test))] + pub const ADDR_RANGE: [usize; 2] = [16, 26]; + + pub fn construct_circuit( + cb: &mut CircuitBuilder, + ) -> Result { + // a list of fixed address for non-uniform circuit design + let addrs = (Self::ADDR_RANGE[0]..Self::ADDR_RANGE[1]) + .map(|size| cb.create_fixed(|| format!("addr_{size}",))) + .collect::, ZKVMError>>()?; + + let final_v = (0..Self::V_LIMBS) + .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) + .collect::, ZKVMError>>()?; + + izip!(&addrs, Self::ADDR_RANGE[0]..Self::ADDR_RANGE[1]) + .map(|(addr, size)| { + let init_table_expr = cb.rlc_chip_record( + [ + vec![(RAMType::Memory as usize).into()], + vec![addr.expr()], + (0..Self::V_LIMBS) + .map(|_| Expression::ZERO) + .collect::>>(), + ] + .concat(), + ); + cb.w_table_record( + || format!("init_table_{}", size), + 1 << size, + init_table_expr, + )?; + let final_table_expr = cb.rlc_chip_record( + [ + vec![(RAMType::Memory as usize).into()], + vec![addr.expr()], + final_v.iter().map(|v| v.expr()).collect_vec(), + ] + .concat(), + ); + cb.r_table_record( + || format!("final_table_{}", size), + 1 << size, + final_table_expr, + )?; + Ok(()) + }) + .collect::>()?; + Ok(Self { addrs, final_v }) + } + + pub fn gen_init_state(&self, num_fixed: usize) -> Vec> { + assert_eq!(num_fixed, Self::ADDR_RANGE[1] - Self::ADDR_RANGE[0]); // +1 for addr + + let addrs = (Self::ADDR_RANGE[0]..Self::ADDR_RANGE[1]) + .map(|size| (0u32..(1 << size)).map(|i| i << 2).collect_vec()) + .collect_vec(); // riv32 + + addrs + } + + /// TODO consider taking RowMajorMatrix from externally, since both pattern are 1D vector + /// with that, we can save one allocation cost + pub fn assign_instances( + &self, + num_witness: usize, + final_v: &[u32], // value limb are concated into 1d slice + ) -> Result, ZKVMError> { + assert_eq!(num_witness, Self::V_LIMBS); + assert!(final_v.len().is_power_of_two()); + assert_eq!(final_v.len() % Self::V_LIMBS, 0); + let mut final_table = + RowMajorMatrix::::new(final_v.len() / Self::V_LIMBS, Self::V_LIMBS); + + final_table + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip(final_v.into_par_iter().chunks(Self::V_LIMBS)) + .for_each(|(row, v)| { + self.final_v.iter().zip(v).for_each(|(c, v)| { + set_val!(row, c, *v as u64); + }); + }); + + Ok(final_table) + } +} + +pub struct MemCircuit(PhantomData); + +impl TableCircuit for MemCircuit { + type TableConfig = MemTableConfig; + type FixedInput = (); + type FixedOutput = Vec>; + type WitnessInput = Vec; + + fn name() -> String { + format!("MEM_{:?}", RAMType::Memory) + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + cb.namespace(|| Self::name(), |cb| MemTableConfig::construct_circuit(cb)) + } + + // address vector + fn generate_fixed_traces( + config: &MemTableConfig, + num_fixed: usize, + _input: &Self::FixedInput, + ) -> Vec> { + config.gen_init_state(num_fixed) + } + + fn assign_instances( + config: &MemTableConfig, + num_witin: usize, + _multiplicity: &[HashMap], + final_v: &Self::WitnessInput, + ) -> Result, ZKVMError> { + let mut table = config.assign_instances(num_witin, final_v)?; + Self::padding_zero(&mut table, num_witin)?; + Ok(table) + } +}