diff --git a/gkr-graph/src/circuit_graph_builder.rs b/gkr-graph/src/circuit_graph_builder.rs index d12d54c41..5ae1854cc 100644 --- a/gkr-graph/src/circuit_graph_builder.rs +++ b/gkr-graph/src/circuit_graph_builder.rs @@ -76,6 +76,7 @@ impl CircuitGraphBuilder { ), }; let old_num_instances = self.witness.node_witnesses[*id].n_instances(); + // TODO find way to avoid expensive clone for wit_in let new_instances = match pred { PredType::PredWire(_) => { let new_size = (old_num_instances * out[0].len()) / num_instances; diff --git a/gkr-graph/src/prover.rs b/gkr-graph/src/prover.rs index df7ef8f26..316abd05a 100644 --- a/gkr-graph/src/prover.rs +++ b/gkr-graph/src/prover.rs @@ -21,6 +21,7 @@ impl IOPProverState { expected_max_thread_id: usize, ) -> Result, GKRGraphError> { assert_eq!(target_evals.0.len(), circuit.targets.len()); + assert_eq!(circuit_witness.node_witnesses.len(), circuit.nodes.len()); let mut output_evals = vec![vec![]; circuit.nodes.len()]; let mut wit_out_evals = circuit @@ -36,10 +37,42 @@ impl IOPProverState { let gkr_proofs = izip!(&circuit.nodes, &circuit_witness.node_witnesses) .rev() .map(|(node, witness)| { - // println!("expected_max_thread_id {:?}", expected_max_thread_id); let max_thread_id = witness.n_instances().min(expected_max_thread_id); - // println!("max_thread_id {:?}", max_thread_id); - let timer = std::time::Instant::now(); + + // sanity check for witness poly evaluation + if cfg!(debug_assertions) { + + // TODO figure out a way to do sanity check on output_evals + // it doens't work for now because output evaluation + // might only take partial range of output layer witness + // assert!(output_evals[node.id].len() <= 1); + // if !output_evals[node.id].is_empty() { + // debug_assert_eq!( + // witness + // .output_layer_witness_ref() + // .instances + // .as_slice() + // .original_mle() + // .evaluate(&point_and_eval.point), + // point_and_eval.eval, + // "node_id {} output eval failed", + // node.id, + // ); + // } + + for (witness_id, point_and_eval) in wit_out_evals[node.id].iter().enumerate() { + let mle = witness.witness_out_ref()[witness_id] + .instances + .as_slice() + .original_mle(); + debug_assert_eq!( + mle.evaluate(&point_and_eval.point), + point_and_eval.eval, + "node_id {} output eval failed", + node.id, + ); + } + } let (proof, input_claim) = GKRProverState::prove_parallel( &node.circuit, witness, @@ -48,6 +81,7 @@ impl IOPProverState { max_thread_id, transcript, ); + // println!( // "Proving node {}, label {}, num_instances:{}, took {}s", // node.id, @@ -56,24 +90,31 @@ impl IOPProverState { // timer.elapsed().as_secs_f64() // ); - izip!(&node.preds, input_claim.point_and_evals) + izip!(&node.preds, &input_claim.point_and_evals) .enumerate() - .for_each(|(wire_id, (pred, point_and_eval))| match pred { + .for_each(|(wire_id, (pred_type, point_and_eval))| match pred_type { PredType::Source => { - debug_assert_eq!( - witness.witness_in_ref()[wire_id as usize] + if cfg!(debug_assertions) { + let input_layer_poly = witness.witness_in_ref()[wire_id] .instances .as_slice() - .original_mle() - .evaluate(&point_and_eval.point), - point_and_eval.eval - ); + .original_mle(); + debug_assert_eq!( + input_layer_poly.evaluate(&point_and_eval.point), + point_and_eval.eval, + "mismatch at node.id {:?} wire_id {:?}, input_claim.point_and_evals.point {:?}, node.preds {:?}", + node.id, + wire_id, + input_claim.point_and_evals[0].point, + node.preds + ); + } } - PredType::PredWire(out) | PredType::PredWireDup(out) => { - let point = match pred { + PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => { + let point = match pred_type { PredType::PredWire(_) => point_and_eval.point.clone(), PredType::PredWireDup(out) => { - let node_id = match out { + let pred_node_id = match out { NodeOutputType::OutputLayer(id) => id, NodeOutputType::WireOut(id, _) => id, }; @@ -81,27 +122,28 @@ impl IOPProverState { // [single_instance_slice || // new_instance_index_slice]. The old point // is [single_instance_slices || - // new_instance_index_slices[(new_instance_num_vars - // - old_instance_num_vars)..]] - let old_instance_num_vars = circuit_witness.node_witnesses - [*node_id] + // new_instance_index_slices[(instance_num_vars + // - pred_instance_num_vars)..]] + let pred_instance_num_vars = circuit_witness.node_witnesses + [*pred_node_id] .instance_num_vars(); - let new_instance_num_vars = witness.instance_num_vars(); - let num_vars = - point_and_eval.point.len() - new_instance_num_vars; + let instance_num_vars = witness.instance_num_vars(); + let num_vars = point_and_eval.point.len() - instance_num_vars; [ point_and_eval.point[..num_vars].to_vec(), point_and_eval.point[num_vars - + (new_instance_num_vars - old_instance_num_vars)..] + + (instance_num_vars - pred_instance_num_vars)..] .to_vec(), ] .concat() } _ => unreachable!(), }; - match out { - NodeOutputType::OutputLayer(id) => output_evals[*id] - .push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)), + match pred_out { + NodeOutputType::OutputLayer(id) => { + output_evals[*id] + .push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)) + }, NodeOutputType::WireOut(id, wire_id) => { let evals = &mut wit_out_evals[*id][*wire_id as usize]; assert!( diff --git a/gkr-graph/src/verifier.rs b/gkr-graph/src/verifier.rs index 4aabbd7a4..7094dfb85 100644 --- a/gkr-graph/src/verifier.rs +++ b/gkr-graph/src/verifier.rs @@ -50,52 +50,56 @@ impl IOPVerifierState { let new_instance_num_vars = aux_info.instance_num_vars[node.id]; - izip!(&node.preds, input_claim.point_and_evals).for_each(|(pred, point_and_eval)| { - match pred { - PredType::Source => { - // TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations` for later PCS open? - } - PredType::PredWire(out) | PredType::PredWireDup(out) => { - let old_point = match pred { - PredType::PredWire(_) => point_and_eval.point.clone(), - PredType::PredWireDup(out) => { - let node_id = match out { - NodeOutputType::OutputLayer(id) => *id, - NodeOutputType::WireOut(id, _) => *id, - }; - // Suppose the new point is - // [single_instance_slice || - // new_instance_index_slice]. The old point - // is [single_instance_slices || - // new_instance_index_slices[(new_instance_num_vars - // - old_instance_num_vars)..]] - let old_instance_num_vars = aux_info.instance_num_vars[node_id]; - let num_vars = point_and_eval.point.len() - new_instance_num_vars; - [ - point_and_eval.point[..num_vars].to_vec(), - point_and_eval.point[num_vars - + (new_instance_num_vars - old_instance_num_vars)..] - .to_vec(), - ] - .concat() - } - _ => unreachable!(), - }; - match out { - NodeOutputType::OutputLayer(id) => output_evals[*id] - .push(PointAndEval::new_from_ref(&old_point, &point_and_eval.eval)), - NodeOutputType::WireOut(id, wire_id) => { - let evals = &mut wit_out_evals[*id][*wire_id as usize]; - assert!( - evals.point.is_empty() && evals.eval.is_zero_vartime(), - "unimplemented", - ); - *evals = PointAndEval::new(old_point, point_and_eval.eval); + izip!(&node.preds, input_claim.point_and_evals).for_each( + |(pred_type, point_and_eval)| { + match pred_type { + PredType::Source => { + // TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations` + // for later PCS open? + } + PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => { + let point = match pred_type { + PredType::PredWire(_) => point_and_eval.point.clone(), + PredType::PredWireDup(out) => { + let node_id = match out { + NodeOutputType::OutputLayer(id) => *id, + NodeOutputType::WireOut(id, _) => *id, + }; + // Suppose the new point is + // [single_instance_slice || + // new_instance_index_slice]. The old point + // is [single_instance_slices || + // new_instance_index_slices[(new_instance_num_vars + // - old_instance_num_vars)..]] + let old_instance_num_vars = aux_info.instance_num_vars[node_id]; + let num_vars = + point_and_eval.point.len() - new_instance_num_vars; + [ + point_and_eval.point[..num_vars].to_vec(), + point_and_eval.point[num_vars + + (new_instance_num_vars - old_instance_num_vars)..] + .to_vec(), + ] + .concat() + } + _ => unreachable!(), + }; + match pred_out { + NodeOutputType::OutputLayer(id) => output_evals[*id] + .push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)), + NodeOutputType::WireOut(id, wire_id) => { + let evals = &mut wit_out_evals[*id][*wire_id as usize]; + assert!( + evals.point.is_empty() && evals.eval.is_zero_vartime(), + "unimplemented", + ); + *evals = PointAndEval::new(point, point_and_eval.eval); + } } } } - } - }); + }, + ); } Ok(()) diff --git a/gkr/benches/keccak256.rs b/gkr/benches/keccak256.rs index c7c261e78..b27b37e14 100644 --- a/gkr/benches/keccak256.rs +++ b/gkr/benches/keccak256.rs @@ -10,7 +10,6 @@ use gkr::gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak2 use goldilocks::GoldilocksExt2; use sumcheck::util::is_power_of_2; -// cargo bench --bench keccak256 --features parallel --features flamegraph --package gkr -- --profile-time cfg_if::cfg_if! { if #[cfg(feature = "flamegraph")] { criterion_group! { @@ -48,8 +47,7 @@ fn bench_keccak256(c: &mut Criterion) { #[cfg(feature = "non_pow2_rayon_thread")] { - use sumcheck::local_thread_pool::create_local_pool_once; - use sumcheck::util::ceil_log2; + use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); max_thread_id diff --git a/gkr/examples/keccak256.rs b/gkr/examples/keccak256.rs index d8a1433ac..a105e0930 100644 --- a/gkr/examples/keccak256.rs +++ b/gkr/examples/keccak256.rs @@ -36,8 +36,7 @@ fn main() { #[cfg(feature = "non_pow2_rayon_thread")] { - use sumcheck::local_thread_pool::create_local_pool_once; - use sumcheck::util::ceil_log2; + use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; max_thread_id = 1 << ceil_log2(max_thread_id); create_local_pool_once(max_thread_id, true); } diff --git a/gkr/src/circuit/circuit_layout.rs b/gkr/src/circuit/circuit_layout.rs index 3fb7d7787..8e71bd4cb 100644 --- a/gkr/src/circuit/circuit_layout.rs +++ b/gkr/src/circuit/circuit_layout.rs @@ -132,7 +132,8 @@ impl Circuit { }); let segment = ( wire_ids_in_layer[in_cell_ids[0]], - wire_ids_in_layer[in_cell_ids[in_cell_ids.len() - 1]] + 1, + wire_ids_in_layer[in_cell_ids[in_cell_ids.len() - 1]] + 1, /* + 1 for exclusive + * last index */ ); match ty { InType::Witness(wit_id) => { @@ -258,9 +259,10 @@ impl Circuit { .push(output_subsets.update_wire_id(old_layer_id, old_wire_id)); } OutType::AssertConst(constant) => { + let new_wire_id = output_subsets.update_wire_id(old_layer_id, old_wire_id); output_assert_const.push(GateCIn { idx_in: [], - idx_out: output_subsets.update_wire_id(old_layer_id, old_wire_id), + idx_out: new_wire_id, scalar: ConstantType::Field(i64_to_field(constant)), }); } @@ -288,8 +290,7 @@ impl Circuit { } else { let last_layer = &layers[(layer_id - 1) as usize]; if !last_layer.is_linear() || !layer.copy_to.is_empty() { - curr_sc_steps - .extend([SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2]); + curr_sc_steps.extend([SumcheckStepType::Phase1Step1]); } } @@ -900,7 +901,7 @@ mod tests { // Single input witness, therefore no input phase 2 steps. assert_eq!( circuit.layers[2].sumcheck_steps, - vec![SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2,] + vec![SumcheckStepType::Phase1Step1] ); // There are only one incoming evals since the last layer is linear, and // no subset evals. Therefore, there are no phase1 steps. @@ -931,7 +932,7 @@ mod tests { // Single input witness, therefore no input phase 2 steps. assert_eq!( circuit.layers[1].sumcheck_steps, - vec![SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2] + vec![SumcheckStepType::Phase1Step1] ); // Output layer, single output witness, therefore no output phase 1 steps. assert_eq!( diff --git a/gkr/src/circuit/circuit_witness.rs b/gkr/src/circuit/circuit_witness.rs index ee692b41e..f7351e652 100644 --- a/gkr/src/circuit/circuit_witness.rs +++ b/gkr/src/circuit/circuit_witness.rs @@ -53,6 +53,7 @@ impl CircuitWitness { let mut layer_wit = vec![vec![F::ZERO; circuit.layers[n_layers - 1].size()]; n_instances]; for instance_id in 0..n_instances { + assert_eq!(wits_in.len(), circuit.paste_from_wits_in.len()); for (wit_id, (l, r)) in circuit.paste_from_wits_in.iter().enumerate() { for i in *l..*r { layer_wit[instance_id][i] = @@ -175,34 +176,45 @@ impl CircuitWitness { pub fn add_instances( &mut self, circuit: &Circuit, - wits_in: Vec>, + new_wits_in: Vec>, n_instances: usize, ) where E: ExtensionField, { - assert_eq!(wits_in.len(), circuit.n_witness_in); + assert_eq!(new_wits_in.len(), circuit.n_witness_in); assert!(n_instances.is_power_of_two()); - assert!(!wits_in + assert!(!new_wits_in .iter() .any(|wit_in| wit_in.instances.len() != n_instances)); - let (new_layer_wits, new_wits_out) = - CircuitWitness::new_instances(circuit, &wits_in, &self.challenges, n_instances); + let (inferred_layer_wits, inferred_wits_out) = + CircuitWitness::new_instances(circuit, &new_wits_in, &self.challenges, n_instances); // Merge self and circuit_witness. - for (layer_wit, new_layer_wit) in self.layers.iter_mut().zip(new_layer_wits.into_iter()) { - layer_wit.instances.extend(new_layer_wit.instances); + for (layer_wit, inferred_layer_wit) in + self.layers.iter_mut().zip(inferred_layer_wits.into_iter()) + { + layer_wit.instances.extend(inferred_layer_wit.instances); } - for (wit_out, new_wit_out) in self.witness_out.iter_mut().zip(new_wits_out.into_iter()) { - wit_out.instances.extend(new_wit_out.instances); + for (wit_out, inferred_wits_out) in self + .witness_out + .iter_mut() + .zip(inferred_wits_out.into_iter()) + { + wit_out.instances.extend(inferred_wits_out.instances); } - for (wit_in, new_wit_in) in self.witness_in.iter_mut().zip(wits_in.into_iter()) { + for (wit_in, new_wit_in) in self.witness_in.iter_mut().zip(new_wits_in.into_iter()) { wit_in.instances.extend(new_wit_in.instances); } self.n_instances += n_instances; + + // check correctness in debug build + if cfg!(debug_assertions) { + self.check_correctness(circuit); + } } pub fn instance_num_vars(&self) -> usize { diff --git a/gkr/src/lib.rs b/gkr/src/lib.rs index 7ca4d1284..214126a5b 100644 --- a/gkr/src/lib.rs +++ b/gkr/src/lib.rs @@ -11,5 +11,7 @@ pub mod unsafe_utils; pub mod utils; mod verifier; +pub use sumcheck::util; + #[cfg(test)] mod test; diff --git a/gkr/src/prover.rs b/gkr/src/prover.rs index 26a8c6c87..ee1e6890b 100644 --- a/gkr/src/prover.rs +++ b/gkr/src/prover.rs @@ -8,7 +8,8 @@ use multilinear_extensions::{ virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, }; use rayon::iter::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelIterator, }; use simple_frontend::structs::LayerId; use transcript::Transcript; @@ -85,21 +86,44 @@ impl IOPProverState { transcript, )].to_vec() }, - (SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2, _) => - [ - prover_state - .prove_and_update_state_phase1_step1( - circuit, - circuit_witness, - transcript, - ), - prover_state - .prove_and_update_state_phase1_step2( - circuit, - circuit_witness, - transcript, - ), - ].to_vec() + (SumcheckStepType::Phase1Step1, _, _) => { + let alpha = transcript + .get_and_append_challenge(b"combine subset evals") + .elements; + let hi_num_vars = circuit_witness.instance_num_vars(); + let eq_t = prover_state.to_next_phase_point_and_evals.par_iter().chain(prover_state.subset_point_and_evals[layer_id as usize].par_iter().map(|(_, point_and_eval)| point_and_eval)).map(|point_and_eval|{ + let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + build_eq_x_r_vec(&point_and_eval.point[point_lo_num_vars..]) + }).collect::>>(); + let virtual_polys: Vec> = (0..max_thread_id).into_par_iter().map(|thread_id| { + let span = entered_span!("build_poly"); + let virtual_poly = IOPProverState::build_phase1_step1_sumcheck_poly( + &prover_state, + layer_id, + alpha, + &eq_t, + circuit, + circuit_witness, + (thread_id, max_thread_id), + ); + exit_span!(span); + virtual_poly + }).collect(); + + let (sumcheck_proof, sumcheck_prover_state) = sumcheck::structs::IOPProverState::::prove_batch_polys( + max_thread_id, + virtual_polys.try_into().unwrap(), + transcript, + ); + + let prover_msg = prover_state.combine_phase1_step1_evals( + sumcheck_proof, + sumcheck_prover_state, + ); + + vec![prover_msg] + + } , (SumcheckStepType::Phase2Step1, step2, _) => { let span = entered_span!("phase2_gkr"); @@ -261,19 +285,11 @@ impl IOPProverState { }) .collect_vec(); let to_next_phase_point_and_evals = output_evals; - subset_point_and_evals[0] = wires_out_evals.into_iter().map(|p| (0, p)).collect(); - - let phase1_layer_polys = (0..n_layers) - .into_par_iter() - .map(|layer_id| { - let num_vars = circuit.layers[layer_id].num_vars; - mem::take(&mut circuit_witness.layer_poly( - layer_id.try_into().unwrap(), - num_vars, - (0, 1), - )) - }) + subset_point_and_evals[0] = wires_out_evals + .into_iter() + .map(|p| (0 as LayerId, p)) .collect(); + Self { to_next_phase_point_and_evals, subset_point_and_evals, @@ -282,7 +298,7 @@ impl IOPProverState { assert_point, // Default layer_id: 0, - phase1_layer_polys, + phase1_layer_poly: ArcDenseMultilinearExtension::default(), g1_values: vec![], } } diff --git a/gkr/src/prover/phase1.rs b/gkr/src/prover/phase1.rs index 3e0bb117b..6e52040f7 100644 --- a/gkr/src/prover/phase1.rs +++ b/gkr/src/prover/phase1.rs @@ -1,45 +1,42 @@ use ark_std::{end_timer, start_timer}; use ff::Field; use ff_ext::ExtensionField; -use itertools::{chain, Itertools}; +use itertools::{izip, Itertools}; use multilinear_extensions::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}, - virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, + virtual_poly::{build_eq_x_r_vec_sequential, VirtualPolynomial}, }; -use std::mem; -use sumcheck::entered_span; -use transcript::Transcript; - -#[cfg(feature = "parallel")] -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; +use simple_frontend::structs::LayerId; +use std::sync::Arc; +use sumcheck::{entered_span, util::ceil_log2}; use crate::{ - exit_span, izip_parallizable, - prover::SumcheckState, - structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, Point, PointAndEval}, - tracing_span, - utils::MatrixMLERowFirst, + exit_span, + structs::{ + Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval, SumcheckProof, + }, + utils::{tensor_product, MatrixMLERowFirst}, }; // Prove the items copied from the current layer to later layers for data parallel circuits. -// \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) -// = \sum_y( \sum_j( \alpha^j copy_to[j](ry_j, y) \sum_t( eq(rt_j, t) * layers[i](t || y) ) ) ) impl IOPProverState { - /// Sumcheck 1: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) + /// Sumcheck 1: sigma = \sum_{t || y} \sum_j ( f1^{(j)}(t || y) * g1^{(j)}(t || y) ) /// sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) - /// f1^{(j)}(y) = \sum_t( eq(rt_j, t) * layers[i](t || y) ) - /// g1^{(j)}(y) = \alpha^j copy_to[j](ry_j, y) - #[tracing::instrument(skip_all, name = "prove_and_update_state_phase1_step1")] - pub(super) fn prove_and_update_state_phase1_step1( - &mut self, + /// f1^{(j)}(y) = layers[i](t || y) + /// g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) + /// g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) + pub(super) fn build_phase1_step1_sumcheck_poly( + &self, + layer_id: LayerId, + alpha: E, + eq_t: &Vec>, circuit: &Circuit, circuit_witness: &CircuitWitness, - transcript: &mut Transcript, - ) -> IOPProverStepMessage { + multi_threads_meta: (usize, usize), + ) -> VirtualPolynomial { + let span = entered_span!("preparation"); let timer = start_timer!(|| "Prover sumcheck phase 1 step 1"); - let alpha = transcript - .get_and_append_challenge(b"combine subset evals") - .elements; + let total_length = self.to_next_phase_point_and_evals.len() + self.subset_point_and_evals[self.layer_id as usize].len() + 1; @@ -54,198 +51,136 @@ impl IOPProverState { let lo_num_vars = circuit.layers[self.layer_id as usize].num_vars; let hi_num_vars = circuit_witness.instance_num_vars(); - // sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) - // f1^{(j)}(y) = \sum_t( eq(rt_j, t) * layers[i](t || y) ) - // g1^{(j)}(y) = \alpha^j copy_to[j](ry_j, y) - let span = entered_span!("fg"); - let (mut f1, mut g1): ( - Vec>, - Vec>, - ) = tracing_span!("f1g1").in_scope(|| { - izip_parallizable!(&self.to_next_phase_point_and_evals, &alpha_pows) - .map(|(point_and_eval, alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + // parallel unit logic handling + let (thread_id, max_thread_id) = multi_threads_meta; + let log2_max_thread_id = ceil_log2(max_thread_id); - let f1_j = self.phase1_layer_polys[self.layer_id as usize] - .fix_high_variables(&point_and_eval.point[point_lo_num_vars..]); + exit_span!(span); - let g1_j = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]) - .into_iter() - .take(1 << lo_num_vars) - .map(|eq| *alpha_pow * eq) - .collect_vec(); + // f1^{(j)}(y) = layers[i](t || y) + let f1: Arc> = circuit_witness + .layer_poly::( + (layer_id).try_into().unwrap(), + lo_num_vars, + multi_threads_meta, + ) + .into(); - assert_eq!(g1_j.len(), 1 << lo_num_vars); - ( - f1_j.into(), - DenseMultilinearExtension::from_evaluations_ext_vec(lo_num_vars, g1_j) - .into(), - ) - }) - .unzip() - }); + assert_eq!( + f1.evaluations.len(), + 1 << (hi_num_vars + lo_num_vars - log2_max_thread_id) + ); + let span = entered_span!("g1"); + // g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) + // g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) let copy_to_matrices = &circuit.layers[self.layer_id as usize].copy_to; - let (f1_subset_point_and_evals, g1_subset_point_and_evals): ( - Vec>, - Vec>, - ) = tracing_span!("f1_j_g1_j").in_scope(|| { - izip_parallizable!( - &self.subset_point_and_evals[self.layer_id as usize], - &alpha_pows[self.to_next_phase_point_and_evals.len()..] - ) - .map(|((new_layer_id, point_and_eval), alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let copy_to = ©_to_matrices[new_layer_id]; - let lo_eq_w_p = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - - let f1_j = self.phase1_layer_polys[self.layer_id as usize] - .fix_high_variables(&point_and_eval.point[point_lo_num_vars..]); - - assert!(copy_to.len() <= lo_eq_w_p.len()); - let g1_j = copy_to.as_slice().fix_row_row_first_with_scalar( - &lo_eq_w_p, - lo_num_vars, - alpha_pow, - ); - ( - f1_j.into(), - DenseMultilinearExtension::from_evaluations_ext_vec(lo_num_vars, g1_j).into(), - ) - }) - .unzip() - }); - exit_span!(span); + let g1: ArcDenseMultilinearExtension = { + let gs = izip!(&self.to_next_phase_point_and_evals, &alpha_pows, eq_t) + .map(|(point_and_eval, alpha_pow, eq_t)| { + // g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) + let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - f1.extend(f1_subset_point_and_evals); - g1.extend(g1_subset_point_and_evals); + let eq_y = + build_eq_x_r_vec_sequential(&point_and_eval.point[..point_lo_num_vars]) + .into_iter() + .take(1 << lo_num_vars) + .map(|eq| *alpha_pow * eq) + .collect_vec(); - // sumcheck: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) - let virtual_poly_1 = tracing_span!("virtual_poly").in_scope(|| { - let mut virtual_poly_1 = VirtualPolynomial::new(lo_num_vars); - for (f1_j, g1_j) in f1.into_iter().zip(g1.into_iter()) { - let mut tmp = VirtualPolynomial::new_from_mle(f1_j, E::BaseField::ONE); - tmp.mul_by_mle(g1_j, E::BaseField::ONE); - virtual_poly_1.merge(&tmp); - } - virtual_poly_1 - }); + let eq_t_unit_len = eq_t.len() / max_thread_id; + let start_index = thread_id * eq_t_unit_len; + let g1_j = + tensor_product(&eq_t[start_index..(start_index + eq_t_unit_len)], &eq_y); - let (sumcheck_proof_1, prover_state) = - SumcheckState::prove_parallel(virtual_poly_1, transcript); - let (f1, g1): (Vec<_>, Vec<_>) = prover_state - .get_mle_final_evaluations() - .into_iter() - .enumerate() - .partition(|(i, _)| i % 2 == 0); - let eval_value_1 = f1.into_iter().map(|(_, f1_j)| f1_j).collect_vec(); + assert_eq!( + g1_j.len(), + (1 << (hi_num_vars + lo_num_vars - log2_max_thread_id)) + ); - self.to_next_step_point = sumcheck_proof_1.point.clone(); - self.g1_values = g1.into_iter().map(|(_, g1_j)| g1_j).collect_vec(); + g1_j + }) + .chain( + izip!( + &self.subset_point_and_evals[self.layer_id as usize], + &alpha_pows[self.to_next_phase_point_and_evals.len()..], + eq_t.iter().skip(self.to_next_phase_point_and_evals.len()) + ) + .map( + |((new_layer_id, point_and_eval), alpha_pow, eq_t)| { + let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + let copy_to = ©_to_matrices[new_layer_id]; + let lo_eq_w_p = build_eq_x_r_vec_sequential( + &point_and_eval.point[..point_lo_num_vars], + ); + + // g2^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) + let eq_t_unit_len = eq_t.len() / max_thread_id; + let start_index = thread_id * eq_t_unit_len; + let g2_j = tensor_product( + &eq_t[start_index..(start_index + eq_t_unit_len)], + ©_to.as_slice().fix_row_row_first_with_scalar( + &lo_eq_w_p, + lo_num_vars, + alpha_pow, + ), + ); + + assert_eq!( + g2_j.len(), + (1 << (hi_num_vars + lo_num_vars - log2_max_thread_id)) + ); + g2_j + }, + ), + ) + .collect::>>(); + + DenseMultilinearExtension::from_evaluations_ext_vec( + hi_num_vars + lo_num_vars - log2_max_thread_id, + gs.into_iter() + .fold(vec![E::ZERO; 1 << f1.num_vars], |mut acc, g| { + assert_eq!(1 << f1.num_vars, g.len()); + acc.iter_mut().enumerate().for_each(|(i, v)| *v += g[i]); + acc + }), + ) + .into() + }; + exit_span!(span); + // sumcheck: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y)) + let span = entered_span!("virtual_poly"); + let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1, E::BaseField::ONE); + virtual_poly_1.mul_by_mle(g1, E::BaseField::ONE); + exit_span!(span); end_timer!(timer); - IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_1, - sumcheck_eval_values: eval_value_1, - } + + virtual_poly_1 } - /// Sumcheck 2: sigma = \sum_t( \sum_j( f2^{(j)}(t) ) ) * g2(t) - /// sigma = \sum_j( f1^{(j)}(ry) * g1^{(j)}(ry) ) - /// f2(t) = layers[i](t || ry) - /// g2^{(j)}(t) = \alpha^j copy_to[j](ry_j, ry) eq(rt_j, t) - #[tracing::instrument(skip_all, name = "prove_and_update_state_phase1_step2")] - pub(super) fn prove_and_update_state_phase1_step2( + pub(super) fn combine_phase1_step1_evals( &mut self, - _: &Circuit, - circuit_witness: &CircuitWitness, - transcript: &mut Transcript, + sumcheck_proof_1: SumcheckProof, + prover_state: sumcheck::structs::IOPProverState, ) -> IOPProverStepMessage { - let timer = start_timer!(|| "Prover sumcheck phase 1 step 2"); - let hi_num_vars = circuit_witness.instance_num_vars(); - - let span = entered_span!("f2_fix_variables"); - // f2(t) = layers[i](t || ry) - let f2 = mem::take(&mut self.phase1_layer_polys[self.layer_id as usize]) - .fix_variables_parallel(&self.to_next_step_point) - .into(); - exit_span!(span); - - // g2^{(j)}(t) = \alpha^j copy_to[j](ry_j, ry) eq(rt_j, t) - let output_points: Vec<&Point> = chain![ - self.to_next_phase_point_and_evals.iter().map(|x| &x.point), - self.subset_point_and_evals[self.layer_id as usize] - .iter() - .map(|x| &x.1.point), - ] - .collect(); - let span = entered_span!("g2"); - - #[cfg(not(feature = "parallel"))] - let zeros = vec![E::ZERO; 1 << hi_num_vars]; - - #[cfg(feature = "parallel")] - let zeros = || vec![E::ZERO; 1 << hi_num_vars]; - - let g2 = izip_parallizable!(output_points, &self.g1_values) - .map(|(point, &g1_value)| { - let point_lo_num_vars = point.len() - hi_num_vars; - build_eq_x_r_vec(&point[point_lo_num_vars..]) - .into_iter() - .map(|eq| g1_value * eq) - .collect_vec() - }) - .fold(zeros, |acc, nxt| { - acc.into_iter() - .zip(nxt.into_iter()) - .map(|(a, b)| a + b) - .collect_vec() - }); - - #[cfg(not(feature = "parallel"))] - let g2 = DenseMultilinearExtension::from_evaluations_ext_vec(hi_num_vars, g2); - - // When rayon is used, the `fold` operation results in a iterator of `Vec` rather than a single `Vec`. In this case, we simply need to sum them. - #[cfg(feature = "parallel")] - let g2 = DenseMultilinearExtension::from_evaluations_ext_vec( - hi_num_vars, - g2.reduce(zeros, |acc, nxt| { - acc.into_iter() - .zip(nxt.into_iter()) - .map(|(a, b)| a + b) - .collect_vec() - }), - ); - - exit_span!(span); - - // sumcheck: sigma = \sum_t( \sum_j( g2^{(j)}(t) ) ) * f2(t) - let mut virtual_poly_2 = VirtualPolynomial::new_from_mle(f2, E::BaseField::ONE); - virtual_poly_2.mul_by_mle(g2.into(), E::BaseField::ONE); - - let (sumcheck_proof_2, prover_state) = - SumcheckState::prove_parallel(virtual_poly_2, transcript); - let (mut f2, _): (Vec<_>, Vec<_>) = prover_state + let (mut f1, _): (Vec<_>, Vec<_>) = prover_state .get_mle_final_evaluations() .into_iter() .enumerate() .partition(|(i, _)| i % 2 == 0); - let eval_value_2 = f2.remove(0).1; - self.to_next_step_point = [ - mem::take(&mut self.to_next_step_point), - sumcheck_proof_2.point.clone(), - ] - .concat(); + let eval_value_1 = f1.remove(0).1; + + self.to_next_step_point = sumcheck_proof_1.point.clone(); self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref( &self.to_next_step_point, - &eval_value_2, + &eval_value_1, )]; self.subset_point_and_evals[self.layer_id as usize].clear(); - end_timer!(timer); IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_2, - sumcheck_eval_values: vec![eval_value_2], + sumcheck_proof: sumcheck_proof_1, + sumcheck_eval_values: vec![eval_value_1], } } } diff --git a/gkr/src/prover/phase1_output.rs b/gkr/src/prover/phase1_output.rs index 3c26eb112..385e49abf 100644 --- a/gkr/src/prover/phase1_output.rs +++ b/gkr/src/prover/phase1_output.rs @@ -1,15 +1,16 @@ use ark_std::{end_timer, start_timer}; use ff::Field; use ff_ext::ExtensionField; -use itertools::{chain, Itertools}; +use itertools::{chain, izip, Itertools}; use multilinear_extensions::{ - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}, + mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, FieldType}, virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, }; -use std::{iter, mem}; +use std::{iter, mem, sync::Arc}; use transcript::Transcript; use crate::{ + circuit::EvaluateConstant, izip_parallizable, prover::SumcheckState, structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval}, @@ -21,12 +22,13 @@ use rayon::iter::{IndexedParallelIterator, ParallelIterator}; // Prove the items copied from the output layer to the output witness for data parallel circuits. // \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) -// = \sum_y( \sum_j( \alpha^j (eq or copy_to[j] or assert_subset_eq)(ry_j, y) \sum_t( eq(rt_j, t) * layers[i](t || y) ) ) ) +// = \sum_y( \sum_j( \alpha^j (eq or copy_to[j] or assert_subset_eq)(ry_j, y) \sum_t( eq(rt_j, +// t) * layers[i](t || y) ) ) ) impl IOPProverState { /// Sumcheck 1: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) /// sigma = \sum_j( \alpha^j * wit_out_eval[j](rt_j || ry_j) ) /// + \alpha^{wit_out_eval[j].len()} * assert_const(rt || ry) ) - /// f1^{(j)}(y) = \sum_t( eq(rt_j, t) * layers[i](t || y) ) + /// f1^{(j)}(y) = layers[i](rt_j || y) /// g1^{(j)}(y) = \alpha^j eq(ry_j, y) // or \alpha^j copy_to[j](ry_j, y) // or \alpha^j assert_subset_eq(ry, y) @@ -41,6 +43,7 @@ impl IOPProverState { let alpha = transcript .get_and_append_challenge(b"combine subset evals") .elements; + let total_length = self.to_next_phase_point_and_evals.len() + self.subset_point_and_evals[self.layer_id as usize].len() + 1; @@ -55,11 +58,15 @@ impl IOPProverState { let lo_num_vars = circuit.layers[self.layer_id as usize].num_vars; let hi_num_vars = circuit_witness.instance_num_vars(); + self.phase1_layer_poly = circuit_witness + .layer_poly::((self.layer_id).try_into().unwrap(), lo_num_vars, (0, 1)) + .into(); + // sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) - // f1^{(j)}(y) = \sum_t( eq(rt_j, t) * layers[i](t || y) ) + // f1^{(j)}(y) = layers[i](rt_j || y) // g1^{(j)}(y) = \alpha^j eq(ry_j, y) // or \alpha^j copy_to[j](ry_j, y) - // or \alpha^j assert_subset_eq[j](ry, y) + // or \alpha^j assert_subset_eq(ry, y) // TODO: Double check the soundness here. let (mut f1, mut g1): ( Vec>, @@ -70,7 +77,8 @@ impl IOPProverState { let point = &point_and_eval.point; let lo_eq_w_p = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - let f1_j = self.phase1_layer_polys[self.layer_id as usize] + let f1_j = self + .phase1_layer_poly .fix_high_variables(&point[point_lo_num_vars..]); let g1_j = lo_eq_w_p @@ -88,18 +96,20 @@ impl IOPProverState { let (f1_copy_to, g1_copy_to): ( Vec>, Vec>, - ) = izip_parallizable!( + ) = izip!( &circuit.copy_to_wits_out, &self.subset_point_and_evals[self.layer_id as usize], &alpha_pows[self.to_next_phase_point_and_evals.len()..] ) .map(|(copy_to, (_, point_and_eval), alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; let point = &point_and_eval.point; - let lo_eq_w_p = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); + let point_lo_num_vars = point.len() - hi_num_vars; + + let lo_eq_w_p = build_eq_x_r_vec(&point[..point_lo_num_vars]); assert!(copy_to.len() <= lo_eq_w_p.len()); - let f1_j = self.phase1_layer_polys[self.layer_id as usize] + let f1_j = self + .phase1_layer_poly .fix_high_variables(&point[point_lo_num_vars..]); let g1_j = copy_to.as_slice().fix_row_row_first_with_scalar( @@ -118,12 +128,14 @@ impl IOPProverState { f1.extend(f1_copy_to); g1.extend(g1_copy_to); - let f1_j = self.phase1_layer_polys[self.layer_id as usize] + let f1_j = self + .phase1_layer_poly .fix_high_variables(&self.assert_point[lo_num_vars..]); f1.push(f1_j.into()); let alpha_pow = alpha_pows.last().unwrap(); let lo_eq_w_p = build_eq_x_r_vec(&self.assert_point[..lo_num_vars]); + let mut g_last = vec![E::ZERO; 1 << lo_num_vars]; circuit.assert_consts.iter().for_each(|gate| { g_last[gate.idx_out as usize] = lo_eq_w_p[gate.idx_out as usize] * alpha_pow; @@ -175,9 +187,9 @@ impl IOPProverState { let hi_num_vars = circuit_witness.instance_num_vars(); // f2(t) = layers[i](t || ry) - let f2 = self.phase1_layer_polys[self.layer_id as usize] - .fix_variables_parallel(&self.to_next_step_point) - .into(); + let mut f2 = mem::take(&mut self.phase1_layer_poly); + + Arc::make_mut(&mut f2).fix_variables_in_place_parallel(&self.to_next_step_point); // g2(t) = \sum_j \alpha^j (eq or copy_to[j] or assert_subset)(ry_j, ry) eq(rt_j, t) let output_points = chain![ diff --git a/gkr/src/prover/phase2.rs b/gkr/src/prover/phase2.rs index a8f786039..be8b60be0 100644 --- a/gkr/src/prover/phase2.rs +++ b/gkr/src/prover/phase2.rs @@ -45,8 +45,8 @@ macro_rules! prepare_stepx_g_fn { // The number of terms depends on the gate. // Here is an example of degree 3: // layers[i](rt || ry) = \sum_{s1}( \sum_{s2}( \sum_{s3}( \sum_{x1}( \sum_{x2}( \sum_{x3}( -// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2) * layers[i + 1](s3 || x3) -// ) ) ) ) ) ) + sum_s1( sum_s2( sum_{x1}( sum_{x2}( +// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2) +// * layers[i + 1](s3 || x3) ) ) ) ) ) ) + sum_s1( sum_s2( sum_{x1}( sum_{x2}( // eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2) // ) ) ) ) + \sum_{s1}( \sum_{x1}( // eq(rt, s1) * add(ry, x1) * layers[i + 1](s1 || x1) @@ -54,16 +54,17 @@ macro_rules! prepare_stepx_g_fn { // \sum_j eq(rt, s1) paste_from[j](ry, x1) * subset[j][i](s1 || x1) // ) ) + add_const(ry) impl IOPProverState { - /// Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * g1'_j(s1 || x1) - /// sigma = layers[i](rt || ry) - add_const(ry), + /// Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * + /// g1'_j(s1 || x1) sigma = layers[i](rt || ry) - add_const(ry), /// f1(s1 || x1) = layers[i + 1](s1 || x1) /// g1(s1 || x1) = \sum_{s2}( \sum_{s3}( \sum_{x2}( \sum_{x3}( - /// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + 1](s3 || x3) - /// ) ) ) ) + \sum_{s2}( \sum_{x2}( + /// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + + /// 1](s3 || x3) ) ) ) ) + \sum_{s2}( \sum_{x2}( /// eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s2 || x2) /// ) ) + eq(rt, s1) * add(ry, x1) /// f1'^{(j)}(s1 || x1) = subset[j][i](s1 || x1) /// g1'^{(j)}(s1 || x1) = eq(rt, s1) paste_from[j](ry, x1) + /// s1 || x1 || 0, s1 || x1 || 1 #[tracing::instrument(skip_all, name = "build_phase2_step1_sumcheck_poly")] pub(super) fn build_phase2_step1_sumcheck_poly( eq: &[Vec; 1], @@ -105,8 +106,8 @@ impl IOPProverState { let f1 = phase2_next_layer_polys_v2.clone(); // g1(s1 || x1) = \sum_{s2}( \sum_{s3}( \sum_{x2}( \sum_{x3}( - // eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + 1](s3 || x3) - // ) ) ) ) + \sum_{s2}( \sum_{x2}( + // eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + + // 1](s3 || x3) ) ) ) ) + \sum_{s2}( \sum_{x2}( // eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s2 || x2) // ) ) + eq(rt, s1) * add(ry, x1) let mut g1 = vec![E::ZERO; 1 << f1.num_vars]; @@ -176,7 +177,8 @@ impl IOPProverState { Vec>, ) = ([vec![f1], f1_j].concat(), [vec![g1], g1_j].concat()); - // sumcheck: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * g1'_j(s1 || x1) + // sumcheck: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * + // g1'_j(s1 || x1) let mut virtual_poly_1 = VirtualPolynomial::new(f[0].num_vars); for (f, g) in f.into_iter().zip(g.into_iter()) { let mut tmp = VirtualPolynomial::new_from_mle(f, E::BaseField::ONE); diff --git a/gkr/src/prover/phase2_input.rs b/gkr/src/prover/phase2_input.rs index c3bfbe4db..8a5961fac 100644 --- a/gkr/src/prover/phase2_input.rs +++ b/gkr/src/prover/phase2_input.rs @@ -16,6 +16,7 @@ use crate::{ izip_parallizable, prover::SumcheckState, structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval}, + utils::MultilinearExtensionFromVectors, }; // Prove the computation in the current layer for data parallel circuits. @@ -65,7 +66,6 @@ impl IOPProverState { } g[subset_wire_id] = eq_y_ry[new_wire_id]; } - ( { let mut f = DenseMultilinearExtension::from_evaluations_vec( @@ -133,8 +133,8 @@ impl IOPProverState { .partition(|(i, _)| i % 2 == 0); let eval_values_f = f_vec .into_iter() - .take(wits_in.len()) .map(|(_, f)| f) + .take(wits_in.len()) .collect_vec(); self.to_next_phase_point_and_evals = izip!(paste_from_wit_in.iter(), eval_values_f.iter()) @@ -151,6 +151,7 @@ impl IOPProverState { PointAndEval::new_from_ref(&point, &wit_in_eval) }) .collect_vec(); + self.to_next_step_point = [&eval_point, hi_point].concat(); end_timer!(timer); diff --git a/gkr/src/structs.rs b/gkr/src/structs.rs index ec4489551..3f667a7a6 100644 --- a/gkr/src/structs.rs +++ b/gkr/src/structs.rs @@ -62,14 +62,12 @@ pub struct IOPProverState { /// The point to the next step. pub(crate) to_next_step_point: Point, - /// layer poly for phase 1 which point to current layer - pub(crate) phase1_layer_polys: Vec>, // Especially for output phase1. + pub(crate) phase1_layer_poly: ArcDenseMultilinearExtension, pub(crate) assert_point: Point, // Especially for phase1. pub(crate) g1_values: Vec, - // Especially for phase2. } /// Represent the verifier state for each layer in the IOP protocol. @@ -126,7 +124,6 @@ pub(crate) enum SumcheckStepType { OutputPhase1Step1, OutputPhase1Step2, Phase1Step1, - Phase1Step2, Phase2Step1, Phase2Step2, Phase2Step2NoStep3, @@ -151,11 +148,11 @@ pub struct Layer { // Gates. Should be all None if it's the input layer. pub(crate) add_consts: Vec>>, pub(crate) adds: Vec>>, - pub(crate) adds_fanin_mapping: [BTreeMap>>>; 1], // grouping for 1 fanins + pub(crate) adds_fanin_mapping: [BTreeMap>>>; 1], /* grouping for 1 fanins */ pub(crate) mul2s: Vec>>, - pub(crate) mul2s_fanin_mapping: [BTreeMap>>>; 2], // grouping for 2 fanins + pub(crate) mul2s_fanin_mapping: [BTreeMap>>>; 2], /* grouping for 2 fanins */ pub(crate) mul3s: Vec>>, - pub(crate) mul3s_fanin_mapping: [BTreeMap>>>; 3], // grouping for 3 fanins + pub(crate) mul3s_fanin_mapping: [BTreeMap>>>; 3], /* grouping for 3 fanins */ /// The corresponding wires copied from this layer to later layers. It is /// (later layer id -> current wire id to be copied). It stores the non-zero @@ -201,7 +198,7 @@ pub struct Circuit { pub paste_from_wits_in: Vec<(CellId, CellId)>, /// The endpoints in the input layer copied from counter. pub paste_from_counter_in: Vec<(usize, (CellId, CellId))>, - /// The endpoints in the output layer copied to each output witness. + /// The endpoints in the input layer copied from constants pub paste_from_consts_in: Vec<(i64, (CellId, CellId))>, /// The wires copied to the output witness pub copy_to_wits_out: Vec>, diff --git a/gkr/src/utils.rs b/gkr/src/utils.rs index 284e9f09d..f9436e6e7 100644 --- a/gkr/src/utils.rs +++ b/gkr/src/utils.rs @@ -20,7 +20,8 @@ pub fn i64_to_field(x: i64) -> F { /// This is to compute a segment indicator. Specifically, it is an MLE of the /// following vector: /// segment_{\mathbf{x}} -/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i b_i + (1 - x_i)(1 - b_i)) +/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i b_i + (1 - x_i)(1 - +/// b_i)) pub(crate) fn segment_eval_greater_than(min_idx: usize, a: &[E]) -> E { let running_product2 = { let mut running_product = vec![E::ZERO; a.len() + 1]; @@ -50,7 +51,8 @@ pub(crate) fn segment_eval_greater_than(min_idx: usize, a: &[ /// This is to compute a variant of eq(\mathbf{x}, \mathbf{y}) for indices in /// (min_idx, 2^n]. Specifically, it is an MLE of the following vector: /// partial_eq_{\mathbf{x}}(\mathbf{y}) -/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - y_i)(1 - b_i)) +/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - +/// y_i)(1 - b_i)) #[allow(dead_code)] pub(crate) fn eq_eval_greater_than(min_idx: usize, a: &[F], b: &[F]) -> F { assert!(a.len() >= b.len()); @@ -97,7 +99,8 @@ pub(crate) fn eq_eval_greater_than(min_idx: usize, a: &[F], b: &[ /// This is to compute a variant of eq(\mathbf{x}, \mathbf{y}) for indices in /// [0, max_idx]. Specifically, it is an MLE of the following vector: /// partial_eq_{\mathbf{x}}(\mathbf{y}) -/// = \sum_{\mathbf{b}=0}^{max_idx} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - y_i)(1 - b_i)) +/// = \sum_{\mathbf{b}=0}^{max_idx} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - y_i)(1 - +/// b_i)) pub(crate) fn eq_eval_less_or_equal_than(max_idx: usize, a: &[E], b: &[E]) -> E { assert!(a.len() >= b.len()); // Compute running product of ( x_i y_i + (1 - x_i)(1 - y_i) )_{0 <= i <= n} @@ -176,7 +179,7 @@ pub fn eq4_eval(x: &[E], y: &[E], z: &[E], w: &[E]) -> E { res } -pub fn tensor_product(a: &[F], b: &[F]) -> Vec { +pub fn tensor_product(a: &[F], b: &[F]) -> Vec { let mut res = vec![F::ZERO; a.len() * b.len()]; for i in 0..a.len() { for j in 0..b.len() { diff --git a/gkr/src/verifier.rs b/gkr/src/verifier.rs index b197e9e58..d84ff6ce2 100644 --- a/gkr/src/verifier.rs +++ b/gkr/src/verifier.rs @@ -1,6 +1,6 @@ use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; -use itertools::Itertools; +use itertools::{izip, Itertools}; use simple_frontend::structs::{ChallengeConst, LayerId}; use std::collections::HashMap; use transcript::Transcript; @@ -8,7 +8,8 @@ use transcript::Transcript; use crate::{ error::GKRError, structs::{ - Circuit, GKRInputClaims, IOPProof, IOPVerifierState, PointAndEval, SumcheckStepType, + Circuit, GKRInputClaims, IOPProof, IOPProverStepMessage, IOPVerifierState, PointAndEval, + SumcheckStepType, }, }; @@ -44,48 +45,44 @@ impl IOPVerifierState { circuit.layers[0].num_vars + instance_num_vars, ); - let mut step_proof_iter = proof.sumcheck_proofs.into_iter(); + let mut sumcheck_proofs_iter = proof.sumcheck_proofs.into_iter(); for layer_id in 0..circuit.layers.len() { let timer = start_timer!(|| format!("Verify layer {}", layer_id)); verifier_state.layer_id = layer_id as LayerId; let layer = &circuit.layers[layer_id as usize]; - for step in layer.sumcheck_steps.iter() { - let step_msg = step_proof_iter - .next() - .ok_or(GKRError::VerifyError("Wrong number of step proofs"))?; + for (step, step_proof) in izip!(layer.sumcheck_steps.iter(), &mut sumcheck_proofs_iter) + { match step { SumcheckStepType::OutputPhase1Step1 => verifier_state .verify_and_update_state_output_phase1_step1( - circuit, step_msg, transcript, + circuit, step_proof, transcript, )?, SumcheckStepType::OutputPhase1Step2 => verifier_state .verify_and_update_state_output_phase1_step2( - circuit, step_msg, transcript, + circuit, step_proof, transcript, )?, SumcheckStepType::Phase1Step1 => verifier_state - .verify_and_update_state_phase1_step1(circuit, step_msg, transcript)?, - SumcheckStepType::Phase1Step2 => verifier_state - .verify_and_update_state_phase1_step2(circuit, step_msg, transcript)?, + .verify_and_update_state_phase1_step1(circuit, step_proof, transcript)?, SumcheckStepType::Phase2Step1 => verifier_state - .verify_and_update_state_phase2_step1(circuit, step_msg, transcript)?, + .verify_and_update_state_phase2_step1(circuit, step_proof, transcript)?, SumcheckStepType::Phase2Step2 => verifier_state .verify_and_update_state_phase2_step2( - circuit, step_msg, transcript, false, + circuit, step_proof, transcript, false, )?, SumcheckStepType::Phase2Step2NoStep3 => verifier_state .verify_and_update_state_phase2_step2( - circuit, step_msg, transcript, true, + circuit, step_proof, transcript, true, )?, SumcheckStepType::Phase2Step3 => verifier_state - .verify_and_update_state_phase2_step3(circuit, step_msg, transcript)?, + .verify_and_update_state_phase2_step3(circuit, step_proof, transcript)?, SumcheckStepType::LinearPhase2Step1 => verifier_state .verify_and_update_state_linear_phase2_step1( - circuit, step_msg, transcript, + circuit, step_proof, transcript, )?, SumcheckStepType::InputPhase2Step1 => verifier_state .verify_and_update_state_input_phase2_step1( - circuit, step_msg, transcript, + circuit, step_proof, transcript, )?, _ => unreachable!(), } diff --git a/gkr/src/verifier/phase1.rs b/gkr/src/verifier/phase1.rs index 6d30c4e08..6a3b0c947 100644 --- a/gkr/src/verifier/phase1.rs +++ b/gkr/src/verifier/phase1.rs @@ -2,7 +2,7 @@ use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; use itertools::{chain, izip, Itertools}; use multilinear_extensions::virtual_poly::{build_eq_x_r_vec, eq_eval, VPAuxInfo}; -use std::{marker::PhantomData, mem}; +use std::marker::PhantomData; use transcript::Transcript; use crate::{ @@ -52,27 +52,42 @@ impl IOPVerifierState { .fold(E::ZERO, |acc, ((_, point_and_eval), alpha_pow)| { acc + point_and_eval.eval * alpha_pow }); - // Sumcheck 1: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) - // f1^{(j)}(y) = \sum_t( eq(rt_j, t) * layers[i](t || y) ) - // g1^{(j)}(y) = \alpha^j copy_to[j](ry_j, y) + + // Sumcheck: sigma = \sum_{t || y}( \sum_j f1^{(j)}( t || y) * g1^{(j)}(t || y) ) + // f1^{(j)}(y) = layers[i](t || y) + // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) + // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) let claim_1 = SumcheckState::verify( sigma_1, &step_msg.sumcheck_proof, &VPAuxInfo { max_degree: 2, - num_variables: lo_num_vars, + num_variables: lo_num_vars + hi_num_vars, phantom: PhantomData, }, transcript, ); + let claim1_point = claim_1.point.iter().map(|x| x.elements).collect_vec(); - let eq_y_ry = build_eq_x_r_vec(&claim1_point); + let claim1_point_lo_num_vars = claim1_point.len() - hi_num_vars; + let eq_y_ry = build_eq_x_r_vec(&claim1_point[..claim1_point_lo_num_vars]); - self.g1_values = chain![ + assert_eq!(step_msg.sumcheck_eval_values.len(), 1); + let f_value = step_msg.sumcheck_eval_values[0]; + + let g_value: E = chain![ izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()).map( |(point_and_eval, alpha_pow)| { let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - eq_eval(&point_and_eval.point[..point_lo_num_vars], &claim1_point) * alpha_pow + let eq_t = eq_eval( + &point_and_eval.point[point_lo_num_vars..], + &claim1_point[(claim1_point.len() - hi_num_vars)..], + ); + let eq_y = eq_eval( + &point_and_eval.point[..point_lo_num_vars], + &claim1_point[..point_lo_num_vars], + ); + eq_t * eq_y * alpha_pow } ), izip!( @@ -83,88 +98,28 @@ impl IOPVerifierState { ) .map(|((new_layer_id, point_and_eval), alpha_pow)| { let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + let eq_t = eq_eval( + &point_and_eval.point[point_lo_num_vars..], + &claim1_point[(claim1_point.len() - hi_num_vars)..], + ); let eq_yj_ryj = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - circuit.layers[self.layer_id as usize].copy_to[new_layer_id] + eq_t * circuit.layers[self.layer_id as usize].copy_to[new_layer_id] .as_slice() .eval_row_first(&eq_yj_ryj, &eq_y_ry) * alpha_pow }), ] - .collect_vec(); - - let got_value_1 = izip!(step_msg.sumcheck_eval_values.iter(), self.g1_values.iter()) - .fold(E::ZERO, |acc, (&f1, g1)| acc + f1 * g1); - end_timer!(timer); - - if claim_1.expected_evaluation != got_value_1 { - return Err(GKRError::VerifyError("phase1 step1 failed")); - } - - self.to_next_step_point_and_eval = - PointAndEval::new(claim1_point, claim_1.expected_evaluation); - - Ok(()) - } - - pub(super) fn verify_and_update_state_phase1_step2( - &mut self, - _: &Circuit, - step_msg: IOPProverStepMessage, - transcript: &mut Transcript, - ) -> Result<(), GKRError> { - let timer = start_timer!(|| "Verifier sumcheck phase 1 step 2"); - let hi_num_vars = self.instance_num_vars; + .sum(); - // sigma = \sum_j( f1^{(j)}(ry) * g1^{(j)}(ry) ) - // Sumcheck 2: sigma = \sum_t( \sum_j( g2^{(j)}(t) ) ) * f2(t) - // f2(t) = layers[i](t || ry) - // g2^{(j)}(t) = \alpha^j copy_to[j](ry_j, r_y) eq(rt_j, t) - let claim_2 = SumcheckState::verify( - self.to_next_step_point_and_eval.eval, - &step_msg.sumcheck_proof, - &VPAuxInfo { - max_degree: 2, - num_variables: hi_num_vars, - phantom: PhantomData, - }, - transcript, - ); - let claim2_point = claim_2.point.iter().map(|x| x.elements).collect_vec(); - - let output_points = chain![ - self.to_next_phase_point_and_evals.iter().map(|x| &x.point), - self.subset_point_and_evals[self.layer_id as usize] - .iter() - .map(|x| &x.1.point), - ]; - let f2_value = step_msg.sumcheck_eval_values[0]; - let g2_value = output_points - .zip(self.g1_values.iter()) - .map(|(point, g1_value)| { - let point_lo_num_vars = point.len() - hi_num_vars; - *g1_value * eq_eval(&point[point_lo_num_vars..], &claim2_point) - }) - .fold(E::ZERO, |acc, value| acc + value); - - let got_value_2 = f2_value * g2_value; + let got_value = f_value * g_value; end_timer!(timer); - if claim_2.expected_evaluation != got_value_2 { - return Err(GKRError::VerifyError("output phase1 step2 failed")); + if claim_1.expected_evaluation != got_value { + return Err(GKRError::VerifyError("phase1 step1 failed")); } + self.to_next_step_point_and_eval = PointAndEval::new_from_ref(&claim1_point, &f_value); + self.to_next_phase_point_and_evals = vec![self.to_next_step_point_and_eval.clone()]; - self.to_next_step_point_and_eval = PointAndEval::new( - [ - mem::take(&mut self.to_next_step_point_and_eval.point), - claim2_point, - ] - .concat(), - f2_value, - ); - self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref( - &self.to_next_step_point_and_eval.point, - &f2_value, - )]; self.subset_point_and_evals[self.layer_id as usize].clear(); Ok(()) diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index c2f3d5515..018b782d2 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -8,7 +8,6 @@ use ff_ext::ExtensionField; use rayon::iter::IntoParallelRefIterator; use serde::{Deserialize, Serialize}; -#[cfg(feature = "parallel")] use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; #[derive(Clone, PartialEq, Eq, Hash, Default, Debug, Serialize, Deserialize)] @@ -68,12 +67,14 @@ impl DenseMultilinearExtension { } } - /// Identical to [`from_evaluations_slice`], with and exception that evaluation vector is in extension field + /// Identical to [`from_evaluations_slice`], with and exception that evaluation vector is in + /// extension field pub fn from_evaluations_ext_slice(num_vars: usize, evaluations: &[E]) -> Self { Self::from_evaluations_ext_vec(num_vars, evaluations.to_vec()) } - /// Identical to [`from_evaluations_vec`], with and exception that evaluation vector is in extension field + /// Identical to [`from_evaluations_vec`], with and exception that evaluation vector is in + /// extension field pub fn from_evaluations_ext_vec(num_vars: usize, evaluations: Vec) -> Self { // assert that the number of variables matches the size of evaluations // TODO: return error. @@ -146,7 +147,8 @@ impl DenseMultilinearExtension { let nv = self.num_vars; // evaluate single variable of partial point from left to right for (i, point) in partial_point.iter().enumerate() { - // override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1, b2,..bt, 1] in parallel + // override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1, + // b2,..bt, 1] in parallel match &mut self.evaluations { FieldType::Base(evaluations) => { let evaluations_ext = evaluations @@ -443,7 +445,8 @@ impl DenseMultilinearExtension { // evaluate single variable of partial point from left to right for (i, point) in partial_point.iter().enumerate() { let max_log2_size = nv - i; - // override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1, b2,..bt, 1] in parallel + // override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1, + // b2,..bt, 1] in parallel match &mut self.evaluations { FieldType::Base(evaluations) => { let evaluations_ext = evaluations diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index 34b96a047..eb3edac92 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -1,18 +1,17 @@ -use std::cmp::max; -use std::mem; -use std::{collections::HashMap, marker::PhantomData, sync::Arc}; - -use crate::mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}; -use crate::util::bit_decompose; -use ark_std::rand::Rng; -use ark_std::{end_timer, start_timer}; -use ff::Field; -use ff::PrimeField; +use std::{cmp::max, collections::HashMap, marker::PhantomData, mem, sync::Arc}; + +use crate::{ + mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}, + util::bit_decompose, +}; +use ark_std::{end_timer, rand::Rng, start_timer}; +use ff::{Field, PrimeField}; use ff_ext::ExtensionField; -use rayon::iter::IntoParallelIterator; -use rayon::iter::IntoParallelRefIterator; -use rayon::prelude::{IndexedParallelIterator, ParallelIterator}; -use rayon::slice::ParallelSliceMut; +use rayon::{ + iter::{IntoParallelIterator, IntoParallelRefIterator}, + prelude::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSliceMut, +}; use serde::{Deserialize, Serialize}; #[rustfmt::skip] diff --git a/singer-pro/src/basic_block/bb_final.rs b/singer-pro/src/basic_block/bb_final.rs index 8bcd88f69..eca9d2259 100644 --- a/singer-pro/src/basic_block/bb_final.rs +++ b/singer-pro/src/basic_block/bb_final.rs @@ -14,7 +14,7 @@ use singer_utils::{ structs::{ChipChallenges, InstOutChipType, PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::UIntAddSub, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{BBFinalCircuit, BBFinalLayout, FromBBStart, FromPredInst, FromWitness}, diff --git a/singer-pro/src/basic_block/bb_ret.rs b/singer-pro/src/basic_block/bb_ret.rs index 18a55fd9c..8c828cf21 100644 --- a/singer-pro/src/basic_block/bb_ret.rs +++ b/singer-pro/src/basic_block/bb_ret.rs @@ -9,7 +9,7 @@ use singer_utils::{ register_witness, structs::{ChipChallenges, InstOutChipType, RAMHandler, ROMHandler, StackUInt, TSUInt}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{ diff --git a/singer-pro/src/instructions/add.rs b/singer-pro/src/instructions/add.rs index e136a8744..60eda4c80 100644 --- a/singer-pro/src/instructions/add.rs +++ b/singer-pro/src/instructions/add.rs @@ -10,7 +10,7 @@ use singer_utils::{ structs::{ChipChallenges, InstOutChipType, ROMHandler, StackUInt, TSUInt}, uint::UIntAddSub, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{FromPredInst, FromWitness, InstCircuit, InstLayout, ToSuccInst}, diff --git a/singer-pro/src/instructions/calldataload.rs b/singer-pro/src/instructions/calldataload.rs index 2d9436c32..82c93d1f9 100644 --- a/singer-pro/src/instructions/calldataload.rs +++ b/singer-pro/src/instructions/calldataload.rs @@ -9,7 +9,7 @@ use singer_utils::{ register_witness, structs::{ChipChallenges, InstOutChipType, ROMHandler, StackUInt, TSUInt, UInt64}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{FromPredInst, FromWitness, InstCircuit, InstLayout, ToSuccInst}, diff --git a/singer-pro/src/instructions/gt.rs b/singer-pro/src/instructions/gt.rs index 99099070b..6b53a7b93 100644 --- a/singer-pro/src/instructions/gt.rs +++ b/singer-pro/src/instructions/gt.rs @@ -10,7 +10,7 @@ use singer_utils::{ structs::{ChipChallenges, InstOutChipType, ROMHandler, StackUInt, TSUInt}, uint::UIntCmp, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{FromPredInst, FromWitness, InstCircuit, InstLayout, ToSuccInst}, diff --git a/singer-pro/src/instructions/jumpi.rs b/singer-pro/src/instructions/jumpi.rs index 3ab9ef735..acfd3a36b 100644 --- a/singer-pro/src/instructions/jumpi.rs +++ b/singer-pro/src/instructions/jumpi.rs @@ -11,7 +11,7 @@ use singer_utils::{ register_witness, structs::{ChipChallenges, InstOutChipType, PCUInt, ROMHandler, StackUInt, TSUInt}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{FromPredInst, FromWitness, InstCircuit, InstLayout, ToSuccInst}, diff --git a/singer-pro/src/instructions/mstore.rs b/singer-pro/src/instructions/mstore.rs index dcc290c7e..6f60ade91 100644 --- a/singer-pro/src/instructions/mstore.rs +++ b/singer-pro/src/instructions/mstore.rs @@ -12,7 +12,7 @@ use singer_utils::{ structs::{ChipChallenges, InstOutChipType, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::{UIntAddSub, UIntCmp}, }; -use std::{mem, sync::Arc}; +use std::{collections::BTreeMap, mem, sync::Arc}; use crate::{ component::{ diff --git a/singer-pro/src/instructions/ret.rs b/singer-pro/src/instructions/ret.rs index 605475560..8d315f0a7 100644 --- a/singer-pro/src/instructions/ret.rs +++ b/singer-pro/src/instructions/ret.rs @@ -11,7 +11,7 @@ use singer_utils::{ structs::{ChipChallenges, InstOutChipType, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::UIntAddSub, }; -use std::{mem, sync::Arc}; +use std::{collections::BTreeMap, mem, sync::Arc}; use crate::{ component::{ diff --git a/singer-utils/src/macros.rs b/singer-utils/src/macros.rs index 612de1966..b869e3ee5 100644 --- a/singer-utils/src/macros.rs +++ b/singer-utils/src/macros.rs @@ -1,5 +1,6 @@ #[macro_export] macro_rules! register_witness { + // phaseX_size() implementation ($struct_name:ident, $($wire_name:ident { $($slice_name:ident => $length:expr),* }),*) => { paste! { impl $struct_name { @@ -10,11 +11,24 @@ macro_rules! register_witness { } register_witness!(@internal $wire_name, 0usize; $($slice_name => $length),*); + + #[inline] + pub fn [<$wire_name _ idxes_map>]() -> BTreeMap<&'static str, std::ops::Range> { + let mut map = BTreeMap::new(); + + $( + map.insert(stringify!([<$wire_name _ $slice_name>]), Self::[<$wire_name _ $slice_name>]()); + )* + + map + } + )* } } }; + ($struct_name:ident, $($wire_name:ident { $($slice_name:ident => $length:expr),* }),*) => { paste! { impl $struct_name { @@ -25,6 +39,18 @@ macro_rules! register_witness { } register_witness!(@internal $wire_name, 0usize; $($slice_name => $length),*); + + #[inline] + pub fn [<$wire_name _ idxes_map>]() -> BTreeMap<&'static str, std::ops::Range> { + let mut map = BTreeMap::new(); + + $( + map.insert(stringify!([<$wire_name _ $slice_name>]), Self::[<$wire_name _ $slice_name>]()); + )* + + map + } + )* } } @@ -32,9 +58,13 @@ macro_rules! register_witness { (@internal $wire_name:ident, $offset:expr; $name:ident => $length:expr $(, $rest:ident => $rest_length:expr)*) => { paste! { - fn [<$wire_name _ $name>]() -> std::ops::Range { + pub fn [<$wire_name _ $name>]() -> std::ops::Range { $offset..$offset + $length } + + pub fn [<$wire_name _ $name _ str>]() -> &'static str { + stringify!([<$wire_name _ $name>]) + } register_witness!(@internal $wire_name, $offset + $length; $($rest => $rest_length),*); } }; diff --git a/singer-utils/src/uint.rs b/singer-utils/src/uint.rs index 09dade8f2..f62057e02 100644 --- a/singer-utils/src/uint.rs +++ b/singer-utils/src/uint.rs @@ -36,6 +36,8 @@ impl TryFrom> for UInt { } impl UInt { + pub const M: usize = M; + pub const C: usize = C; pub const N_OPRAND_CELLS: usize = (M + C - 1) / C; const N_CARRY_CELLS: usize = Self::N_OPRAND_CELLS; diff --git a/singer/benches/add.rs b/singer/benches/add.rs index b55578198..d674f9795 100644 --- a/singer/benches/add.rs +++ b/singer/benches/add.rs @@ -7,8 +7,7 @@ use ark_std::test_rng; use const_env::from_env; use criterion::*; -use ff_ext::ff::Field; -use ff_ext::ExtensionField; +use ff_ext::{ff::Field, ExtensionField}; use gkr::structs::LayerWitness; use goldilocks::GoldilocksExt2; use itertools::Itertools; @@ -57,8 +56,7 @@ fn bench_add(c: &mut Criterion) { #[cfg(feature = "non_pow2_rayon_thread")] { - use sumcheck::local_thread_pool::create_local_pool_once; - use sumcheck::util::ceil_log2; + use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); max_thread_id @@ -71,85 +69,80 @@ fn bench_add(c: &mut Criterion) { let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); - for instance_num_vars in 11..12 { + for instance_num_vars in 10..14 { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("add_op_{}", instance_num_vars)); group.sample_size(NUM_SAMPLES); // Benchmark the proving time group.bench_function( - BenchmarkId::new("prove_keccak256", format!("keccak256_log2_{}", instance_num_vars)), + BenchmarkId::new("prove_add", format!("prove_add_log2_{}", instance_num_vars)), |b| { b.iter_with_setup( || { let mut rng = test_rng(); let singer_builder = SingerGraphBuilder::::new(); - - let real_challenges = vec![E::random(&mut rng), E::random(&mut rng)]; - (rng, singer_builder, real_challenges) + let real_challenges = vec![E::random(&mut rng), E::random(&mut rng)]; + (rng, singer_builder, real_challenges) }, - | (mut rng,mut singer_builder, real_challenges)| { - - let size = AddInstruction::phase0_size(); - - let phase0: CircuitWiresIn< - ::BaseField, - > = vec![LayerWitness { - instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| { - ::BaseField::random( - &mut rng, - ) - }) - .collect_vec() - }) - .collect_vec(), - }]; - - - let timer = Instant::now(); - - let _ = AddInstruction::construct_graph_and_witness( - &mut singer_builder.graph_builder, - &mut singer_builder.chip_builder, - &circuit_builder.insts_circuits - [>::OPCODE as usize], - vec![phase0], - &real_challenges, - 1 << instance_num_vars, - &SingerParams::default(), - ) - .expect("gkr graph construction failed"); - - let (graph, wit) = singer_builder.graph_builder.finalize_graph_and_witness(); - - println!( - "AddInstruction::construct_graph_and_witness, instance_num_vars = {}, time = {}", - instance_num_vars, - timer.elapsed().as_secs_f64() - ); - - let point = vec![E::random(&mut rng), E::random(&mut rng)]; - let target_evals = graph.target_evals(&wit, &point); - - let mut prover_transcript = &mut Transcript::new(b"Singer"); - - let timer = Instant::now(); - let _ = GKRGraphProverState::prove( - &graph, - &wit, - &target_evals, - &mut prover_transcript, - (1 << instance_num_vars).min(max_thread_id), - ) - .expect("prove failed"); - println!( - "AddInstruction::prove, instance_num_vars = {}, time = {}", - instance_num_vars, - timer.elapsed().as_secs_f64() - ); + |(mut rng,mut singer_builder, real_challenges)| { + let size = AddInstruction::phase0_size(); + let phase0: CircuitWiresIn<::BaseField> = vec![LayerWitness { + instances: (0..(1 << instance_num_vars)) + .map(|_| { + (0..size) + .map(|_| { + ::BaseField::random( + &mut rng, + ) + }) + .collect_vec() + }) + .collect_vec(), + }]; + + + let timer = Instant::now(); + + let _ = AddInstruction::construct_graph_and_witness( + &mut singer_builder.graph_builder, + &mut singer_builder.chip_builder, + &circuit_builder.insts_circuits + [>::OPCODE as usize], + vec![phase0], + &real_challenges, + 1 << instance_num_vars, + &SingerParams::default(), + ) + .expect("gkr graph construction failed"); + + let (graph, wit) = singer_builder.graph_builder.finalize_graph_and_witness(); + + println!( + "AddInstruction::construct_graph_and_witness, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + + let point = vec![E::random(&mut rng), E::random(&mut rng)]; + let target_evals = graph.target_evals(&wit, &point); + + let mut prover_transcript = &mut Transcript::new(b"Singer"); + + let timer = Instant::now(); + let _ = GKRGraphProverState::prove( + &graph, + &wit, + &target_evals, + &mut prover_transcript, + (1 << instance_num_vars).min(max_thread_id), + ) + .expect("prove failed"); + println!( + "AddInstruction::prove, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); }); }, ); diff --git a/singer/examples/add.rs b/singer/examples/add.rs index 28a4f0a96..185e48e29 100644 --- a/singer/examples/add.rs +++ b/singer/examples/add.rs @@ -1,26 +1,117 @@ -use std::time::{Duration, Instant}; +use std::{collections::BTreeMap, time::Instant}; use ark_std::test_rng; -use const_env::from_env; -use criterion::*; - -use ff_ext::ff::Field; -use ff_ext::ExtensionField; +use ff_ext::{ff::Field, ExtensionField}; use gkr::structs::LayerWitness; -use goldilocks::GoldilocksExt2; +use gkr_graph::structs::CircuitGraphAuxInfo; +use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; +use simple_frontend::structs::CellId; use singer::{ instructions::{add::AddInstruction, Instruction, InstructionGraph, SingerCircuitBuilder}, - scheme::GKRGraphProverState, - CircuitWiresIn, SingerGraphBuilder, SingerParams, + scheme::{GKRGraphProverState, GKRGraphVerifierState}, + u64vec, CircuitWiresIn, SingerGraphBuilder, SingerParams, +}; +use singer_utils::{ + constants::RANGE_CHIP_BIT_WIDTH, + structs::{ChipChallenges, StackUInt, TSUInt}, }; -use singer_utils::structs::ChipChallenges; use transcript::Transcript; +fn get_single_instance_values_map() -> BTreeMap<&'static str, Vec> { + let mut phase0_values_map = BTreeMap::<&'static str, Vec>::new(); + phase0_values_map.insert( + AddInstruction::phase0_pc_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_stack_ts_str(), + vec![Goldilocks::from(3u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_memory_ts_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_stack_top_str(), + vec![Goldilocks::from(100u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_clk_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_pc_add_str(), + vec![], // carry is 0, may test carry using larger values in PCUInt + ); + phase0_values_map.insert( + AddInstruction::phase0_stack_ts_add_str(), + vec![ + Goldilocks::from(4u64), /* first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 + * cells are range values, stack_ts + 1 = 4 */ + Goldilocks::from(0u64), + Goldilocks::from(0u64), + Goldilocks::from(0u64), + // no place for carry + ], + ); + phase0_values_map.insert( + AddInstruction::phase0_old_stack_ts0_str(), + vec![Goldilocks::from(2u64)], + ); + let m: u64 = (1 << TSUInt::C) - 1; + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + phase0_values_map.insert( + AddInstruction::phase0_old_stack_ts_lt0_str(), + vec![ + Goldilocks::from(range_values[0]), + Goldilocks::from(range_values[1]), + Goldilocks::from(range_values[2]), + Goldilocks::from(range_values[3]), + Goldilocks::from(1u64), // borrow + ], + ); + phase0_values_map.insert( + AddInstruction::phase0_old_stack_ts1_str(), + vec![Goldilocks::from(1u64)], + ); + let m: u64 = (1 << TSUInt::C) - 2; + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + phase0_values_map.insert( + AddInstruction::phase0_old_stack_ts_lt1_str(), + vec![ + Goldilocks::from(range_values[0]), + Goldilocks::from(range_values[1]), + Goldilocks::from(range_values[2]), + Goldilocks::from(range_values[3]), + Goldilocks::from(1u64), // borrow + ], + ); + let m: u64 = (1 << StackUInt::C) - 1; + phase0_values_map.insert( + AddInstruction::phase0_addend_0_str(), + vec![Goldilocks::from(m)], + ); + phase0_values_map.insert( + AddInstruction::phase0_addend_1_str(), + vec![Goldilocks::from(1u64)], + ); + let range_values = u64vec::<{ StackUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m + 1); + let mut wit_phase0_instruction_add: Vec = vec![]; + for i in 0..16 { + wit_phase0_instruction_add.push(Goldilocks::from(range_values[i])) + } + wit_phase0_instruction_add.push(Goldilocks::from(1u64)); // carry is [1, 0, ...] + phase0_values_map.insert( + AddInstruction::phase0_instruction_add_str(), + wit_phase0_instruction_add, + ); + phase0_values_map +} fn main() { - let max_thread_id = 1; - let instance_num_vars = 9; + let max_thread_id = 8; + let instance_num_vars = 11; type E = GoldilocksExt2; let chip_challenges = ChipChallenges::default(); let circuit_builder = @@ -29,14 +120,31 @@ fn main() { let mut rng = test_rng(); let size = AddInstruction::phase0_size(); + let phase0_values_map = get_single_instance_values_map(); + let phase0_idx_map = AddInstruction::phase0_idxes_map(); + + let mut single_witness_in = vec![::BaseField::ZERO; size]; + + for key in phase0_idx_map.keys() { + let range = phase0_idx_map + .get(key) + .unwrap() + .clone() + .collect::>(); + let values = phase0_values_map + .get(key) + .expect(&("unknown key ".to_owned() + key)); + for (value_idx, cell_idx) in range.into_iter().enumerate() { + if value_idx < values.len() { + single_witness_in[cell_idx] = values[value_idx]; + } + } + } + let phase0: CircuitWiresIn<::BaseField> = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| ::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| single_witness_in.clone()) .collect_vec(), }]; @@ -66,20 +174,37 @@ fn main() { let point = vec![E::random(&mut rng), E::random(&mut rng)]; let target_evals = graph.target_evals(&wit, &point); - let mut prover_transcript = &mut Transcript::new(b"Singer"); - - let timer = Instant::now(); - let _ = GKRGraphProverState::prove( - &graph, - &wit, - &target_evals, - &mut prover_transcript, - (1 << instance_num_vars).min(max_thread_id), - ) - .expect("prove failed"); - println!( - "AddInstruction::prove, instance_num_vars = {}, time = {}", - instance_num_vars, - timer.elapsed().as_secs_f64() - ); + for _ in 0..5 { + let mut prover_transcript = &mut Transcript::new(b"Singer"); + let timer = Instant::now(); + let proof = GKRGraphProverState::prove( + &graph, + &wit, + &target_evals, + &mut prover_transcript, + (1 << instance_num_vars).min(max_thread_id), + ) + .expect("prove failed"); + println!( + "AddInstruction::prove, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + let mut verifier_transcript = Transcript::new(b"Singer"); + let _ = GKRGraphVerifierState::verify( + &graph, + &real_challenges, + &target_evals, + proof, + &CircuitGraphAuxInfo { + instance_num_vars: wit + .node_witnesses + .iter() + .map(|witness| witness.instance_num_vars()) + .collect(), + }, + &mut verifier_transcript, + ) + .expect("verify failed"); + } } diff --git a/singer/src/instructions.rs b/singer/src/instructions.rs index 863fcc3ef..772c233f0 100644 --- a/singer/src/instructions.rs +++ b/singer/src/instructions.rs @@ -221,6 +221,7 @@ pub trait InstructionGraph { real_n_instances: usize, _: &SingerParams, ) -> Result, ZKVMError> { + assert_eq!(sources.len(), 1, "unknown source length"); let inst_circuit = &inst_circuits[0]; let inst_wires_in = mem::take(&mut sources[0]); let node_id = graph_builder.add_node_with_witness( diff --git a/singer/src/instructions/add.rs b/singer/src/instructions/add.rs index f03c9fc85..0302f8cd0 100644 --- a/singer/src/instructions/add.rs +++ b/singer/src/instructions/add.rs @@ -13,7 +13,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::{UIntAddSub, UIntCmp}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -54,6 +54,7 @@ impl Instruction for AddInstruction { fn construct_circuit(challenges: ChipChallenges) -> Result, ZKVMError> { let mut circuit_builder = CircuitBuilder::new(); let (phase0_wire_id, phase0) = circuit_builder.create_witness_in(Self::phase0_size()); + let mut ram_handler = RAMHandler::new(&challenges); let mut rom_handler = ROMHandler::new(&challenges); @@ -116,6 +117,7 @@ impl Instruction for AddInstruction { // Pop two values from stack let old_stack_ts0 = (&phase0[Self::phase0_old_stack_ts0()]).try_into()?; + UIntCmp::::assert_lt( &mut circuit_builder, &mut rom_handler, @@ -123,6 +125,7 @@ impl Instruction for AddInstruction { &stack_ts, &phase0[Self::phase0_old_stack_ts_lt0()], )?; + ram_handler.stack_pop( &mut circuit_builder, stack_top_expr.sub(E::BaseField::from(1)), @@ -180,66 +183,27 @@ impl Instruction for AddInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::{StackUInt, TSUInt}; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{ + constants::RANGE_CHIP_BIT_WIDTH, + structs::{StackUInt, TSUInt}, + }; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - AddInstruction, ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, + use crate::{ + instructions::{ + AddInstruction, ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::{get_uint_params, test_opcode_circuit_v2}, + utils::u64vec, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl AddInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert( - "phase0_old_stack_ts0".to_string(), - Self::phase0_old_stack_ts0(), - ); - map.insert( - "phase0_old_stack_ts_lt0".to_string(), - Self::phase0_old_stack_ts_lt0(), - ); - map.insert( - "phase0_old_stack_ts1".to_string(), - Self::phase0_old_stack_ts1(), - ); - map.insert( - "phase0_old_stack_ts_lt1".to_string(), - Self::phase0_old_stack_ts_lt1(), - ); - map.insert("phase0_addend_0".to_string(), Self::phase0_addend_0()); - map.insert("phase0_addend_1".to_string(), Self::phase0_addend_1()); - map.insert( - "phase0_instruction_add".to_string(), - Self::phase0_instruction_add(), - ); - - map - } - } #[test] fn test_add_construct_circuit() { @@ -260,23 +224,36 @@ mod test { #[cfg(feature = "test-dbg")] println!("{:?}", inst_circuit); - let mut phase0_values_map = BTreeMap::>::new(); - phase0_values_map.insert("phase0_pc".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(3u64)]); - phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(1u64)]); + let mut phase0_values_map = BTreeMap::<&'static str, Vec>::new(); + phase0_values_map.insert( + AddInstruction::phase0_pc_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_stack_ts_str(), + vec![Goldilocks::from(3u64)], + ); phase0_values_map.insert( - "phase0_stack_top".to_string(), + AddInstruction::phase0_memory_ts_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_stack_top_str(), vec![Goldilocks::from(100u64)], ); - phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( - "phase0_pc_add".to_string(), + AddInstruction::phase0_clk_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_pc_add_str(), vec![], // carry is 0, may test carry using larger values in PCUInt ); phase0_values_map.insert( - "phase0_stack_ts_add".to_string(), + AddInstruction::phase0_stack_ts_add_str(), vec![ - Goldilocks::from(4u64), // first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(4u64), /* first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 + * cells are range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), @@ -284,13 +261,13 @@ mod test { ], ); phase0_values_map.insert( - "phase0_old_stack_ts0".to_string(), + AddInstruction::phase0_old_stack_ts0_str(), vec![Goldilocks::from(2u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( - "phase0_old_stack_ts_lt0".to_string(), + AddInstruction::phase0_old_stack_ts_lt0_str(), vec![ Goldilocks::from(range_values[0]), Goldilocks::from(range_values[1]), @@ -300,13 +277,13 @@ mod test { ], ); phase0_values_map.insert( - "phase0_old_stack_ts1".to_string(), + AddInstruction::phase0_old_stack_ts1_str(), vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 2; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( - "phase0_old_stack_ts_lt1".to_string(), + AddInstruction::phase0_old_stack_ts_lt1_str(), vec![ Goldilocks::from(range_values[0]), Goldilocks::from(range_values[1]), @@ -316,16 +293,23 @@ mod test { ], ); let m: u64 = (1 << get_uint_params::().1) - 1; - phase0_values_map.insert("phase0_addend_0".to_string(), vec![Goldilocks::from(m)]); - phase0_values_map.insert("phase0_addend_1".to_string(), vec![Goldilocks::from(1u64)]); - let range_values = u2vec::<{ StackUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m + 1); + phase0_values_map.insert( + AddInstruction::phase0_addend_0_str(), + vec![Goldilocks::from(m)], + ); + phase0_values_map.insert( + AddInstruction::phase0_addend_1_str(), + vec![Goldilocks::from(1u64)], + ); + let range_values = + u64vec::<{ StackUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m + 1); let mut wit_phase0_instruction_add: Vec = vec![]; for i in 0..16 { wit_phase0_instruction_add.push(Goldilocks::from(range_values[i])) } wit_phase0_instruction_add.push(Goldilocks::from(1u64)); // carry is [1, 0, ...] phase0_values_map.insert( - "phase0_instruction_add".to_string(), + AddInstruction::phase0_instruction_add_str(), wit_phase0_instruction_add, ); @@ -335,7 +319,7 @@ mod test { let c = GoldilocksExt2::from(6u64); let circuit_witness_challenges = vec![c; 3]; - let circuit_witness = test_opcode_circuit( + let _ = test_opcode_circuit_v2( &inst_circuit, &phase0_idx_map, phase0_witness_size, diff --git a/singer/src/instructions/calldataload.rs b/singer/src/instructions/calldataload.rs index de1ee19d5..5a7a176d9 100644 --- a/singer/src/instructions/calldataload.rs +++ b/singer/src/instructions/calldataload.rs @@ -13,7 +13,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt, UInt64}, uint::{UIntAddSub, UIntCmp}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -151,57 +151,25 @@ impl Instruction for CalldataloadInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - CalldataloadInstruction, ChipChallenges, Instruction, InstructionGraph, - SingerCircuitBuilder, + use crate::{ + instructions::{ + CalldataloadInstruction, ChipChallenges, Instruction, InstructionGraph, + SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl CalldataloadInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_ts".to_string(), Self::phase0_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert("phase0_data".to_string(), Self::phase0_data()); - map.insert("phase0_offset".to_string(), Self::phase0_offset()); - map.insert( - "phase0_old_stack_ts".to_string(), - Self::phase0_old_stack_ts(), - ); - map.insert( - "phase0_old_stack_ts_lt".to_string(), - Self::phase0_old_stack_ts_lt(), - ); - - map - } - } #[test] fn test_calldataload_construct_circuit() { @@ -239,7 +207,8 @@ mod test { phase0_values_map.insert( "phase0_stack_ts_add".to_string(), vec![ - Goldilocks::from(4u64), // first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(4u64), /* first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 + * cells are range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), @@ -251,7 +220,7 @@ mod test { vec![Goldilocks::from(2u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt".to_string(), vec![ diff --git a/singer/src/instructions/dup.rs b/singer/src/instructions/dup.rs index 6f4dc7e92..700e31557 100644 --- a/singer/src/instructions/dup.rs +++ b/singer/src/instructions/dup.rs @@ -13,7 +13,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::{UIntAddSub, UIntCmp}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -161,56 +161,24 @@ impl Instruction for DupInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - ChipChallenges, DupInstruction, Instruction, InstructionGraph, SingerCircuitBuilder, + use crate::{ + instructions::{ + ChipChallenges, DupInstruction, Instruction, InstructionGraph, SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl DupInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert( - "phase0_stack_values".to_string(), - Self::phase0_stack_values(), - ); - map.insert( - "phase0_old_stack_ts".to_string(), - Self::phase0_old_stack_ts(), - ); - map.insert( - "phase0_old_stack_ts_lt".to_string(), - Self::phase0_old_stack_ts_lt(), - ); - - map - } - } #[test] fn test_dup1_construct_circuit() { @@ -247,7 +215,8 @@ mod test { phase0_values_map.insert( "phase0_stack_ts_add".to_string(), vec![ - Goldilocks::from(3u64), // first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(3u64), /* first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 + * cells are range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), @@ -272,7 +241,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt".to_string(), vec![ diff --git a/singer/src/instructions/gt.rs b/singer/src/instructions/gt.rs index 2b0165531..857181378 100644 --- a/singer/src/instructions/gt.rs +++ b/singer/src/instructions/gt.rs @@ -13,7 +13,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::{UIntAddSub, UIntCmp}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -176,66 +176,24 @@ impl Instruction for GtInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - ChipChallenges, GtInstruction, Instruction, InstructionGraph, SingerCircuitBuilder, + use crate::{ + instructions::{ + ChipChallenges, GtInstruction, Instruction, InstructionGraph, SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl GtInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert( - "phase0_old_stack_ts0".to_string(), - Self::phase0_old_stack_ts0(), - ); - map.insert( - "phase0_old_stack_ts_lt0".to_string(), - Self::phase0_old_stack_ts_lt0(), - ); - map.insert( - "phase0_old_stack_ts1".to_string(), - Self::phase0_old_stack_ts1(), - ); - map.insert( - "phase0_old_stack_ts_lt1".to_string(), - Self::phase0_old_stack_ts_lt1(), - ); - map.insert("phase0_oprand_0".to_string(), Self::phase0_oprand_0()); - map.insert("phase0_oprand_1".to_string(), Self::phase0_oprand_1()); - map.insert( - "phase0_instruction_gt".to_string(), - Self::phase0_instruction_gt(), - ); - - map - } - } #[test] fn test_gt_construct_circuit() { @@ -272,7 +230,8 @@ mod test { phase0_values_map.insert( "phase0_stack_ts_add".to_string(), vec![ - Goldilocks::from(4u64), // first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(4u64), /* first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 + * cells are range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), @@ -284,7 +243,7 @@ mod test { vec![Goldilocks::from(2u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt0".to_string(), vec![ @@ -300,7 +259,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 2; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt1".to_string(), vec![ diff --git a/singer/src/instructions/jump.rs b/singer/src/instructions/jump.rs index 35c1175a4..dc44bf631 100644 --- a/singer/src/instructions/jump.rs +++ b/singer/src/instructions/jump.rs @@ -15,6 +15,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, TSUInt}, uint::UIntCmp, }; +use std::collections::BTreeMap; use crate::error::ZKVMError; @@ -125,48 +126,26 @@ impl Instruction for JumpInstruction { #[cfg(test)] mod test { - use crate::instructions::{ChipChallenges, Instruction, JumpInstruction}; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; + use crate::{ + instructions::{ChipChallenges, Instruction, JumpInstruction}, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + }; use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{InstructionGraph, SingerCircuitBuilder}; - use crate::scheme::GKRGraphProverState; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl JumpInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_next_pc".to_string(), Self::phase0_next_pc()); - map.insert( - "phase0_old_stack_ts".to_string(), - Self::phase0_old_stack_ts(), - ); - map.insert( - "phase0_old_stack_ts_lt".to_string(), - Self::phase0_old_stack_ts_lt(), - ); - - map - } - } + use crate::{ + instructions::{InstructionGraph, SingerCircuitBuilder}, + scheme::GKRGraphProverState, + CircuitWiresIn, SingerGraphBuilder, SingerParams, + }; #[test] fn test_jump_construct_circuit() { @@ -205,7 +184,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt".to_string(), vec![ diff --git a/singer/src/instructions/jumpdest.rs b/singer/src/instructions/jumpdest.rs index 99cd73343..ede7b3e05 100644 --- a/singer/src/instructions/jumpdest.rs +++ b/singer/src/instructions/jumpdest.rs @@ -12,7 +12,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, TSUInt}, uint::UIntAddSub, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -100,38 +100,23 @@ impl Instruction for JumpdestInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use std::collections::BTreeMap; - use std::time::Instant; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - ChipChallenges, Instruction, InstructionGraph, JumpdestInstruction, SingerCircuitBuilder, + use crate::{ + instructions::{ + ChipChallenges, Instruction, InstructionGraph, JumpdestInstruction, + SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::test_opcode_circuit, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::test_opcode_circuit; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl JumpdestInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - - map - } - } #[test] fn test_jumpdest_construct_circuit() { diff --git a/singer/src/instructions/jumpi.rs b/singer/src/instructions/jumpi.rs index d3d4c4103..887070dc9 100644 --- a/singer/src/instructions/jumpi.rs +++ b/singer/src/instructions/jumpi.rs @@ -14,7 +14,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::{UIntAddSub, UIntCmp}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; diff --git a/singer/src/instructions/mstore.rs b/singer/src/instructions/mstore.rs index 7185454d9..e8d7fcfe1 100644 --- a/singer/src/instructions/mstore.rs +++ b/singer/src/instructions/mstore.rs @@ -15,7 +15,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::{UIntAddSub, UIntCmp}, }; -use std::{mem, sync::Arc}; +use std::{collections::BTreeMap, mem, sync::Arc}; use crate::{error::ZKVMError, utils::add_assign_each_cell, CircuitWiresIn, SingerParams}; @@ -378,7 +378,9 @@ impl MstoreAccessory { #[cfg(test)] mod test { - use crate::{instructions::InstructionGraph, scheme::GKRGraphProverState, SingerParams}; + use crate::{ + instructions::InstructionGraph, scheme::GKRGraphProverState, utils::u64vec, SingerParams, + }; use ark_std::test_rng; use ff::Field; use ff_ext::ExtensionField; @@ -397,52 +399,11 @@ mod test { CircuitWiresIn, SingerGraphBuilder, }; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use core::ops::Range; + use crate::test::{get_uint_params, test_opcode_circuit}; use goldilocks::Goldilocks; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; use std::collections::BTreeMap; - impl MstoreInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_memory_ts_add".to_string(), - Self::phase0_memory_ts_add(), - ); - map.insert("phase0_offset".to_string(), Self::phase0_offset()); - map.insert("phase0_mem_bytes".to_string(), Self::phase0_mem_bytes()); - map.insert( - "phase0_old_stack_ts_offset".to_string(), - Self::phase0_old_stack_ts_offset(), - ); - map.insert( - "phase0_old_stack_ts_lt_offset".to_string(), - Self::phase0_old_stack_ts_lt_offset(), - ); - map.insert( - "phase0_old_stack_ts_value".to_string(), - Self::phase0_old_stack_ts_value(), - ); - map.insert( - "phase0_old_stack_ts_lt_value".to_string(), - Self::phase0_old_stack_ts_lt_value(), - ); - - map - } - } - #[test] fn test_mstore_construct_circuit() { let challenges = ChipChallenges::default(); @@ -478,7 +439,8 @@ mod test { phase0_values_map.insert( "phase0_memory_ts_add".to_string(), vec![ - Goldilocks::from(4u64), // first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 cells are range values, memory_ts + 1 = 4 + Goldilocks::from(4u64), /* first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 + * cells are range values, memory_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), @@ -491,7 +453,7 @@ mod test { vec![Goldilocks::from(2u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt_offset".to_string(), vec![ @@ -511,7 +473,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 2; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt_value".to_string(), vec![ diff --git a/singer/src/instructions/pop.rs b/singer/src/instructions/pop.rs index 32f083671..da34c2a99 100644 --- a/singer/src/instructions/pop.rs +++ b/singer/src/instructions/pop.rs @@ -13,7 +13,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::{UIntAddSub, UIntCmp}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -126,7 +126,6 @@ impl Instruction for PopInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; @@ -134,44 +133,20 @@ mod test { use itertools::Itertools; use std::collections::BTreeMap; - use crate::instructions::{ChipChallenges, Instruction, PopInstruction}; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; + use crate::{ + instructions::{ChipChallenges, Instruction, PopInstruction}, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + }; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; use std::time::Instant; use transcript::Transcript; - use crate::instructions::{InstructionGraph, SingerCircuitBuilder}; - use crate::scheme::GKRGraphProverState; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl PopInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_old_stack_ts".to_string(), - Self::phase0_old_stack_ts(), - ); - map.insert( - "phase0_old_stack_ts_lt".to_string(), - Self::phase0_old_stack_ts_lt(), - ); - map.insert( - "phase0_stack_values".to_string(), - Self::phase0_stack_values(), - ); - - map - } - } + use crate::{ + instructions::{InstructionGraph, SingerCircuitBuilder}, + scheme::GKRGraphProverState, + CircuitWiresIn, SingerGraphBuilder, SingerParams, + }; #[test] fn test_pop_construct_circuit() { @@ -210,7 +185,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt".to_string(), vec![ diff --git a/singer/src/instructions/push.rs b/singer/src/instructions/push.rs index 6dfe7084f..3eb2fad05 100644 --- a/singer/src/instructions/push.rs +++ b/singer/src/instructions/push.rs @@ -13,7 +13,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::UIntAddSub, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -146,46 +146,22 @@ impl Instruction for PushInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use std::collections::BTreeMap; - use std::time::Instant; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - ChipChallenges, Instruction, InstructionGraph, PushInstruction, SingerCircuitBuilder, + use crate::{ + instructions::{ + ChipChallenges, Instruction, InstructionGraph, PushInstruction, SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::test_opcode_circuit, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::test_opcode_circuit; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl PushInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert( - "phase0_pc_add_i_plus_1".to_string(), - Self::phase0_pc_add_i_plus_1(), - ); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert("phase0_stack_bytes".to_string(), Self::phase0_stack_bytes()); - - map - } - } #[test] fn test_push1_construct_circuit() { @@ -221,7 +197,8 @@ mod test { phase0_values_map.insert( "phase0_stack_ts_add".to_string(), vec![ - Goldilocks::from(2u64), // first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(2u64), /* first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 + * cells are range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), diff --git a/singer/src/instructions/ret.rs b/singer/src/instructions/ret.rs index ac24c2efa..4215564fe 100644 --- a/singer/src/instructions/ret.rs +++ b/singer/src/instructions/ret.rs @@ -15,7 +15,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::UIntAddSub, }; -use std::{mem, sync::Arc}; +use std::{collections::BTreeMap, mem, sync::Arc}; use crate::{error::ZKVMError, utils::add_assign_each_cell, CircuitWiresIn, SingerParams}; diff --git a/singer/src/instructions/swap.rs b/singer/src/instructions/swap.rs index 5c944652a..c1f00cf49 100644 --- a/singer/src/instructions/swap.rs +++ b/singer/src/instructions/swap.rs @@ -13,7 +13,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::{UIntAddSub, UIntCmp}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -181,68 +181,24 @@ impl Instruction for SwapInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, SwapInstruction, + use crate::{ + instructions::{ + ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, SwapInstruction, + }, + scheme::GKRGraphProverState, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl SwapInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert( - "phase0_old_stack_ts_1".to_string(), - Self::phase0_old_stack_ts_1(), - ); - map.insert( - "phase0_old_stack_ts_lt_1".to_string(), - Self::phase0_old_stack_ts_lt_1(), - ); - map.insert( - "phase0_old_stack_ts_n_plus_1".to_string(), - Self::phase0_old_stack_ts_n_plus_1(), - ); - map.insert( - "phase0_old_stack_ts_lt_n_plus_1".to_string(), - Self::phase0_old_stack_ts_lt_n_plus_1(), - ); - map.insert( - "phase0_stack_values_1".to_string(), - Self::phase0_stack_values_1(), - ); - map.insert( - "phase0_stack_values_n_plus_1".to_string(), - Self::phase0_stack_values_n_plus_1(), - ); - - map - } - } #[test] fn test_swap2_construct_circuit() { @@ -279,7 +235,8 @@ mod test { phase0_values_map.insert( "phase0_stack_ts_add".to_string(), vec![ - Goldilocks::from(5u64), // first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(5u64), /* first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 + * cells are range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), @@ -291,7 +248,7 @@ mod test { vec![Goldilocks::from(3u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt_1".to_string(), vec![ @@ -307,7 +264,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 3; - let range_values = u2vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt_n_plus_1".to_string(), vec![ diff --git a/singer/src/lib.rs b/singer/src/lib.rs index 614ae606f..69d07cd8b 100644 --- a/singer/src/lib.rs +++ b/singer/src/lib.rs @@ -18,14 +18,13 @@ pub mod instructions; pub mod scheme; #[cfg(test)] pub mod test; +pub use utils::u64vec; mod utils; // Process sketch: // 1. Construct instruction circuits and circuit gadgets => circuit gadgets -// 2. (bytecode + input) => Run revm interpreter, generate all wires in -// 2.1 phase 0 wire in + commitment -// 2.2 phase 1 wire in + commitment -// 2.3 phase 2 wire in + commitment +// 2. (bytecode + input) => Run revm interpreter, generate all wires in 2.1 phase 0 wire in + +// commitment 2.2 phase 1 wire in + commitment 2.3 phase 2 wire in + commitment // 3. (circuit gadgets + wires in) => gkr graph + gkr witness // 4. (gkr graph + gkr witness) => (gkr proof + point) // 5. (commitments + point) => pcs proof diff --git a/singer/src/test.rs b/singer/src/test.rs index 849189379..68daf2621 100644 --- a/singer/src/test.rs +++ b/singer/src/test.rs @@ -2,10 +2,6 @@ use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::CircuitWitness; -use gkr::structs::IOPProverState; -use gkr::utils::MultilinearExtensionFromVectors; -use goldilocks::SmallField; -use itertools::Itertools; use simple_frontend::structs::CellId; use singer_utils::structs::UInt; use std::collections::BTreeMap; @@ -26,21 +22,11 @@ pub(crate) fn get_uint_params() -> (usize, usize) { (T::BITS, T::CELL_BIT_WIDTH) } -pub(crate) fn u2vec(x: u64) -> [u64; W] { - let mut x = x; - let mut ret = [0; W]; - for i in 0..ret.len() { - ret[i] = x & ((1 << C) - 1); - x >>= C; - } - ret -} - -pub(crate) fn test_opcode_circuit( +pub(crate) fn test_opcode_circuit_v2( inst_circuit: &InstCircuit, - phase0_idx_map: &BTreeMap>, + phase0_idx_map: &BTreeMap<&'static str, Range>, phase0_witness_size: usize, - phase0_values_map: &BTreeMap>, + phase0_values_map: &BTreeMap<&'static str, Vec>, circuit_witness_challenges: Vec, ) -> CircuitWitness<::BaseField> { // configure circuit @@ -65,7 +51,9 @@ pub(crate) fn test_opcode_circuit( .unwrap() .clone() .collect::>(); - let values = phase0_values_map.get(key).unwrap(); + let values = phase0_values_map + .get(key) + .expect(&("unknown key ".to_owned() + key)); for (value_idx, cell_idx) in range.into_iter().enumerate() { if value_idx < values.len() { witness_in[phase0_input_idx as usize][cell_idx] = values[value_idx]; @@ -89,67 +77,86 @@ pub(crate) fn test_opcode_circuit( circuit_witness - /*let instance_num_vars = circuit_witness.instance_num_vars(); - let (proof, output_num_vars, output_eval) = { - let mut prover_transcript = Transcript::::new(b"example"); - let output_num_vars = instance_num_vars + circuit.first_layer_ref().num_vars(); - let output_point = (0..output_num_vars) - .map(|_| { - prover_transcript - .get_and_append_challenge(b"output point") - .elements - }) - .collect_vec(); - let output_eval = circuit_witness - .layer_poly(0, circuit.first_layer_ref().num_vars()) - .evaluate(&output_point); - ( - IOPProverState::prove_parallel( - &circuit, - &circuit_witness, - &[(output_point, output_eval)], - &[], - &mut prover_transcript, - ), - output_num_vars, - output_eval, - ) - };*/ - /* - let gkr_input_claims = { - let mut verifier_transcript = &mut Transcript::::new(b"example"); - let output_point = (0..output_num_vars) - .map(|_| { - verifier_transcript - .get_and_append_challenge(b"output point") - .elements - }) - .collect_vec(); - IOPVerifierState::verify_parallel( - &circuit, - circuit_witness.challenges(), - &[(output_point, output_eval)], - &[], - &proof, - instance_num_vars, - &mut verifier_transcript, - ) - .expect("verification failed") - }; - let expected_values = circuit_witness - .witness_in_ref() + // let instance_num_vars = circuit_witness.instance_num_vars(); + // let (proof, output_num_vars, output_eval) = { + // let mut prover_transcript = Transcript::::new(b"example"); + // let output_num_vars = instance_num_vars + circuit.first_layer_ref().num_vars(); + // let output_point = (0..output_num_vars) + // .map(|_| { + // prover_transcript + // .get_and_append_challenge(b"output point") + // .elements + // }) + // .collect_vec(); + // let output_eval = circuit_witness + // .layer_poly(0, circuit.first_layer_ref().num_vars()) + // .evaluate(&output_point); + // ( + // IOPProverState::prove_parallel( + // &circuit, + // &circuit_witness, + // &[(output_point, output_eval)], + // &[], + // &mut prover_transcript, + // ), + // output_num_vars, + // output_eval, + // ) + // }; + // let gkr_input_claims = { + // let mut verifier_transcript = &mut Transcript::::new(b"example"); + // let output_point = (0..output_num_vars) + // .map(|_| { + // verifier_transcript + // .get_and_append_challenge(b"output point") + // .elements + // }) + // .collect_vec(); + // IOPVerifierState::verify_parallel( + // &circuit, + // circuit_witness.challenges(), + // &[(output_point, output_eval)], + // &[], + // &proof, + // instance_num_vars, + // &mut verifier_transcript, + // ) + // .expect("verification failed") + // }; + // let expected_values = circuit_witness + // .witness_in_ref() + // .iter() + // .map(|witness| { + // witness + // .instances + // .as_slice() + // .mle(circuit.max_wit_in_num_vars.expect("REASON"), instance_num_vars) + // .evaluate(&gkr_input_claims.point_and_evals) + // }) + // .collect_vec(); + // for i in 0..gkr_input_claims.point_and_evals.len() { + // assert_eq!(expected_values[i], gkr_input_claims.point_and_evals[i]); + // } + // println!("verification succeeded"); +} + +#[deprecated(note = "deprecated and use test_opcode_circuit_v2 instead")] +pub(crate) fn test_opcode_circuit( + inst_circuit: &InstCircuit, + phase0_idx_map: &BTreeMap<&'static str, Range>, + phase0_witness_size: usize, + phase0_values_map: &BTreeMap>, + circuit_witness_challenges: Vec, +) -> CircuitWitness<::BaseField> { + let phase0_values_map = phase0_values_map .iter() - .map(|witness| { - witness - .instances - .as_slice() - .mle(circuit.max_wit_in_num_vars.expect("REASON"), instance_num_vars) - .evaluate(&gkr_input_claims.point_and_evals) - }) - .collect_vec(); - for i in 0..gkr_input_claims.point_and_evals.len() { - assert_eq!(expected_values[i], gkr_input_claims.point_and_evals[i]); - } - println!("verification succeeded"); - */ + .map(|(key, value)| (key.clone().leak() as &'static str, value.clone())) + .collect::>>(); + test_opcode_circuit_v2( + inst_circuit, + phase0_idx_map, + phase0_witness_size, + &phase0_values_map, + circuit_witness_challenges, + ) } diff --git a/singer/src/utils.rs b/singer/src/utils.rs index d194edc8c..d1314c56b 100644 --- a/singer/src/utils.rs +++ b/singer/src/utils.rs @@ -21,3 +21,16 @@ pub(crate) fn add_assign_each_cell( circuit_builder.add(*dest, *src, E::BaseField::ONE); } } + +// split single u64 value into W slices, each slice got C bits. +// all the rest slices will be filled with 0 if W x C > 64 +pub fn u64vec(x: u64) -> [u64; W] { + assert!(C <= 64); + let mut x = x; + let mut ret = [0; W]; + for i in 0..ret.len() { + ret[i] = x & ((1 << C) - 1); + x >>= C; + } + ret +} diff --git a/sumcheck/src/verifier.rs b/sumcheck/src/verifier.rs index c554dc12d..b4e119d3d 100644 --- a/sumcheck/src/verifier.rs +++ b/sumcheck/src/verifier.rs @@ -155,8 +155,10 @@ impl IOPVerifierState { // 1. check if the received 'P(0) + P(1) = expected`. if evaluations[0] + evaluations[1] != expected { panic!( - "{}th round's prover message is not consistent with the claim. {:?} {:?} {:?}", - i, evaluations[0], evaluations[1], expected + "{}th round's prover message is not consistent with the claim. {:?} {:?}", + i, + evaluations[0] + evaluations[1], + expected ); } }