From 580159132709fe4b24cb2f5ed530969c9018042f Mon Sep 17 00:00:00 2001 From: Ming Date: Mon, 15 Jul 2024 21:30:20 +0800 Subject: [PATCH] devirgo style on phase 1 (#83) * devirgo style on phase 1 instruction add example with prove/verify * optimize sumcheck algo * fix comment --- Cargo.lock | 3 + gkr-graph/src/circuit_graph_builder.rs | 1 + gkr-graph/src/prover.rs | 93 ++++-- gkr-graph/src/verifier.rs | 90 +++--- gkr/benches/keccak256.rs | 4 +- gkr/examples/keccak256.rs | 3 +- gkr/src/circuit/circuit_layout.rs | 13 +- gkr/src/circuit/circuit_witness.rs | 32 +- gkr/src/lib.rs | 2 + gkr/src/prover.rs | 74 +++-- gkr/src/prover/phase1.rs | 321 ++++++++------------ gkr/src/prover/phase1_output.rs | 38 ++- gkr/src/prover/phase2_input.rs | 1 - gkr/src/structs.rs | 13 +- gkr/src/utils.rs | 2 +- gkr/src/verifier.rs | 33 +- gkr/src/verifier/phase1.rs | 113 +++---- multilinear_extensions/src/mle.rs | 7 +- multilinear_extensions/src/virtual_poly.rs | 27 +- rustfmt.toml | 5 +- singer-pro/src/basic_block/bb_final.rs | 2 +- singer-pro/src/basic_block/bb_ret.rs | 2 +- singer-pro/src/instructions/add.rs | 2 +- singer-pro/src/instructions/calldataload.rs | 2 +- singer-pro/src/instructions/gt.rs | 2 +- singer-pro/src/instructions/jumpi.rs | 2 +- singer-pro/src/instructions/mstore.rs | 2 +- singer-pro/src/instructions/ret.rs | 2 +- singer-utils/src/macros.rs | 32 +- singer/Cargo.toml | 4 +- singer/benches/add.rs | 135 ++++---- singer/examples/add.rs | 205 +++++++++++-- singer/src/instructions.rs | 1 + singer/src/instructions/add.rs | 107 +++---- singer/src/instructions/calldataload.rs | 63 +--- singer/src/instructions/dup.rs | 61 +--- singer/src/instructions/gt.rs | 73 +---- singer/src/instructions/jump.rs | 51 +--- singer/src/instructions/jumpdest.rs | 35 +-- singer/src/instructions/jumpi.rs | 2 +- singer/src/instructions/mstore.rs | 52 +--- singer/src/instructions/pop.rs | 53 +--- singer/src/instructions/push.rs | 47 +-- singer/src/instructions/ret.rs | 2 +- singer/src/instructions/swap.rs | 75 +---- singer/src/lib.rs | 1 + singer/src/test.rs | 167 +++++----- singer/src/utils.rs | 13 + sumcheck/src/prover.rs | 62 ++-- sumcheck/src/verifier.rs | 6 +- 50 files changed, 1004 insertions(+), 1134 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1aaab2229..3c1c14bc4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1275,6 +1275,9 @@ dependencies = [ "singer-utils", "strum 0.25.0", "strum_macros 0.25.3", + "tracing", + "tracing-flame", + "tracing-subscriber", "transcript", ] diff --git a/gkr-graph/src/circuit_graph_builder.rs b/gkr-graph/src/circuit_graph_builder.rs index d12d54c41..5ae1854cc 100644 --- a/gkr-graph/src/circuit_graph_builder.rs +++ b/gkr-graph/src/circuit_graph_builder.rs @@ -76,6 +76,7 @@ impl CircuitGraphBuilder { ), }; let old_num_instances = self.witness.node_witnesses[*id].n_instances(); + // TODO find way to avoid expensive clone for wit_in let new_instances = match pred { PredType::PredWire(_) => { let new_size = (old_num_instances * out[0].len()) / num_instances; diff --git a/gkr-graph/src/prover.rs b/gkr-graph/src/prover.rs index df7ef8f26..74cbcf0eb 100644 --- a/gkr-graph/src/prover.rs +++ b/gkr-graph/src/prover.rs @@ -21,6 +21,7 @@ impl IOPProverState { expected_max_thread_id: usize, ) -> Result, GKRGraphError> { assert_eq!(target_evals.0.len(), circuit.targets.len()); + assert_eq!(circuit_witness.node_witnesses.len(), circuit.nodes.len()); let mut output_evals = vec![vec![]; circuit.nodes.len()]; let mut wit_out_evals = circuit @@ -36,10 +37,42 @@ impl IOPProverState { let gkr_proofs = izip!(&circuit.nodes, &circuit_witness.node_witnesses) .rev() .map(|(node, witness)| { - // println!("expected_max_thread_id {:?}", expected_max_thread_id); let max_thread_id = witness.n_instances().min(expected_max_thread_id); - // println!("max_thread_id {:?}", max_thread_id); - let timer = std::time::Instant::now(); + + // sanity check for witness poly evaluation + if cfg!(debug_assertions) { + + // TODO figure out a way to do sanity check on output_evals + // it doens't work for now because output evaluation + // might only take partial range of output layer witness + // assert!(output_evals[node.id].len() <= 1); + // if !output_evals[node.id].is_empty() { + // debug_assert_eq!( + // witness + // .output_layer_witness_ref() + // .instances + // .as_slice() + // .original_mle() + // .evaluate(&point_and_eval.point), + // point_and_eval.eval, + // "node_id {} output eval failed", + // node.id, + // ); + // } + + for (witness_id, point_and_eval) in wit_out_evals[node.id].iter().enumerate() { + let mle = witness.witness_out_ref()[witness_id] + .instances + .as_slice() + .original_mle(); + debug_assert_eq!( + mle.evaluate(&point_and_eval.point), + point_and_eval.eval, + "node_id {} output eval failed", + node.id, + ); + } + } let (proof, input_claim) = GKRProverState::prove_parallel( &node.circuit, witness, @@ -48,6 +81,7 @@ impl IOPProverState { max_thread_id, transcript, ); + // println!( // "Proving node {}, label {}, num_instances:{}, took {}s", // node.id, @@ -56,24 +90,32 @@ impl IOPProverState { // timer.elapsed().as_secs_f64() // ); - izip!(&node.preds, input_claim.point_and_evals) + izip!(&node.preds, &input_claim.point_and_evals) .enumerate() - .for_each(|(wire_id, (pred, point_and_eval))| match pred { + .for_each(|(wire_id, (pred_type, point_and_eval))| match pred_type { PredType::Source => { - debug_assert_eq!( - witness.witness_in_ref()[wire_id as usize] + // sanity check for input poly evaluation + if cfg!(debug_assertions) { + let input_layer_poly = witness.witness_in_ref()[wire_id] .instances .as_slice() - .original_mle() - .evaluate(&point_and_eval.point), - point_and_eval.eval - ); + .original_mle(); + debug_assert_eq!( + input_layer_poly.evaluate(&point_and_eval.point), + point_and_eval.eval, + "mismatch at node.id {:?} wire_id {:?}, input_claim.point_and_evals.point {:?}, node.preds {:?}", + node.id, + wire_id, + input_claim.point_and_evals[0].point, + node.preds + ); + } } - PredType::PredWire(out) | PredType::PredWireDup(out) => { - let point = match pred { + PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => { + let point = match pred_type { PredType::PredWire(_) => point_and_eval.point.clone(), PredType::PredWireDup(out) => { - let node_id = match out { + let pred_node_id = match out { NodeOutputType::OutputLayer(id) => id, NodeOutputType::WireOut(id, _) => id, }; @@ -81,27 +123,28 @@ impl IOPProverState { // [single_instance_slice || // new_instance_index_slice]. The old point // is [single_instance_slices || - // new_instance_index_slices[(new_instance_num_vars - // - old_instance_num_vars)..]] - let old_instance_num_vars = circuit_witness.node_witnesses - [*node_id] + // new_instance_index_slices[(instance_num_vars + // - pred_instance_num_vars)..]] + let pred_instance_num_vars = circuit_witness.node_witnesses + [*pred_node_id] .instance_num_vars(); - let new_instance_num_vars = witness.instance_num_vars(); - let num_vars = - point_and_eval.point.len() - new_instance_num_vars; + let instance_num_vars = witness.instance_num_vars(); + let num_vars = point_and_eval.point.len() - instance_num_vars; [ point_and_eval.point[..num_vars].to_vec(), point_and_eval.point[num_vars - + (new_instance_num_vars - old_instance_num_vars)..] + + (instance_num_vars - pred_instance_num_vars)..] .to_vec(), ] .concat() } _ => unreachable!(), }; - match out { - NodeOutputType::OutputLayer(id) => output_evals[*id] - .push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)), + match pred_out { + NodeOutputType::OutputLayer(id) => { + output_evals[*id] + .push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)) + }, NodeOutputType::WireOut(id, wire_id) => { let evals = &mut wit_out_evals[*id][*wire_id as usize]; assert!( diff --git a/gkr-graph/src/verifier.rs b/gkr-graph/src/verifier.rs index 4aabbd7a4..7094dfb85 100644 --- a/gkr-graph/src/verifier.rs +++ b/gkr-graph/src/verifier.rs @@ -50,52 +50,56 @@ impl IOPVerifierState { let new_instance_num_vars = aux_info.instance_num_vars[node.id]; - izip!(&node.preds, input_claim.point_and_evals).for_each(|(pred, point_and_eval)| { - match pred { - PredType::Source => { - // TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations` for later PCS open? - } - PredType::PredWire(out) | PredType::PredWireDup(out) => { - let old_point = match pred { - PredType::PredWire(_) => point_and_eval.point.clone(), - PredType::PredWireDup(out) => { - let node_id = match out { - NodeOutputType::OutputLayer(id) => *id, - NodeOutputType::WireOut(id, _) => *id, - }; - // Suppose the new point is - // [single_instance_slice || - // new_instance_index_slice]. The old point - // is [single_instance_slices || - // new_instance_index_slices[(new_instance_num_vars - // - old_instance_num_vars)..]] - let old_instance_num_vars = aux_info.instance_num_vars[node_id]; - let num_vars = point_and_eval.point.len() - new_instance_num_vars; - [ - point_and_eval.point[..num_vars].to_vec(), - point_and_eval.point[num_vars - + (new_instance_num_vars - old_instance_num_vars)..] - .to_vec(), - ] - .concat() - } - _ => unreachable!(), - }; - match out { - NodeOutputType::OutputLayer(id) => output_evals[*id] - .push(PointAndEval::new_from_ref(&old_point, &point_and_eval.eval)), - NodeOutputType::WireOut(id, wire_id) => { - let evals = &mut wit_out_evals[*id][*wire_id as usize]; - assert!( - evals.point.is_empty() && evals.eval.is_zero_vartime(), - "unimplemented", - ); - *evals = PointAndEval::new(old_point, point_and_eval.eval); + izip!(&node.preds, input_claim.point_and_evals).for_each( + |(pred_type, point_and_eval)| { + match pred_type { + PredType::Source => { + // TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations` + // for later PCS open? + } + PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => { + let point = match pred_type { + PredType::PredWire(_) => point_and_eval.point.clone(), + PredType::PredWireDup(out) => { + let node_id = match out { + NodeOutputType::OutputLayer(id) => *id, + NodeOutputType::WireOut(id, _) => *id, + }; + // Suppose the new point is + // [single_instance_slice || + // new_instance_index_slice]. The old point + // is [single_instance_slices || + // new_instance_index_slices[(new_instance_num_vars + // - old_instance_num_vars)..]] + let old_instance_num_vars = aux_info.instance_num_vars[node_id]; + let num_vars = + point_and_eval.point.len() - new_instance_num_vars; + [ + point_and_eval.point[..num_vars].to_vec(), + point_and_eval.point[num_vars + + (new_instance_num_vars - old_instance_num_vars)..] + .to_vec(), + ] + .concat() + } + _ => unreachable!(), + }; + match pred_out { + NodeOutputType::OutputLayer(id) => output_evals[*id] + .push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)), + NodeOutputType::WireOut(id, wire_id) => { + let evals = &mut wit_out_evals[*id][*wire_id as usize]; + assert!( + evals.point.is_empty() && evals.eval.is_zero_vartime(), + "unimplemented", + ); + *evals = PointAndEval::new(point, point_and_eval.eval); + } } } } - } - }); + }, + ); } Ok(()) diff --git a/gkr/benches/keccak256.rs b/gkr/benches/keccak256.rs index c7c261e78..b27b37e14 100644 --- a/gkr/benches/keccak256.rs +++ b/gkr/benches/keccak256.rs @@ -10,7 +10,6 @@ use gkr::gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak2 use goldilocks::GoldilocksExt2; use sumcheck::util::is_power_of_2; -// cargo bench --bench keccak256 --features parallel --features flamegraph --package gkr -- --profile-time cfg_if::cfg_if! { if #[cfg(feature = "flamegraph")] { criterion_group! { @@ -48,8 +47,7 @@ fn bench_keccak256(c: &mut Criterion) { #[cfg(feature = "non_pow2_rayon_thread")] { - use sumcheck::local_thread_pool::create_local_pool_once; - use sumcheck::util::ceil_log2; + use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); max_thread_id diff --git a/gkr/examples/keccak256.rs b/gkr/examples/keccak256.rs index d8a1433ac..a105e0930 100644 --- a/gkr/examples/keccak256.rs +++ b/gkr/examples/keccak256.rs @@ -36,8 +36,7 @@ fn main() { #[cfg(feature = "non_pow2_rayon_thread")] { - use sumcheck::local_thread_pool::create_local_pool_once; - use sumcheck::util::ceil_log2; + use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; max_thread_id = 1 << ceil_log2(max_thread_id); create_local_pool_once(max_thread_id, true); } diff --git a/gkr/src/circuit/circuit_layout.rs b/gkr/src/circuit/circuit_layout.rs index 3fb7d7787..8e71bd4cb 100644 --- a/gkr/src/circuit/circuit_layout.rs +++ b/gkr/src/circuit/circuit_layout.rs @@ -132,7 +132,8 @@ impl Circuit { }); let segment = ( wire_ids_in_layer[in_cell_ids[0]], - wire_ids_in_layer[in_cell_ids[in_cell_ids.len() - 1]] + 1, + wire_ids_in_layer[in_cell_ids[in_cell_ids.len() - 1]] + 1, /* + 1 for exclusive + * last index */ ); match ty { InType::Witness(wit_id) => { @@ -258,9 +259,10 @@ impl Circuit { .push(output_subsets.update_wire_id(old_layer_id, old_wire_id)); } OutType::AssertConst(constant) => { + let new_wire_id = output_subsets.update_wire_id(old_layer_id, old_wire_id); output_assert_const.push(GateCIn { idx_in: [], - idx_out: output_subsets.update_wire_id(old_layer_id, old_wire_id), + idx_out: new_wire_id, scalar: ConstantType::Field(i64_to_field(constant)), }); } @@ -288,8 +290,7 @@ impl Circuit { } else { let last_layer = &layers[(layer_id - 1) as usize]; if !last_layer.is_linear() || !layer.copy_to.is_empty() { - curr_sc_steps - .extend([SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2]); + curr_sc_steps.extend([SumcheckStepType::Phase1Step1]); } } @@ -900,7 +901,7 @@ mod tests { // Single input witness, therefore no input phase 2 steps. assert_eq!( circuit.layers[2].sumcheck_steps, - vec![SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2,] + vec![SumcheckStepType::Phase1Step1] ); // There are only one incoming evals since the last layer is linear, and // no subset evals. Therefore, there are no phase1 steps. @@ -931,7 +932,7 @@ mod tests { // Single input witness, therefore no input phase 2 steps. assert_eq!( circuit.layers[1].sumcheck_steps, - vec![SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2] + vec![SumcheckStepType::Phase1Step1] ); // Output layer, single output witness, therefore no output phase 1 steps. assert_eq!( diff --git a/gkr/src/circuit/circuit_witness.rs b/gkr/src/circuit/circuit_witness.rs index ee692b41e..f7351e652 100644 --- a/gkr/src/circuit/circuit_witness.rs +++ b/gkr/src/circuit/circuit_witness.rs @@ -53,6 +53,7 @@ impl CircuitWitness { let mut layer_wit = vec![vec![F::ZERO; circuit.layers[n_layers - 1].size()]; n_instances]; for instance_id in 0..n_instances { + assert_eq!(wits_in.len(), circuit.paste_from_wits_in.len()); for (wit_id, (l, r)) in circuit.paste_from_wits_in.iter().enumerate() { for i in *l..*r { layer_wit[instance_id][i] = @@ -175,34 +176,45 @@ impl CircuitWitness { pub fn add_instances( &mut self, circuit: &Circuit, - wits_in: Vec>, + new_wits_in: Vec>, n_instances: usize, ) where E: ExtensionField, { - assert_eq!(wits_in.len(), circuit.n_witness_in); + assert_eq!(new_wits_in.len(), circuit.n_witness_in); assert!(n_instances.is_power_of_two()); - assert!(!wits_in + assert!(!new_wits_in .iter() .any(|wit_in| wit_in.instances.len() != n_instances)); - let (new_layer_wits, new_wits_out) = - CircuitWitness::new_instances(circuit, &wits_in, &self.challenges, n_instances); + let (inferred_layer_wits, inferred_wits_out) = + CircuitWitness::new_instances(circuit, &new_wits_in, &self.challenges, n_instances); // Merge self and circuit_witness. - for (layer_wit, new_layer_wit) in self.layers.iter_mut().zip(new_layer_wits.into_iter()) { - layer_wit.instances.extend(new_layer_wit.instances); + for (layer_wit, inferred_layer_wit) in + self.layers.iter_mut().zip(inferred_layer_wits.into_iter()) + { + layer_wit.instances.extend(inferred_layer_wit.instances); } - for (wit_out, new_wit_out) in self.witness_out.iter_mut().zip(new_wits_out.into_iter()) { - wit_out.instances.extend(new_wit_out.instances); + for (wit_out, inferred_wits_out) in self + .witness_out + .iter_mut() + .zip(inferred_wits_out.into_iter()) + { + wit_out.instances.extend(inferred_wits_out.instances); } - for (wit_in, new_wit_in) in self.witness_in.iter_mut().zip(wits_in.into_iter()) { + for (wit_in, new_wit_in) in self.witness_in.iter_mut().zip(new_wits_in.into_iter()) { wit_in.instances.extend(new_wit_in.instances); } self.n_instances += n_instances; + + // check correctness in debug build + if cfg!(debug_assertions) { + self.check_correctness(circuit); + } } pub fn instance_num_vars(&self) -> usize { diff --git a/gkr/src/lib.rs b/gkr/src/lib.rs index 7ca4d1284..214126a5b 100644 --- a/gkr/src/lib.rs +++ b/gkr/src/lib.rs @@ -11,5 +11,7 @@ pub mod unsafe_utils; pub mod utils; mod verifier; +pub use sumcheck::util; + #[cfg(test)] mod test; diff --git a/gkr/src/prover.rs b/gkr/src/prover.rs index 26a8c6c87..ee1e6890b 100644 --- a/gkr/src/prover.rs +++ b/gkr/src/prover.rs @@ -8,7 +8,8 @@ use multilinear_extensions::{ virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, }; use rayon::iter::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelIterator, }; use simple_frontend::structs::LayerId; use transcript::Transcript; @@ -85,21 +86,44 @@ impl IOPProverState { transcript, )].to_vec() }, - (SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2, _) => - [ - prover_state - .prove_and_update_state_phase1_step1( - circuit, - circuit_witness, - transcript, - ), - prover_state - .prove_and_update_state_phase1_step2( - circuit, - circuit_witness, - transcript, - ), - ].to_vec() + (SumcheckStepType::Phase1Step1, _, _) => { + let alpha = transcript + .get_and_append_challenge(b"combine subset evals") + .elements; + let hi_num_vars = circuit_witness.instance_num_vars(); + let eq_t = prover_state.to_next_phase_point_and_evals.par_iter().chain(prover_state.subset_point_and_evals[layer_id as usize].par_iter().map(|(_, point_and_eval)| point_and_eval)).map(|point_and_eval|{ + let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + build_eq_x_r_vec(&point_and_eval.point[point_lo_num_vars..]) + }).collect::>>(); + let virtual_polys: Vec> = (0..max_thread_id).into_par_iter().map(|thread_id| { + let span = entered_span!("build_poly"); + let virtual_poly = IOPProverState::build_phase1_step1_sumcheck_poly( + &prover_state, + layer_id, + alpha, + &eq_t, + circuit, + circuit_witness, + (thread_id, max_thread_id), + ); + exit_span!(span); + virtual_poly + }).collect(); + + let (sumcheck_proof, sumcheck_prover_state) = sumcheck::structs::IOPProverState::::prove_batch_polys( + max_thread_id, + virtual_polys.try_into().unwrap(), + transcript, + ); + + let prover_msg = prover_state.combine_phase1_step1_evals( + sumcheck_proof, + sumcheck_prover_state, + ); + + vec![prover_msg] + + } , (SumcheckStepType::Phase2Step1, step2, _) => { let span = entered_span!("phase2_gkr"); @@ -261,19 +285,11 @@ impl IOPProverState { }) .collect_vec(); let to_next_phase_point_and_evals = output_evals; - subset_point_and_evals[0] = wires_out_evals.into_iter().map(|p| (0, p)).collect(); - - let phase1_layer_polys = (0..n_layers) - .into_par_iter() - .map(|layer_id| { - let num_vars = circuit.layers[layer_id].num_vars; - mem::take(&mut circuit_witness.layer_poly( - layer_id.try_into().unwrap(), - num_vars, - (0, 1), - )) - }) + subset_point_and_evals[0] = wires_out_evals + .into_iter() + .map(|p| (0 as LayerId, p)) .collect(); + Self { to_next_phase_point_and_evals, subset_point_and_evals, @@ -282,7 +298,7 @@ impl IOPProverState { assert_point, // Default layer_id: 0, - phase1_layer_polys, + phase1_layer_poly: ArcDenseMultilinearExtension::default(), g1_values: vec![], } } diff --git a/gkr/src/prover/phase1.rs b/gkr/src/prover/phase1.rs index 3e0bb117b..04dd3170f 100644 --- a/gkr/src/prover/phase1.rs +++ b/gkr/src/prover/phase1.rs @@ -1,45 +1,42 @@ use ark_std::{end_timer, start_timer}; use ff::Field; use ff_ext::ExtensionField; -use itertools::{chain, Itertools}; +use itertools::{izip, Itertools}; use multilinear_extensions::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}, - virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, + virtual_poly::{build_eq_x_r_vec_sequential, VirtualPolynomial}, }; -use std::mem; -use sumcheck::entered_span; -use transcript::Transcript; - -#[cfg(feature = "parallel")] -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; +use simple_frontend::structs::LayerId; +use std::sync::Arc; +use sumcheck::{entered_span, util::ceil_log2}; use crate::{ - exit_span, izip_parallizable, - prover::SumcheckState, - structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, Point, PointAndEval}, - tracing_span, - utils::MatrixMLERowFirst, + exit_span, + structs::{ + Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval, SumcheckProof, + }, + utils::{tensor_product, MatrixMLERowFirst}, }; // Prove the items copied from the current layer to later layers for data parallel circuits. -// \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) -// = \sum_y( \sum_j( \alpha^j copy_to[j](ry_j, y) \sum_t( eq(rt_j, t) * layers[i](t || y) ) ) ) impl IOPProverState { - /// Sumcheck 1: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) + /// Sumcheck 1: sigma = \sum_{t || y}(f1({t || y}) * (\sum_j g1^{(j)}({t || y}))) /// sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) - /// f1^{(j)}(y) = \sum_t( eq(rt_j, t) * layers[i](t || y) ) - /// g1^{(j)}(y) = \alpha^j copy_to[j](ry_j, y) - #[tracing::instrument(skip_all, name = "prove_and_update_state_phase1_step1")] - pub(super) fn prove_and_update_state_phase1_step1( - &mut self, + /// f1^{(j)}(y) = layers[i](t || y) + /// g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) + /// g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) + pub(super) fn build_phase1_step1_sumcheck_poly( + &self, + layer_id: LayerId, + alpha: E, + eq_t: &Vec>, circuit: &Circuit, circuit_witness: &CircuitWitness, - transcript: &mut Transcript, - ) -> IOPProverStepMessage { + multi_threads_meta: (usize, usize), + ) -> VirtualPolynomial { + let span = entered_span!("preparation"); let timer = start_timer!(|| "Prover sumcheck phase 1 step 1"); - let alpha = transcript - .get_and_append_challenge(b"combine subset evals") - .elements; + let total_length = self.to_next_phase_point_and_evals.len() + self.subset_point_and_evals[self.layer_id as usize].len() + 1; @@ -54,198 +51,136 @@ impl IOPProverState { let lo_num_vars = circuit.layers[self.layer_id as usize].num_vars; let hi_num_vars = circuit_witness.instance_num_vars(); - // sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) - // f1^{(j)}(y) = \sum_t( eq(rt_j, t) * layers[i](t || y) ) - // g1^{(j)}(y) = \alpha^j copy_to[j](ry_j, y) - let span = entered_span!("fg"); - let (mut f1, mut g1): ( - Vec>, - Vec>, - ) = tracing_span!("f1g1").in_scope(|| { - izip_parallizable!(&self.to_next_phase_point_and_evals, &alpha_pows) - .map(|(point_and_eval, alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + // parallel unit logic handling + let (thread_id, max_thread_id) = multi_threads_meta; + let log2_max_thread_id = ceil_log2(max_thread_id); - let f1_j = self.phase1_layer_polys[self.layer_id as usize] - .fix_high_variables(&point_and_eval.point[point_lo_num_vars..]); + exit_span!(span); - let g1_j = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]) - .into_iter() - .take(1 << lo_num_vars) - .map(|eq| *alpha_pow * eq) - .collect_vec(); + // f1^{(j)}(y) = layers[i](t || y) + let f1: Arc> = circuit_witness + .layer_poly::( + (layer_id).try_into().unwrap(), + lo_num_vars, + multi_threads_meta, + ) + .into(); - assert_eq!(g1_j.len(), 1 << lo_num_vars); - ( - f1_j.into(), - DenseMultilinearExtension::from_evaluations_ext_vec(lo_num_vars, g1_j) - .into(), - ) - }) - .unzip() - }); + assert_eq!( + f1.evaluations.len(), + 1 << (hi_num_vars + lo_num_vars - log2_max_thread_id) + ); + let span = entered_span!("g1"); + // g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) + // g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) let copy_to_matrices = &circuit.layers[self.layer_id as usize].copy_to; - let (f1_subset_point_and_evals, g1_subset_point_and_evals): ( - Vec>, - Vec>, - ) = tracing_span!("f1_j_g1_j").in_scope(|| { - izip_parallizable!( - &self.subset_point_and_evals[self.layer_id as usize], - &alpha_pows[self.to_next_phase_point_and_evals.len()..] - ) - .map(|((new_layer_id, point_and_eval), alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let copy_to = ©_to_matrices[new_layer_id]; - let lo_eq_w_p = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - - let f1_j = self.phase1_layer_polys[self.layer_id as usize] - .fix_high_variables(&point_and_eval.point[point_lo_num_vars..]); - - assert!(copy_to.len() <= lo_eq_w_p.len()); - let g1_j = copy_to.as_slice().fix_row_row_first_with_scalar( - &lo_eq_w_p, - lo_num_vars, - alpha_pow, - ); - ( - f1_j.into(), - DenseMultilinearExtension::from_evaluations_ext_vec(lo_num_vars, g1_j).into(), - ) - }) - .unzip() - }); - exit_span!(span); + let g1: ArcDenseMultilinearExtension = { + let gs = izip!(&self.to_next_phase_point_and_evals, &alpha_pows, eq_t) + .map(|(point_and_eval, alpha_pow, eq_t)| { + // g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) + let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - f1.extend(f1_subset_point_and_evals); - g1.extend(g1_subset_point_and_evals); + let eq_y = + build_eq_x_r_vec_sequential(&point_and_eval.point[..point_lo_num_vars]) + .into_iter() + .take(1 << lo_num_vars) + .map(|eq| *alpha_pow * eq) + .collect_vec(); - // sumcheck: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) - let virtual_poly_1 = tracing_span!("virtual_poly").in_scope(|| { - let mut virtual_poly_1 = VirtualPolynomial::new(lo_num_vars); - for (f1_j, g1_j) in f1.into_iter().zip(g1.into_iter()) { - let mut tmp = VirtualPolynomial::new_from_mle(f1_j, E::BaseField::ONE); - tmp.mul_by_mle(g1_j, E::BaseField::ONE); - virtual_poly_1.merge(&tmp); - } - virtual_poly_1 - }); + let eq_t_unit_len = eq_t.len() / max_thread_id; + let start_index = thread_id * eq_t_unit_len; + let g1_j = + tensor_product(&eq_t[start_index..(start_index + eq_t_unit_len)], &eq_y); - let (sumcheck_proof_1, prover_state) = - SumcheckState::prove_parallel(virtual_poly_1, transcript); - let (f1, g1): (Vec<_>, Vec<_>) = prover_state - .get_mle_final_evaluations() - .into_iter() - .enumerate() - .partition(|(i, _)| i % 2 == 0); - let eval_value_1 = f1.into_iter().map(|(_, f1_j)| f1_j).collect_vec(); + assert_eq!( + g1_j.len(), + (1 << (hi_num_vars + lo_num_vars - log2_max_thread_id)) + ); - self.to_next_step_point = sumcheck_proof_1.point.clone(); - self.g1_values = g1.into_iter().map(|(_, g1_j)| g1_j).collect_vec(); + g1_j + }) + .chain( + izip!( + &self.subset_point_and_evals[self.layer_id as usize], + &alpha_pows[self.to_next_phase_point_and_evals.len()..], + eq_t.iter().skip(self.to_next_phase_point_and_evals.len()) + ) + .map( + |((new_layer_id, point_and_eval), alpha_pow, eq_t)| { + let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + let copy_to = ©_to_matrices[new_layer_id]; + let lo_eq_w_p = build_eq_x_r_vec_sequential( + &point_and_eval.point[..point_lo_num_vars], + ); + + // g2^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) + let eq_t_unit_len = eq_t.len() / max_thread_id; + let start_index = thread_id * eq_t_unit_len; + let g2_j = tensor_product( + &eq_t[start_index..(start_index + eq_t_unit_len)], + ©_to.as_slice().fix_row_row_first_with_scalar( + &lo_eq_w_p, + lo_num_vars, + alpha_pow, + ), + ); + + assert_eq!( + g2_j.len(), + (1 << (hi_num_vars + lo_num_vars - log2_max_thread_id)) + ); + g2_j + }, + ), + ) + .collect::>>(); + + DenseMultilinearExtension::from_evaluations_ext_vec( + hi_num_vars + lo_num_vars - log2_max_thread_id, + gs.into_iter() + .fold(vec![E::ZERO; 1 << f1.num_vars], |mut acc, g| { + assert_eq!(1 << f1.num_vars, g.len()); + acc.iter_mut().enumerate().for_each(|(i, v)| *v += g[i]); + acc + }), + ) + .into() + }; + exit_span!(span); + // sumcheck: sigma = \sum_{s || y}(f1({s || y}) * (\sum_j g1^{(j)}({s || y}))) + let span = entered_span!("virtual_poly"); + let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1, E::BaseField::ONE); + virtual_poly_1.mul_by_mle(g1, E::BaseField::ONE); + exit_span!(span); end_timer!(timer); - IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_1, - sumcheck_eval_values: eval_value_1, - } + + virtual_poly_1 } - /// Sumcheck 2: sigma = \sum_t( \sum_j( f2^{(j)}(t) ) ) * g2(t) - /// sigma = \sum_j( f1^{(j)}(ry) * g1^{(j)}(ry) ) - /// f2(t) = layers[i](t || ry) - /// g2^{(j)}(t) = \alpha^j copy_to[j](ry_j, ry) eq(rt_j, t) - #[tracing::instrument(skip_all, name = "prove_and_update_state_phase1_step2")] - pub(super) fn prove_and_update_state_phase1_step2( + pub(super) fn combine_phase1_step1_evals( &mut self, - _: &Circuit, - circuit_witness: &CircuitWitness, - transcript: &mut Transcript, + sumcheck_proof_1: SumcheckProof, + prover_state: sumcheck::structs::IOPProverState, ) -> IOPProverStepMessage { - let timer = start_timer!(|| "Prover sumcheck phase 1 step 2"); - let hi_num_vars = circuit_witness.instance_num_vars(); - - let span = entered_span!("f2_fix_variables"); - // f2(t) = layers[i](t || ry) - let f2 = mem::take(&mut self.phase1_layer_polys[self.layer_id as usize]) - .fix_variables_parallel(&self.to_next_step_point) - .into(); - exit_span!(span); - - // g2^{(j)}(t) = \alpha^j copy_to[j](ry_j, ry) eq(rt_j, t) - let output_points: Vec<&Point> = chain![ - self.to_next_phase_point_and_evals.iter().map(|x| &x.point), - self.subset_point_and_evals[self.layer_id as usize] - .iter() - .map(|x| &x.1.point), - ] - .collect(); - let span = entered_span!("g2"); - - #[cfg(not(feature = "parallel"))] - let zeros = vec![E::ZERO; 1 << hi_num_vars]; - - #[cfg(feature = "parallel")] - let zeros = || vec![E::ZERO; 1 << hi_num_vars]; - - let g2 = izip_parallizable!(output_points, &self.g1_values) - .map(|(point, &g1_value)| { - let point_lo_num_vars = point.len() - hi_num_vars; - build_eq_x_r_vec(&point[point_lo_num_vars..]) - .into_iter() - .map(|eq| g1_value * eq) - .collect_vec() - }) - .fold(zeros, |acc, nxt| { - acc.into_iter() - .zip(nxt.into_iter()) - .map(|(a, b)| a + b) - .collect_vec() - }); - - #[cfg(not(feature = "parallel"))] - let g2 = DenseMultilinearExtension::from_evaluations_ext_vec(hi_num_vars, g2); - - // When rayon is used, the `fold` operation results in a iterator of `Vec` rather than a single `Vec`. In this case, we simply need to sum them. - #[cfg(feature = "parallel")] - let g2 = DenseMultilinearExtension::from_evaluations_ext_vec( - hi_num_vars, - g2.reduce(zeros, |acc, nxt| { - acc.into_iter() - .zip(nxt.into_iter()) - .map(|(a, b)| a + b) - .collect_vec() - }), - ); - - exit_span!(span); - - // sumcheck: sigma = \sum_t( \sum_j( g2^{(j)}(t) ) ) * f2(t) - let mut virtual_poly_2 = VirtualPolynomial::new_from_mle(f2, E::BaseField::ONE); - virtual_poly_2.mul_by_mle(g2.into(), E::BaseField::ONE); - - let (sumcheck_proof_2, prover_state) = - SumcheckState::prove_parallel(virtual_poly_2, transcript); - let (mut f2, _): (Vec<_>, Vec<_>) = prover_state + let (mut f1, _): (Vec<_>, Vec<_>) = prover_state .get_mle_final_evaluations() .into_iter() .enumerate() .partition(|(i, _)| i % 2 == 0); - let eval_value_2 = f2.remove(0).1; - self.to_next_step_point = [ - mem::take(&mut self.to_next_step_point), - sumcheck_proof_2.point.clone(), - ] - .concat(); + let eval_value_1 = f1.remove(0).1; + + self.to_next_step_point = sumcheck_proof_1.point.clone(); self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref( &self.to_next_step_point, - &eval_value_2, + &eval_value_1, )]; self.subset_point_and_evals[self.layer_id as usize].clear(); - end_timer!(timer); IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_2, - sumcheck_eval_values: vec![eval_value_2], + sumcheck_proof: sumcheck_proof_1, + sumcheck_eval_values: vec![eval_value_1], } } } diff --git a/gkr/src/prover/phase1_output.rs b/gkr/src/prover/phase1_output.rs index 3c26eb112..3dbce07d2 100644 --- a/gkr/src/prover/phase1_output.rs +++ b/gkr/src/prover/phase1_output.rs @@ -1,12 +1,12 @@ use ark_std::{end_timer, start_timer}; use ff::Field; use ff_ext::ExtensionField; -use itertools::{chain, Itertools}; +use itertools::{chain, izip, Itertools}; use multilinear_extensions::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}, virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, }; -use std::{iter, mem}; +use std::{iter, mem, sync::Arc}; use transcript::Transcript; use crate::{ @@ -26,7 +26,7 @@ impl IOPProverState { /// Sumcheck 1: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) /// sigma = \sum_j( \alpha^j * wit_out_eval[j](rt_j || ry_j) ) /// + \alpha^{wit_out_eval[j].len()} * assert_const(rt || ry) ) - /// f1^{(j)}(y) = \sum_t( eq(rt_j, t) * layers[i](t || y) ) + /// f1^{(j)}(y) = layers[i](rt_j || y) /// g1^{(j)}(y) = \alpha^j eq(ry_j, y) // or \alpha^j copy_to[j](ry_j, y) // or \alpha^j assert_subset_eq(ry, y) @@ -41,6 +41,7 @@ impl IOPProverState { let alpha = transcript .get_and_append_challenge(b"combine subset evals") .elements; + let total_length = self.to_next_phase_point_and_evals.len() + self.subset_point_and_evals[self.layer_id as usize].len() + 1; @@ -55,11 +56,15 @@ impl IOPProverState { let lo_num_vars = circuit.layers[self.layer_id as usize].num_vars; let hi_num_vars = circuit_witness.instance_num_vars(); + self.phase1_layer_poly = circuit_witness + .layer_poly::((self.layer_id).try_into().unwrap(), lo_num_vars, (0, 1)) + .into(); + // sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) - // f1^{(j)}(y) = \sum_t( eq(rt_j, t) * layers[i](t || y) ) + // f1^{(j)}(y) = layers[i](rt_j || y) // g1^{(j)}(y) = \alpha^j eq(ry_j, y) // or \alpha^j copy_to[j](ry_j, y) - // or \alpha^j assert_subset_eq[j](ry, y) + // or \alpha^j assert_subset_eq(ry, y) // TODO: Double check the soundness here. let (mut f1, mut g1): ( Vec>, @@ -70,7 +75,8 @@ impl IOPProverState { let point = &point_and_eval.point; let lo_eq_w_p = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - let f1_j = self.phase1_layer_polys[self.layer_id as usize] + let f1_j = self + .phase1_layer_poly .fix_high_variables(&point[point_lo_num_vars..]); let g1_j = lo_eq_w_p @@ -88,18 +94,20 @@ impl IOPProverState { let (f1_copy_to, g1_copy_to): ( Vec>, Vec>, - ) = izip_parallizable!( + ) = izip!( &circuit.copy_to_wits_out, &self.subset_point_and_evals[self.layer_id as usize], &alpha_pows[self.to_next_phase_point_and_evals.len()..] ) .map(|(copy_to, (_, point_and_eval), alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; let point = &point_and_eval.point; - let lo_eq_w_p = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); + let point_lo_num_vars = point.len() - hi_num_vars; + + let lo_eq_w_p = build_eq_x_r_vec(&point[..point_lo_num_vars]); assert!(copy_to.len() <= lo_eq_w_p.len()); - let f1_j = self.phase1_layer_polys[self.layer_id as usize] + let f1_j = self + .phase1_layer_poly .fix_high_variables(&point[point_lo_num_vars..]); let g1_j = copy_to.as_slice().fix_row_row_first_with_scalar( @@ -118,12 +126,14 @@ impl IOPProverState { f1.extend(f1_copy_to); g1.extend(g1_copy_to); - let f1_j = self.phase1_layer_polys[self.layer_id as usize] + let f1_j = self + .phase1_layer_poly .fix_high_variables(&self.assert_point[lo_num_vars..]); f1.push(f1_j.into()); let alpha_pow = alpha_pows.last().unwrap(); let lo_eq_w_p = build_eq_x_r_vec(&self.assert_point[..lo_num_vars]); + let mut g_last = vec![E::ZERO; 1 << lo_num_vars]; circuit.assert_consts.iter().for_each(|gate| { g_last[gate.idx_out as usize] = lo_eq_w_p[gate.idx_out as usize] * alpha_pow; @@ -175,9 +185,9 @@ impl IOPProverState { let hi_num_vars = circuit_witness.instance_num_vars(); // f2(t) = layers[i](t || ry) - let f2 = self.phase1_layer_polys[self.layer_id as usize] - .fix_variables_parallel(&self.to_next_step_point) - .into(); + let mut f2 = mem::take(&mut self.phase1_layer_poly); + + Arc::make_mut(&mut f2).fix_variables_in_place_parallel(&self.to_next_step_point); // g2(t) = \sum_j \alpha^j (eq or copy_to[j] or assert_subset)(ry_j, ry) eq(rt_j, t) let output_points = chain![ diff --git a/gkr/src/prover/phase2_input.rs b/gkr/src/prover/phase2_input.rs index c3bfbe4db..350e0c644 100644 --- a/gkr/src/prover/phase2_input.rs +++ b/gkr/src/prover/phase2_input.rs @@ -65,7 +65,6 @@ impl IOPProverState { } g[subset_wire_id] = eq_y_ry[new_wire_id]; } - ( { let mut f = DenseMultilinearExtension::from_evaluations_vec( diff --git a/gkr/src/structs.rs b/gkr/src/structs.rs index ec4489551..3f667a7a6 100644 --- a/gkr/src/structs.rs +++ b/gkr/src/structs.rs @@ -62,14 +62,12 @@ pub struct IOPProverState { /// The point to the next step. pub(crate) to_next_step_point: Point, - /// layer poly for phase 1 which point to current layer - pub(crate) phase1_layer_polys: Vec>, // Especially for output phase1. + pub(crate) phase1_layer_poly: ArcDenseMultilinearExtension, pub(crate) assert_point: Point, // Especially for phase1. pub(crate) g1_values: Vec, - // Especially for phase2. } /// Represent the verifier state for each layer in the IOP protocol. @@ -126,7 +124,6 @@ pub(crate) enum SumcheckStepType { OutputPhase1Step1, OutputPhase1Step2, Phase1Step1, - Phase1Step2, Phase2Step1, Phase2Step2, Phase2Step2NoStep3, @@ -151,11 +148,11 @@ pub struct Layer { // Gates. Should be all None if it's the input layer. pub(crate) add_consts: Vec>>, pub(crate) adds: Vec>>, - pub(crate) adds_fanin_mapping: [BTreeMap>>>; 1], // grouping for 1 fanins + pub(crate) adds_fanin_mapping: [BTreeMap>>>; 1], /* grouping for 1 fanins */ pub(crate) mul2s: Vec>>, - pub(crate) mul2s_fanin_mapping: [BTreeMap>>>; 2], // grouping for 2 fanins + pub(crate) mul2s_fanin_mapping: [BTreeMap>>>; 2], /* grouping for 2 fanins */ pub(crate) mul3s: Vec>>, - pub(crate) mul3s_fanin_mapping: [BTreeMap>>>; 3], // grouping for 3 fanins + pub(crate) mul3s_fanin_mapping: [BTreeMap>>>; 3], /* grouping for 3 fanins */ /// The corresponding wires copied from this layer to later layers. It is /// (later layer id -> current wire id to be copied). It stores the non-zero @@ -201,7 +198,7 @@ pub struct Circuit { pub paste_from_wits_in: Vec<(CellId, CellId)>, /// The endpoints in the input layer copied from counter. pub paste_from_counter_in: Vec<(usize, (CellId, CellId))>, - /// The endpoints in the output layer copied to each output witness. + /// The endpoints in the input layer copied from constants pub paste_from_consts_in: Vec<(i64, (CellId, CellId))>, /// The wires copied to the output witness pub copy_to_wits_out: Vec>, diff --git a/gkr/src/utils.rs b/gkr/src/utils.rs index 284e9f09d..2670c68f6 100644 --- a/gkr/src/utils.rs +++ b/gkr/src/utils.rs @@ -176,7 +176,7 @@ pub fn eq4_eval(x: &[E], y: &[E], z: &[E], w: &[E]) -> E { res } -pub fn tensor_product(a: &[F], b: &[F]) -> Vec { +pub fn tensor_product(a: &[F], b: &[F]) -> Vec { let mut res = vec![F::ZERO; a.len() * b.len()]; for i in 0..a.len() { for j in 0..b.len() { diff --git a/gkr/src/verifier.rs b/gkr/src/verifier.rs index b197e9e58..d84ff6ce2 100644 --- a/gkr/src/verifier.rs +++ b/gkr/src/verifier.rs @@ -1,6 +1,6 @@ use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; -use itertools::Itertools; +use itertools::{izip, Itertools}; use simple_frontend::structs::{ChallengeConst, LayerId}; use std::collections::HashMap; use transcript::Transcript; @@ -8,7 +8,8 @@ use transcript::Transcript; use crate::{ error::GKRError, structs::{ - Circuit, GKRInputClaims, IOPProof, IOPVerifierState, PointAndEval, SumcheckStepType, + Circuit, GKRInputClaims, IOPProof, IOPProverStepMessage, IOPVerifierState, PointAndEval, + SumcheckStepType, }, }; @@ -44,48 +45,44 @@ impl IOPVerifierState { circuit.layers[0].num_vars + instance_num_vars, ); - let mut step_proof_iter = proof.sumcheck_proofs.into_iter(); + let mut sumcheck_proofs_iter = proof.sumcheck_proofs.into_iter(); for layer_id in 0..circuit.layers.len() { let timer = start_timer!(|| format!("Verify layer {}", layer_id)); verifier_state.layer_id = layer_id as LayerId; let layer = &circuit.layers[layer_id as usize]; - for step in layer.sumcheck_steps.iter() { - let step_msg = step_proof_iter - .next() - .ok_or(GKRError::VerifyError("Wrong number of step proofs"))?; + for (step, step_proof) in izip!(layer.sumcheck_steps.iter(), &mut sumcheck_proofs_iter) + { match step { SumcheckStepType::OutputPhase1Step1 => verifier_state .verify_and_update_state_output_phase1_step1( - circuit, step_msg, transcript, + circuit, step_proof, transcript, )?, SumcheckStepType::OutputPhase1Step2 => verifier_state .verify_and_update_state_output_phase1_step2( - circuit, step_msg, transcript, + circuit, step_proof, transcript, )?, SumcheckStepType::Phase1Step1 => verifier_state - .verify_and_update_state_phase1_step1(circuit, step_msg, transcript)?, - SumcheckStepType::Phase1Step2 => verifier_state - .verify_and_update_state_phase1_step2(circuit, step_msg, transcript)?, + .verify_and_update_state_phase1_step1(circuit, step_proof, transcript)?, SumcheckStepType::Phase2Step1 => verifier_state - .verify_and_update_state_phase2_step1(circuit, step_msg, transcript)?, + .verify_and_update_state_phase2_step1(circuit, step_proof, transcript)?, SumcheckStepType::Phase2Step2 => verifier_state .verify_and_update_state_phase2_step2( - circuit, step_msg, transcript, false, + circuit, step_proof, transcript, false, )?, SumcheckStepType::Phase2Step2NoStep3 => verifier_state .verify_and_update_state_phase2_step2( - circuit, step_msg, transcript, true, + circuit, step_proof, transcript, true, )?, SumcheckStepType::Phase2Step3 => verifier_state - .verify_and_update_state_phase2_step3(circuit, step_msg, transcript)?, + .verify_and_update_state_phase2_step3(circuit, step_proof, transcript)?, SumcheckStepType::LinearPhase2Step1 => verifier_state .verify_and_update_state_linear_phase2_step1( - circuit, step_msg, transcript, + circuit, step_proof, transcript, )?, SumcheckStepType::InputPhase2Step1 => verifier_state .verify_and_update_state_input_phase2_step1( - circuit, step_msg, transcript, + circuit, step_proof, transcript, )?, _ => unreachable!(), } diff --git a/gkr/src/verifier/phase1.rs b/gkr/src/verifier/phase1.rs index 6d30c4e08..3bc6c13a3 100644 --- a/gkr/src/verifier/phase1.rs +++ b/gkr/src/verifier/phase1.rs @@ -2,7 +2,7 @@ use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; use itertools::{chain, izip, Itertools}; use multilinear_extensions::virtual_poly::{build_eq_x_r_vec, eq_eval, VPAuxInfo}; -use std::{marker::PhantomData, mem}; +use std::marker::PhantomData; use transcript::Transcript; use crate::{ @@ -52,27 +52,42 @@ impl IOPVerifierState { .fold(E::ZERO, |acc, ((_, point_and_eval), alpha_pow)| { acc + point_and_eval.eval * alpha_pow }); - // Sumcheck 1: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) - // f1^{(j)}(y) = \sum_t( eq(rt_j, t) * layers[i](t || y) ) - // g1^{(j)}(y) = \alpha^j copy_to[j](ry_j, y) + + // Sumcheck: sigma = \sum_{t || y}(f1({t || y}) * (\sum_j g1^{(j)}({t || y}))) + // f1^{(j)}(y) = layers[i](t || y) + // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) + // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) let claim_1 = SumcheckState::verify( sigma_1, &step_msg.sumcheck_proof, &VPAuxInfo { max_degree: 2, - num_variables: lo_num_vars, + num_variables: lo_num_vars + hi_num_vars, phantom: PhantomData, }, transcript, ); + let claim1_point = claim_1.point.iter().map(|x| x.elements).collect_vec(); - let eq_y_ry = build_eq_x_r_vec(&claim1_point); + let claim1_point_lo_num_vars = claim1_point.len() - hi_num_vars; + let eq_y_ry = build_eq_x_r_vec(&claim1_point[..claim1_point_lo_num_vars]); - self.g1_values = chain![ + assert_eq!(step_msg.sumcheck_eval_values.len(), 1); + let f_value = step_msg.sumcheck_eval_values[0]; + + let g_value: E = chain![ izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()).map( |(point_and_eval, alpha_pow)| { let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - eq_eval(&point_and_eval.point[..point_lo_num_vars], &claim1_point) * alpha_pow + let eq_t = eq_eval( + &point_and_eval.point[point_lo_num_vars..], + &claim1_point[(claim1_point.len() - hi_num_vars)..], + ); + let eq_y = eq_eval( + &point_and_eval.point[..point_lo_num_vars], + &claim1_point[..point_lo_num_vars], + ); + eq_t * eq_y * alpha_pow } ), izip!( @@ -83,88 +98,28 @@ impl IOPVerifierState { ) .map(|((new_layer_id, point_and_eval), alpha_pow)| { let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + let eq_t = eq_eval( + &point_and_eval.point[point_lo_num_vars..], + &claim1_point[(claim1_point.len() - hi_num_vars)..], + ); let eq_yj_ryj = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - circuit.layers[self.layer_id as usize].copy_to[new_layer_id] + eq_t * circuit.layers[self.layer_id as usize].copy_to[new_layer_id] .as_slice() .eval_row_first(&eq_yj_ryj, &eq_y_ry) * alpha_pow }), ] - .collect_vec(); - - let got_value_1 = izip!(step_msg.sumcheck_eval_values.iter(), self.g1_values.iter()) - .fold(E::ZERO, |acc, (&f1, g1)| acc + f1 * g1); - end_timer!(timer); - - if claim_1.expected_evaluation != got_value_1 { - return Err(GKRError::VerifyError("phase1 step1 failed")); - } - - self.to_next_step_point_and_eval = - PointAndEval::new(claim1_point, claim_1.expected_evaluation); - - Ok(()) - } - - pub(super) fn verify_and_update_state_phase1_step2( - &mut self, - _: &Circuit, - step_msg: IOPProverStepMessage, - transcript: &mut Transcript, - ) -> Result<(), GKRError> { - let timer = start_timer!(|| "Verifier sumcheck phase 1 step 2"); - let hi_num_vars = self.instance_num_vars; + .sum(); - // sigma = \sum_j( f1^{(j)}(ry) * g1^{(j)}(ry) ) - // Sumcheck 2: sigma = \sum_t( \sum_j( g2^{(j)}(t) ) ) * f2(t) - // f2(t) = layers[i](t || ry) - // g2^{(j)}(t) = \alpha^j copy_to[j](ry_j, r_y) eq(rt_j, t) - let claim_2 = SumcheckState::verify( - self.to_next_step_point_and_eval.eval, - &step_msg.sumcheck_proof, - &VPAuxInfo { - max_degree: 2, - num_variables: hi_num_vars, - phantom: PhantomData, - }, - transcript, - ); - let claim2_point = claim_2.point.iter().map(|x| x.elements).collect_vec(); - - let output_points = chain![ - self.to_next_phase_point_and_evals.iter().map(|x| &x.point), - self.subset_point_and_evals[self.layer_id as usize] - .iter() - .map(|x| &x.1.point), - ]; - let f2_value = step_msg.sumcheck_eval_values[0]; - let g2_value = output_points - .zip(self.g1_values.iter()) - .map(|(point, g1_value)| { - let point_lo_num_vars = point.len() - hi_num_vars; - *g1_value * eq_eval(&point[point_lo_num_vars..], &claim2_point) - }) - .fold(E::ZERO, |acc, value| acc + value); - - let got_value_2 = f2_value * g2_value; + let got_value = f_value * g_value; end_timer!(timer); - if claim_2.expected_evaluation != got_value_2 { - return Err(GKRError::VerifyError("output phase1 step2 failed")); + if claim_1.expected_evaluation != got_value { + return Err(GKRError::VerifyError("phase1 step1 failed")); } + self.to_next_step_point_and_eval = PointAndEval::new_from_ref(&claim1_point, &f_value); + self.to_next_phase_point_and_evals = vec![self.to_next_step_point_and_eval.clone()]; - self.to_next_step_point_and_eval = PointAndEval::new( - [ - mem::take(&mut self.to_next_step_point_and_eval.point), - claim2_point, - ] - .concat(), - f2_value, - ); - self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref( - &self.to_next_step_point_and_eval.point, - &f2_value, - )]; self.subset_point_and_evals[self.layer_id as usize].clear(); Ok(()) diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index c2f3d5515..4b8a9baec 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -8,7 +8,6 @@ use ff_ext::ExtensionField; use rayon::iter::IntoParallelRefIterator; use serde::{Deserialize, Serialize}; -#[cfg(feature = "parallel")] use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; #[derive(Clone, PartialEq, Eq, Hash, Default, Debug, Serialize, Deserialize)] @@ -68,12 +67,14 @@ impl DenseMultilinearExtension { } } - /// Identical to [`from_evaluations_slice`], with and exception that evaluation vector is in extension field + /// Identical to [`from_evaluations_slice`], with and exception that evaluation vector is in + /// extension field pub fn from_evaluations_ext_slice(num_vars: usize, evaluations: &[E]) -> Self { Self::from_evaluations_ext_vec(num_vars, evaluations.to_vec()) } - /// Identical to [`from_evaluations_vec`], with and exception that evaluation vector is in extension field + /// Identical to [`from_evaluations_vec`], with and exception that evaluation vector is in + /// extension field pub fn from_evaluations_ext_vec(num_vars: usize, evaluations: Vec) -> Self { // assert that the number of variables matches the size of evaluations // TODO: return error. diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index 34b96a047..eb3edac92 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -1,18 +1,17 @@ -use std::cmp::max; -use std::mem; -use std::{collections::HashMap, marker::PhantomData, sync::Arc}; - -use crate::mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}; -use crate::util::bit_decompose; -use ark_std::rand::Rng; -use ark_std::{end_timer, start_timer}; -use ff::Field; -use ff::PrimeField; +use std::{cmp::max, collections::HashMap, marker::PhantomData, mem, sync::Arc}; + +use crate::{ + mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}, + util::bit_decompose, +}; +use ark_std::{end_timer, rand::Rng, start_timer}; +use ff::{Field, PrimeField}; use ff_ext::ExtensionField; -use rayon::iter::IntoParallelIterator; -use rayon::iter::IntoParallelRefIterator; -use rayon::prelude::{IndexedParallelIterator, ParallelIterator}; -use rayon::slice::ParallelSliceMut; +use rayon::{ + iter::{IntoParallelIterator, IntoParallelRefIterator}, + prelude::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSliceMut, +}; use serde::{Deserialize, Serialize}; #[rustfmt::skip] diff --git a/rustfmt.toml b/rustfmt.toml index 1b31b7d5f..835c6b277 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,10 +1,9 @@ edition = "2021" -comment_width = 100 +wrap_comments = false +comment_width = 300 imports_granularity = "Crate" max_width = 100 newline_style = "Unix" normalize_comments = true reorder_imports = true -wrap_comments = true - diff --git a/singer-pro/src/basic_block/bb_final.rs b/singer-pro/src/basic_block/bb_final.rs index 7d3b1fef6..3fc0e4531 100644 --- a/singer-pro/src/basic_block/bb_final.rs +++ b/singer-pro/src/basic_block/bb_final.rs @@ -14,7 +14,7 @@ use singer_utils::{ register_witness, structs::{ChipChallenges, InstOutChipType, PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{BBFinalCircuit, BBFinalLayout, FromBBStart, FromPredInst, FromWitness}, diff --git a/singer-pro/src/basic_block/bb_ret.rs b/singer-pro/src/basic_block/bb_ret.rs index 2eafaab73..e351fca36 100644 --- a/singer-pro/src/basic_block/bb_ret.rs +++ b/singer-pro/src/basic_block/bb_ret.rs @@ -9,7 +9,7 @@ use singer_utils::{ register_witness, structs::{ChipChallenges, InstOutChipType, RAMHandler, ROMHandler, StackUInt, TSUInt}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{ diff --git a/singer-pro/src/instructions/add.rs b/singer-pro/src/instructions/add.rs index b9fa98d25..14b919cd0 100644 --- a/singer-pro/src/instructions/add.rs +++ b/singer-pro/src/instructions/add.rs @@ -10,7 +10,7 @@ use singer_utils::{ register_witness, structs::{ChipChallenges, InstOutChipType, ROMHandler, StackUInt, TSUInt}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{FromPredInst, FromWitness, InstCircuit, InstLayout, ToSuccInst}, diff --git a/singer-pro/src/instructions/calldataload.rs b/singer-pro/src/instructions/calldataload.rs index 0da68e250..152788138 100644 --- a/singer-pro/src/instructions/calldataload.rs +++ b/singer-pro/src/instructions/calldataload.rs @@ -9,7 +9,7 @@ use singer_utils::{ register_witness, structs::{ChipChallenges, InstOutChipType, ROMHandler, StackUInt, TSUInt, UInt64}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{FromPredInst, FromWitness, InstCircuit, InstLayout, ToSuccInst}, diff --git a/singer-pro/src/instructions/gt.rs b/singer-pro/src/instructions/gt.rs index ea55a3dc6..1bc74c1f6 100644 --- a/singer-pro/src/instructions/gt.rs +++ b/singer-pro/src/instructions/gt.rs @@ -10,7 +10,7 @@ use singer_utils::{ register_witness, structs::{ChipChallenges, InstOutChipType, ROMHandler, StackUInt, TSUInt}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{FromPredInst, FromWitness, InstCircuit, InstLayout, ToSuccInst}, diff --git a/singer-pro/src/instructions/jumpi.rs b/singer-pro/src/instructions/jumpi.rs index a9855e4d4..34a2a0c55 100644 --- a/singer-pro/src/instructions/jumpi.rs +++ b/singer-pro/src/instructions/jumpi.rs @@ -11,7 +11,7 @@ use singer_utils::{ register_witness, structs::{ChipChallenges, InstOutChipType, PCUInt, ROMHandler, StackUInt, TSUInt}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::{ component::{FromPredInst, FromWitness, InstCircuit, InstLayout, ToSuccInst}, diff --git a/singer-pro/src/instructions/mstore.rs b/singer-pro/src/instructions/mstore.rs index 31110f067..c8ea30089 100644 --- a/singer-pro/src/instructions/mstore.rs +++ b/singer-pro/src/instructions/mstore.rs @@ -12,7 +12,7 @@ use singer_utils::{ structs::{ChipChallenges, InstOutChipType, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::constants::AddSubConstants, }; -use std::{mem, sync::Arc}; +use std::{collections::BTreeMap, mem, sync::Arc}; use crate::{ component::{ diff --git a/singer-pro/src/instructions/ret.rs b/singer-pro/src/instructions/ret.rs index 84ea49685..e319e5055 100644 --- a/singer-pro/src/instructions/ret.rs +++ b/singer-pro/src/instructions/ret.rs @@ -11,7 +11,7 @@ use singer_utils::{ register_witness, structs::{ChipChallenges, InstOutChipType, RAMHandler, ROMHandler, StackUInt, TSUInt}, }; -use std::{mem, sync::Arc}; +use std::{collections::BTreeMap, mem, sync::Arc}; use crate::{ component::{ diff --git a/singer-utils/src/macros.rs b/singer-utils/src/macros.rs index 612de1966..b869e3ee5 100644 --- a/singer-utils/src/macros.rs +++ b/singer-utils/src/macros.rs @@ -1,5 +1,6 @@ #[macro_export] macro_rules! register_witness { + // phaseX_size() implementation ($struct_name:ident, $($wire_name:ident { $($slice_name:ident => $length:expr),* }),*) => { paste! { impl $struct_name { @@ -10,11 +11,24 @@ macro_rules! register_witness { } register_witness!(@internal $wire_name, 0usize; $($slice_name => $length),*); + + #[inline] + pub fn [<$wire_name _ idxes_map>]() -> BTreeMap<&'static str, std::ops::Range> { + let mut map = BTreeMap::new(); + + $( + map.insert(stringify!([<$wire_name _ $slice_name>]), Self::[<$wire_name _ $slice_name>]()); + )* + + map + } + )* } } }; + ($struct_name:ident, $($wire_name:ident { $($slice_name:ident => $length:expr),* }),*) => { paste! { impl $struct_name { @@ -25,6 +39,18 @@ macro_rules! register_witness { } register_witness!(@internal $wire_name, 0usize; $($slice_name => $length),*); + + #[inline] + pub fn [<$wire_name _ idxes_map>]() -> BTreeMap<&'static str, std::ops::Range> { + let mut map = BTreeMap::new(); + + $( + map.insert(stringify!([<$wire_name _ $slice_name>]), Self::[<$wire_name _ $slice_name>]()); + )* + + map + } + )* } } @@ -32,9 +58,13 @@ macro_rules! register_witness { (@internal $wire_name:ident, $offset:expr; $name:ident => $length:expr $(, $rest:ident => $rest_length:expr)*) => { paste! { - fn [<$wire_name _ $name>]() -> std::ops::Range { + pub fn [<$wire_name _ $name>]() -> std::ops::Range { $offset..$offset + $length } + + pub fn [<$wire_name _ $name _ str>]() -> &'static str { + stringify!([<$wire_name _ $name>]) + } register_witness!(@internal $wire_name, $offset + $length; $($rest => $rest_length),*); } }; diff --git a/singer/Cargo.toml b/singer/Cargo.toml index ea3a9dd08..e81d0bb27 100644 --- a/singer/Cargo.toml +++ b/singer/Cargo.toml @@ -26,7 +26,9 @@ itertools = "0.12.0" strum = "0.25.0" strum_macros = "0.25.3" paste = "1.0.14" - +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +tracing-flame = "0.2.0" +tracing = "0.1.40" [dev-dependencies] pprof = { version = "0.13", features = ["flamegraph"]} diff --git a/singer/benches/add.rs b/singer/benches/add.rs index b55578198..d674f9795 100644 --- a/singer/benches/add.rs +++ b/singer/benches/add.rs @@ -7,8 +7,7 @@ use ark_std::test_rng; use const_env::from_env; use criterion::*; -use ff_ext::ff::Field; -use ff_ext::ExtensionField; +use ff_ext::{ff::Field, ExtensionField}; use gkr::structs::LayerWitness; use goldilocks::GoldilocksExt2; use itertools::Itertools; @@ -57,8 +56,7 @@ fn bench_add(c: &mut Criterion) { #[cfg(feature = "non_pow2_rayon_thread")] { - use sumcheck::local_thread_pool::create_local_pool_once; - use sumcheck::util::ceil_log2; + use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); max_thread_id @@ -71,85 +69,80 @@ fn bench_add(c: &mut Criterion) { let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); - for instance_num_vars in 11..12 { + for instance_num_vars in 10..14 { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("add_op_{}", instance_num_vars)); group.sample_size(NUM_SAMPLES); // Benchmark the proving time group.bench_function( - BenchmarkId::new("prove_keccak256", format!("keccak256_log2_{}", instance_num_vars)), + BenchmarkId::new("prove_add", format!("prove_add_log2_{}", instance_num_vars)), |b| { b.iter_with_setup( || { let mut rng = test_rng(); let singer_builder = SingerGraphBuilder::::new(); - - let real_challenges = vec![E::random(&mut rng), E::random(&mut rng)]; - (rng, singer_builder, real_challenges) + let real_challenges = vec![E::random(&mut rng), E::random(&mut rng)]; + (rng, singer_builder, real_challenges) }, - | (mut rng,mut singer_builder, real_challenges)| { - - let size = AddInstruction::phase0_size(); - - let phase0: CircuitWiresIn< - ::BaseField, - > = vec![LayerWitness { - instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| { - ::BaseField::random( - &mut rng, - ) - }) - .collect_vec() - }) - .collect_vec(), - }]; - - - let timer = Instant::now(); - - let _ = AddInstruction::construct_graph_and_witness( - &mut singer_builder.graph_builder, - &mut singer_builder.chip_builder, - &circuit_builder.insts_circuits - [>::OPCODE as usize], - vec![phase0], - &real_challenges, - 1 << instance_num_vars, - &SingerParams::default(), - ) - .expect("gkr graph construction failed"); - - let (graph, wit) = singer_builder.graph_builder.finalize_graph_and_witness(); - - println!( - "AddInstruction::construct_graph_and_witness, instance_num_vars = {}, time = {}", - instance_num_vars, - timer.elapsed().as_secs_f64() - ); - - let point = vec![E::random(&mut rng), E::random(&mut rng)]; - let target_evals = graph.target_evals(&wit, &point); - - let mut prover_transcript = &mut Transcript::new(b"Singer"); - - let timer = Instant::now(); - let _ = GKRGraphProverState::prove( - &graph, - &wit, - &target_evals, - &mut prover_transcript, - (1 << instance_num_vars).min(max_thread_id), - ) - .expect("prove failed"); - println!( - "AddInstruction::prove, instance_num_vars = {}, time = {}", - instance_num_vars, - timer.elapsed().as_secs_f64() - ); + |(mut rng,mut singer_builder, real_challenges)| { + let size = AddInstruction::phase0_size(); + let phase0: CircuitWiresIn<::BaseField> = vec![LayerWitness { + instances: (0..(1 << instance_num_vars)) + .map(|_| { + (0..size) + .map(|_| { + ::BaseField::random( + &mut rng, + ) + }) + .collect_vec() + }) + .collect_vec(), + }]; + + + let timer = Instant::now(); + + let _ = AddInstruction::construct_graph_and_witness( + &mut singer_builder.graph_builder, + &mut singer_builder.chip_builder, + &circuit_builder.insts_circuits + [>::OPCODE as usize], + vec![phase0], + &real_challenges, + 1 << instance_num_vars, + &SingerParams::default(), + ) + .expect("gkr graph construction failed"); + + let (graph, wit) = singer_builder.graph_builder.finalize_graph_and_witness(); + + println!( + "AddInstruction::construct_graph_and_witness, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + + let point = vec![E::random(&mut rng), E::random(&mut rng)]; + let target_evals = graph.target_evals(&wit, &point); + + let mut prover_transcript = &mut Transcript::new(b"Singer"); + + let timer = Instant::now(); + let _ = GKRGraphProverState::prove( + &graph, + &wit, + &target_evals, + &mut prover_transcript, + (1 << instance_num_vars).min(max_thread_id), + ) + .expect("prove failed"); + println!( + "AddInstruction::prove, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); }); }, ); diff --git a/singer/examples/add.rs b/singer/examples/add.rs index 28a4f0a96..4d32f63c2 100644 --- a/singer/examples/add.rs +++ b/singer/examples/add.rs @@ -1,26 +1,119 @@ -use std::time::{Duration, Instant}; +use std::{collections::BTreeMap, time::Instant}; use ark_std::test_rng; -use const_env::from_env; -use criterion::*; - -use ff_ext::ff::Field; -use ff_ext::ExtensionField; +use ff_ext::{ff::Field, ExtensionField}; use gkr::structs::LayerWitness; -use goldilocks::GoldilocksExt2; +use gkr_graph::structs::CircuitGraphAuxInfo; +use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; +use simple_frontend::structs::CellId; use singer::{ instructions::{add::AddInstruction, Instruction, InstructionGraph, SingerCircuitBuilder}, - scheme::GKRGraphProverState, - CircuitWiresIn, SingerGraphBuilder, SingerParams, + scheme::{GKRGraphProverState, GKRGraphVerifierState}, + u64vec, CircuitWiresIn, SingerGraphBuilder, SingerParams, +}; +use singer_utils::{ + constants::RANGE_CHIP_BIT_WIDTH, + structs::{ChipChallenges, StackUInt, TSUInt}, }; -use singer_utils::structs::ChipChallenges; +use tracing_flame::FlameLayer; +use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry}; use transcript::Transcript; +fn get_single_instance_values_map() -> BTreeMap<&'static str, Vec> { + let mut phase0_values_map = BTreeMap::<&'static str, Vec>::new(); + phase0_values_map.insert( + AddInstruction::phase0_pc_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_stack_ts_str(), + vec![Goldilocks::from(3u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_memory_ts_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_stack_top_str(), + vec![Goldilocks::from(100u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_clk_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_pc_add_str(), + vec![], // carry is 0, may test carry using larger values in PCUInt + ); + phase0_values_map.insert( + AddInstruction::phase0_stack_ts_add_str(), + vec![ + Goldilocks::from(4u64), /* first TSUInt::N_RANGE_CHECK_CELLS = 1*(56/16) = 4 + * cells are range values, stack_ts + 1 = 4 */ + Goldilocks::from(0u64), + Goldilocks::from(0u64), + Goldilocks::from(0u64), + // no place for carry + ], + ); + phase0_values_map.insert( + AddInstruction::phase0_old_stack_ts0_str(), + vec![Goldilocks::from(2u64)], + ); + let m: u64 = (1 << TSUInt::C) - 1; + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + phase0_values_map.insert( + AddInstruction::phase0_old_stack_ts_lt0_str(), + vec![ + Goldilocks::from(range_values[0]), + Goldilocks::from(range_values[1]), + Goldilocks::from(range_values[2]), + Goldilocks::from(range_values[3]), + Goldilocks::from(1u64), // borrow + ], + ); + phase0_values_map.insert( + AddInstruction::phase0_old_stack_ts1_str(), + vec![Goldilocks::from(1u64)], + ); + let m: u64 = (1 << TSUInt::C) - 2; + let range_values = u64vec::<{ TSUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + phase0_values_map.insert( + AddInstruction::phase0_old_stack_ts_lt1_str(), + vec![ + Goldilocks::from(range_values[0]), + Goldilocks::from(range_values[1]), + Goldilocks::from(range_values[2]), + Goldilocks::from(range_values[3]), + Goldilocks::from(1u64), // borrow + ], + ); + let m: u64 = (1 << StackUInt::C) - 1; + phase0_values_map.insert( + AddInstruction::phase0_addend_0_str(), + vec![Goldilocks::from(m)], + ); + phase0_values_map.insert( + AddInstruction::phase0_addend_1_str(), + vec![Goldilocks::from(1u64)], + ); + let range_values = u64vec::<{ StackUInt::N_RANGE_CHECK_CELLS }, RANGE_CHIP_BIT_WIDTH>(m + 1); + let mut wit_phase0_instruction_add: Vec = vec![]; + for i in 0..16 { + wit_phase0_instruction_add.push(Goldilocks::from(range_values[i])) + } + wit_phase0_instruction_add.push(Goldilocks::from(1u64)); // carry is [1, 0, ...] + phase0_values_map.insert( + AddInstruction::phase0_instruction_add_str(), + wit_phase0_instruction_add, + ); + phase0_values_map +} fn main() { - let max_thread_id = 1; - let instance_num_vars = 9; + let max_thread_id = 8; + let instance_num_vars = 11; type E = GoldilocksExt2; let chip_challenges = ChipChallenges::default(); let circuit_builder = @@ -29,14 +122,31 @@ fn main() { let mut rng = test_rng(); let size = AddInstruction::phase0_size(); + let phase0_values_map = get_single_instance_values_map(); + let phase0_idx_map = AddInstruction::phase0_idxes_map(); + + let mut single_witness_in = vec![::BaseField::ZERO; size]; + + for key in phase0_idx_map.keys() { + let range = phase0_idx_map + .get(key) + .unwrap() + .clone() + .collect::>(); + let values = phase0_values_map + .get(key) + .expect(&("unknown key ".to_owned() + key)); + for (value_idx, cell_idx) in range.into_iter().enumerate() { + if value_idx < values.len() { + single_witness_in[cell_idx] = values[value_idx]; + } + } + } + let phase0: CircuitWiresIn<::BaseField> = vec![LayerWitness { instances: (0..(1 << instance_num_vars)) - .map(|_| { - (0..size) - .map(|_| ::BaseField::random(&mut rng)) - .collect_vec() - }) + .map(|_| single_witness_in.clone()) .collect_vec(), }]; @@ -63,23 +173,52 @@ fn main() { timer.elapsed().as_secs_f64() ); + let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); + let subscriber = Registry::default() + .with( + fmt::layer() + .compact() + .with_thread_ids(false) + .with_thread_names(false), + ) + .with(EnvFilter::from_default_env()) + .with(flame_layer.with_threads_collapsed(true)); + tracing::subscriber::set_global_default(subscriber).unwrap(); + let point = vec![E::random(&mut rng), E::random(&mut rng)]; let target_evals = graph.target_evals(&wit, &point); - let mut prover_transcript = &mut Transcript::new(b"Singer"); - - let timer = Instant::now(); - let _ = GKRGraphProverState::prove( - &graph, - &wit, - &target_evals, - &mut prover_transcript, - (1 << instance_num_vars).min(max_thread_id), - ) - .expect("prove failed"); - println!( - "AddInstruction::prove, instance_num_vars = {}, time = {}", - instance_num_vars, - timer.elapsed().as_secs_f64() - ); + for _ in 0..5 { + let mut prover_transcript = &mut Transcript::new(b"Singer"); + let timer = Instant::now(); + let proof = GKRGraphProverState::prove( + &graph, + &wit, + &target_evals, + &mut prover_transcript, + (1 << instance_num_vars).min(max_thread_id), + ) + .expect("prove failed"); + println!( + "AddInstruction::prove, instance_num_vars = {}, time = {}", + instance_num_vars, + timer.elapsed().as_secs_f64() + ); + let mut verifier_transcript = Transcript::new(b"Singer"); + let _ = GKRGraphVerifierState::verify( + &graph, + &real_challenges, + &target_evals, + proof, + &CircuitGraphAuxInfo { + instance_num_vars: wit + .node_witnesses + .iter() + .map(|witness| witness.instance_num_vars()) + .collect(), + }, + &mut verifier_transcript, + ) + .expect("verify failed"); + } } diff --git a/singer/src/instructions.rs b/singer/src/instructions.rs index 863fcc3ef..772c233f0 100644 --- a/singer/src/instructions.rs +++ b/singer/src/instructions.rs @@ -221,6 +221,7 @@ pub trait InstructionGraph { real_n_instances: usize, _: &SingerParams, ) -> Result, ZKVMError> { + assert_eq!(sources.len(), 1, "unknown source length"); let inst_circuit = &inst_circuits[0]; let inst_wires_in = mem::take(&mut sources[0]); let node_id = graph_builder.add_node_with_witness( diff --git a/singer/src/instructions/add.rs b/singer/src/instructions/add.rs index 6bdef4942..747799759 100644 --- a/singer/src/instructions/add.rs +++ b/singer/src/instructions/add.rs @@ -13,7 +13,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::constants::AddSubConstants, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -54,6 +54,7 @@ impl Instruction for AddInstruction { fn construct_circuit(challenges: ChipChallenges) -> Result, ZKVMError> { let mut circuit_builder = CircuitBuilder::new(); let (phase0_wire_id, phase0) = circuit_builder.create_witness_in(Self::phase0_size()); + let mut ram_handler = RAMHandler::new(&challenges); let mut rom_handler = ROMHandler::new(&challenges); @@ -123,6 +124,7 @@ impl Instruction for AddInstruction { &stack_ts, &phase0[Self::phase0_old_stack_ts_lt0()], )?; + ram_handler.stack_pop( &mut circuit_builder, stack_top_expr.sub(E::BaseField::from(1)), @@ -180,17 +182,14 @@ impl Instruction for AddInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; use singer_utils::{ constants::RANGE_CHIP_BIT_WIDTH, structs::{StackUInt, TSUInt}, - uint::constants::AddSubConstants, }; use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; @@ -200,51 +199,11 @@ mod test { AddInstruction, ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, }, scheme::GKRGraphProverState, - test::{get_uint_params, test_opcode_circuit, u2vec}, + test::{get_uint_params, test_opcode_circuit_v2}, + utils::u64vec, CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - impl AddInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert( - "phase0_old_stack_ts0".to_string(), - Self::phase0_old_stack_ts0(), - ); - map.insert( - "phase0_old_stack_ts_lt0".to_string(), - Self::phase0_old_stack_ts_lt0(), - ); - map.insert( - "phase0_old_stack_ts1".to_string(), - Self::phase0_old_stack_ts1(), - ); - map.insert( - "phase0_old_stack_ts_lt1".to_string(), - Self::phase0_old_stack_ts_lt1(), - ); - map.insert("phase0_addend_0".to_string(), Self::phase0_addend_0()); - map.insert("phase0_addend_1".to_string(), Self::phase0_addend_1()); - map.insert( - "phase0_instruction_add".to_string(), - Self::phase0_instruction_add(), - ); - - map - } - } - #[test] fn test_add_construct_circuit() { let challenges = ChipChallenges::default(); @@ -264,21 +223,33 @@ mod test { #[cfg(feature = "test-dbg")] println!("{:?}", inst_circuit); - let mut phase0_values_map = BTreeMap::>::new(); - phase0_values_map.insert("phase0_pc".to_string(), vec![Goldilocks::from(1u64)]); - phase0_values_map.insert("phase0_stack_ts".to_string(), vec![Goldilocks::from(3u64)]); - phase0_values_map.insert("phase0_memory_ts".to_string(), vec![Goldilocks::from(1u64)]); + let mut phase0_values_map = BTreeMap::<&'static str, Vec>::new(); phase0_values_map.insert( - "phase0_stack_top".to_string(), + AddInstruction::phase0_pc_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_stack_ts_str(), + vec![Goldilocks::from(3u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_memory_ts_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_stack_top_str(), vec![Goldilocks::from(100u64)], ); - phase0_values_map.insert("phase0_clk".to_string(), vec![Goldilocks::from(1u64)]); phase0_values_map.insert( - "phase0_pc_add".to_string(), + AddInstruction::phase0_clk_str(), + vec![Goldilocks::from(1u64)], + ); + phase0_values_map.insert( + AddInstruction::phase0_pc_add_str(), vec![], // carry is 0, may test carry using larger values in PCUInt ); phase0_values_map.insert( - "phase0_stack_ts_add".to_string(), + AddInstruction::phase0_stack_ts_add_str(), vec![ Goldilocks::from(4u64), /* first TSUInt::N_RANGE_CELLS = 1*(48/16) = 3 cells are * range values, stack_ts + 1 = 4 */ @@ -288,13 +259,13 @@ mod test { ], ); phase0_values_map.insert( - "phase0_old_stack_ts0".to_string(), + AddInstruction::phase0_old_stack_ts0_str(), vec![Goldilocks::from(2u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( - "phase0_old_stack_ts_lt0".to_string(), + AddInstruction::phase0_old_stack_ts_lt0_str(), vec![ Goldilocks::from(range_values[0]), Goldilocks::from(range_values[1]), @@ -303,13 +274,13 @@ mod test { ], ); phase0_values_map.insert( - "phase0_old_stack_ts1".to_string(), + AddInstruction::phase0_old_stack_ts1_str(), vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 2; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( - "phase0_old_stack_ts_lt1".to_string(), + AddInstruction::phase0_old_stack_ts_lt1_str(), vec![ Goldilocks::from(range_values[0]), Goldilocks::from(range_values[1]), @@ -318,16 +289,22 @@ mod test { ], ); let m: u64 = (1 << get_uint_params::().1) - 1; - phase0_values_map.insert("phase0_addend_0".to_string(), vec![Goldilocks::from(m)]); - phase0_values_map.insert("phase0_addend_1".to_string(), vec![Goldilocks::from(1u64)]); - let range_values = u2vec::<{ StackUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m + 1); + phase0_values_map.insert( + AddInstruction::phase0_addend_0_str(), + vec![Goldilocks::from(m)], + ); + phase0_values_map.insert( + AddInstruction::phase0_addend_1_str(), + vec![Goldilocks::from(1u64)], + ); + let range_values = u64vec::<{ StackUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m + 1); let mut wit_phase0_instruction_add: Vec = vec![]; for i in 0..16 { wit_phase0_instruction_add.push(Goldilocks::from(range_values[i])) } wit_phase0_instruction_add.push(Goldilocks::from(1u64)); // carry is [1, 0, ...] phase0_values_map.insert( - "phase0_instruction_add".to_string(), + AddInstruction::phase0_instruction_add_str(), wit_phase0_instruction_add, ); @@ -337,7 +314,7 @@ mod test { let c = GoldilocksExt2::from(6u64); let circuit_witness_challenges = vec![c; 3]; - let circuit_witness = test_opcode_circuit( + let _ = test_opcode_circuit_v2( &inst_circuit, &phase0_idx_map, phase0_witness_size, diff --git a/singer/src/instructions/calldataload.rs b/singer/src/instructions/calldataload.rs index 6d65d74cf..449be855d 100644 --- a/singer/src/instructions/calldataload.rs +++ b/singer/src/instructions/calldataload.rs @@ -3,7 +3,6 @@ use ff_ext::ExtensionField; use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::{ BytecodeChipOperations, CalldataChipOperations, GlobalStateChipOperations, OAMOperations, @@ -12,8 +11,9 @@ use singer_utils::{ constants::OpcodeType, register_witness, structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt, UInt64}, + uint::constants::AddSubConstants, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -151,57 +151,25 @@ impl Instruction for CalldataloadInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - CalldataloadInstruction, ChipChallenges, Instruction, InstructionGraph, - SingerCircuitBuilder, + use crate::{ + instructions::{ + CalldataloadInstruction, ChipChallenges, Instruction, InstructionGraph, + SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl CalldataloadInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_ts".to_string(), Self::phase0_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert("phase0_data".to_string(), Self::phase0_data()); - map.insert("phase0_offset".to_string(), Self::phase0_offset()); - map.insert( - "phase0_old_stack_ts".to_string(), - Self::phase0_old_stack_ts(), - ); - map.insert( - "phase0_old_stack_ts_lt".to_string(), - Self::phase0_old_stack_ts_lt(), - ); - - map - } - } #[test] fn test_calldataload_construct_circuit() { @@ -239,7 +207,8 @@ mod test { phase0_values_map.insert( "phase0_stack_ts_add".to_string(), vec![ - Goldilocks::from(4u64), // first TSUInt::N_RANGE_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(4u64), /* first TSUInt::N_RANGE_CELLS = 1*(56/16) = 4 cells are + * range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), @@ -251,7 +220,7 @@ mod test { vec![Goldilocks::from(2u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt".to_string(), vec![ diff --git a/singer/src/instructions/dup.rs b/singer/src/instructions/dup.rs index f69637c3c..a250eeb30 100644 --- a/singer/src/instructions/dup.rs +++ b/singer/src/instructions/dup.rs @@ -3,7 +3,6 @@ use ff_ext::ExtensionField; use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::{ BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, @@ -12,8 +11,9 @@ use singer_utils::{ constants::OpcodeType, register_witness, structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -161,56 +161,24 @@ impl Instruction for DupInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - ChipChallenges, DupInstruction, Instruction, InstructionGraph, SingerCircuitBuilder, + use crate::{ + instructions::{ + ChipChallenges, DupInstruction, Instruction, InstructionGraph, SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl DupInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert( - "phase0_stack_values".to_string(), - Self::phase0_stack_values(), - ); - map.insert( - "phase0_old_stack_ts".to_string(), - Self::phase0_old_stack_ts(), - ); - map.insert( - "phase0_old_stack_ts_lt".to_string(), - Self::phase0_old_stack_ts_lt(), - ); - - map - } - } #[test] fn test_dup1_construct_circuit() { @@ -247,7 +215,8 @@ mod test { phase0_values_map.insert( "phase0_stack_ts_add".to_string(), vec![ - Goldilocks::from(3u64), // first TSUInt::N_RANGE_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(3u64), /* first TSUInt::N_RANGE_CELLS = 1*(56/16) = 4 cells are + * range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), @@ -272,7 +241,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt".to_string(), vec![ diff --git a/singer/src/instructions/gt.rs b/singer/src/instructions/gt.rs index 7e0be45dd..da2f1dab2 100644 --- a/singer/src/instructions/gt.rs +++ b/singer/src/instructions/gt.rs @@ -3,7 +3,6 @@ use ff_ext::ExtensionField; use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::{ BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, @@ -12,8 +11,9 @@ use singer_utils::{ constants::OpcodeType, register_witness, structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -176,66 +176,24 @@ impl Instruction for GtInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - ChipChallenges, GtInstruction, Instruction, InstructionGraph, SingerCircuitBuilder, + use crate::{ + instructions::{ + ChipChallenges, GtInstruction, Instruction, InstructionGraph, SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl GtInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert( - "phase0_old_stack_ts0".to_string(), - Self::phase0_old_stack_ts0(), - ); - map.insert( - "phase0_old_stack_ts_lt0".to_string(), - Self::phase0_old_stack_ts_lt0(), - ); - map.insert( - "phase0_old_stack_ts1".to_string(), - Self::phase0_old_stack_ts1(), - ); - map.insert( - "phase0_old_stack_ts_lt1".to_string(), - Self::phase0_old_stack_ts_lt1(), - ); - map.insert("phase0_oprand_0".to_string(), Self::phase0_oprand_0()); - map.insert("phase0_oprand_1".to_string(), Self::phase0_oprand_1()); - map.insert( - "phase0_instruction_gt".to_string(), - Self::phase0_instruction_gt(), - ); - - map - } - } #[test] fn test_gt_construct_circuit() { @@ -272,7 +230,8 @@ mod test { phase0_values_map.insert( "phase0_stack_ts_add".to_string(), vec![ - Goldilocks::from(4u64), // first TSUInt::N_RANGE_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(4u64), /* first TSUInt::N_RANGE_CELLS = 1*(56/16) = 4 cells are + * range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), @@ -284,7 +243,7 @@ mod test { vec![Goldilocks::from(2u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt0".to_string(), vec![ @@ -300,7 +259,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 2; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt1".to_string(), vec![ diff --git a/singer/src/instructions/jump.rs b/singer/src/instructions/jump.rs index 7a13838f1..c12c44fca 100644 --- a/singer/src/instructions/jump.rs +++ b/singer/src/instructions/jump.rs @@ -5,7 +5,6 @@ use ff_ext::ExtensionField; use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::{ BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, @@ -14,7 +13,9 @@ use singer_utils::{ constants::OpcodeType, register_witness, structs::{PCUInt, RAMHandler, ROMHandler, TSUInt}, + uint::constants::AddSubConstants, }; +use std::collections::BTreeMap; use crate::error::ZKVMError; @@ -125,48 +126,26 @@ impl Instruction for JumpInstruction { #[cfg(test)] mod test { - use crate::instructions::{ChipChallenges, Instruction, JumpInstruction}; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; + use crate::{ + instructions::{ChipChallenges, Instruction, JumpInstruction}, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + }; use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{InstructionGraph, SingerCircuitBuilder}; - use crate::scheme::GKRGraphProverState; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl JumpInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_next_pc".to_string(), Self::phase0_next_pc()); - map.insert( - "phase0_old_stack_ts".to_string(), - Self::phase0_old_stack_ts(), - ); - map.insert( - "phase0_old_stack_ts_lt".to_string(), - Self::phase0_old_stack_ts_lt(), - ); - - map - } - } + use crate::{ + instructions::{InstructionGraph, SingerCircuitBuilder}, + scheme::GKRGraphProverState, + CircuitWiresIn, SingerGraphBuilder, SingerParams, + }; #[test] fn test_jump_construct_circuit() { @@ -205,7 +184,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt".to_string(), vec![ diff --git a/singer/src/instructions/jumpdest.rs b/singer/src/instructions/jumpdest.rs index f21874575..dd5b85215 100644 --- a/singer/src/instructions/jumpdest.rs +++ b/singer/src/instructions/jumpdest.rs @@ -12,7 +12,7 @@ use singer_utils::{ register_witness, structs::{PCUInt, RAMHandler, ROMHandler, TSUInt}, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -100,38 +100,23 @@ impl Instruction for JumpdestInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use std::collections::BTreeMap; - use std::time::Instant; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - ChipChallenges, Instruction, InstructionGraph, JumpdestInstruction, SingerCircuitBuilder, + use crate::{ + instructions::{ + ChipChallenges, Instruction, InstructionGraph, JumpdestInstruction, + SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::test_opcode_circuit, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::test_opcode_circuit; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl JumpdestInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - - map - } - } #[test] fn test_jumpdest_construct_circuit() { diff --git a/singer/src/instructions/jumpi.rs b/singer/src/instructions/jumpi.rs index b007d4f76..62e34d21f 100644 --- a/singer/src/instructions/jumpi.rs +++ b/singer/src/instructions/jumpi.rs @@ -14,7 +14,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::constants::AddSubConstants, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; diff --git a/singer/src/instructions/mstore.rs b/singer/src/instructions/mstore.rs index 14f515933..38f7e1010 100644 --- a/singer/src/instructions/mstore.rs +++ b/singer/src/instructions/mstore.rs @@ -15,7 +15,7 @@ use singer_utils::{ structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, uint::constants::AddSubConstants, }; -use std::{mem, sync::Arc}; +use std::{collections::BTreeMap, mem, sync::Arc}; use crate::{error::ZKVMError, utils::add_assign_each_cell, CircuitWiresIn, SingerParams}; @@ -378,7 +378,9 @@ impl MstoreAccessory { #[cfg(test)] mod test { - use crate::{instructions::InstructionGraph, scheme::GKRGraphProverState, SingerParams}; + use crate::{ + instructions::InstructionGraph, scheme::GKRGraphProverState, utils::u64vec, SingerParams, + }; use ark_std::test_rng; use ff::Field; use ff_ext::ExtensionField; @@ -397,51 +399,11 @@ mod test { CircuitWiresIn, SingerGraphBuilder, }; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use core::ops::Range; + use crate::test::{get_uint_params, test_opcode_circuit}; use goldilocks::Goldilocks; - use simple_frontend::structs::CellId; use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; use std::collections::BTreeMap; - impl MstoreInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_memory_ts_add".to_string(), - Self::phase0_memory_ts_add(), - ); - map.insert("phase0_offset".to_string(), Self::phase0_offset()); - map.insert("phase0_mem_bytes".to_string(), Self::phase0_mem_bytes()); - map.insert( - "phase0_old_stack_ts_offset".to_string(), - Self::phase0_old_stack_ts_offset(), - ); - map.insert( - "phase0_old_stack_ts_lt_offset".to_string(), - Self::phase0_old_stack_ts_lt_offset(), - ); - map.insert( - "phase0_old_stack_ts_value".to_string(), - Self::phase0_old_stack_ts_value(), - ); - map.insert( - "phase0_old_stack_ts_lt_value".to_string(), - Self::phase0_old_stack_ts_lt_value(), - ); - - map - } - } - #[test] fn test_mstore_construct_circuit() { let challenges = ChipChallenges::default(); @@ -491,7 +453,7 @@ mod test { vec![Goldilocks::from(2u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt_offset".to_string(), vec![ @@ -511,7 +473,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 2; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt_value".to_string(), vec![ diff --git a/singer/src/instructions/pop.rs b/singer/src/instructions/pop.rs index c82bd6bc6..0ca25a5a5 100644 --- a/singer/src/instructions/pop.rs +++ b/singer/src/instructions/pop.rs @@ -3,7 +3,6 @@ use ff_ext::ExtensionField; use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::{ BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, @@ -12,8 +11,9 @@ use singer_utils::{ constants::OpcodeType, register_witness, structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -126,7 +126,6 @@ impl Instruction for PopInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; @@ -134,44 +133,20 @@ mod test { use itertools::Itertools; use std::collections::BTreeMap; - use crate::instructions::{ChipChallenges, Instruction, PopInstruction}; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; + use crate::{ + instructions::{ChipChallenges, Instruction, PopInstruction}, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + }; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; use std::time::Instant; use transcript::Transcript; - use crate::instructions::{InstructionGraph, SingerCircuitBuilder}; - use crate::scheme::GKRGraphProverState; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl PopInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_old_stack_ts".to_string(), - Self::phase0_old_stack_ts(), - ); - map.insert( - "phase0_old_stack_ts_lt".to_string(), - Self::phase0_old_stack_ts_lt(), - ); - map.insert( - "phase0_stack_values".to_string(), - Self::phase0_stack_values(), - ); - - map - } - } + use crate::{ + instructions::{InstructionGraph, SingerCircuitBuilder}, + scheme::GKRGraphProverState, + CircuitWiresIn, SingerGraphBuilder, SingerParams, + }; #[test] fn test_pop_construct_circuit() { @@ -210,7 +185,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt".to_string(), vec![ diff --git a/singer/src/instructions/push.rs b/singer/src/instructions/push.rs index e42661b5a..d320aa88e 100644 --- a/singer/src/instructions/push.rs +++ b/singer/src/instructions/push.rs @@ -3,7 +3,6 @@ use ff_ext::ExtensionField; use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::{ BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, @@ -12,8 +11,9 @@ use singer_utils::{ constants::OpcodeType, register_witness, structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -146,46 +146,22 @@ impl Instruction for PushInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use std::collections::BTreeMap; - use std::time::Instant; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - ChipChallenges, Instruction, InstructionGraph, PushInstruction, SingerCircuitBuilder, + use crate::{ + instructions::{ + ChipChallenges, Instruction, InstructionGraph, PushInstruction, SingerCircuitBuilder, + }, + scheme::GKRGraphProverState, + test::test_opcode_circuit, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::test_opcode_circuit; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl PushInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert( - "phase0_pc_add_i_plus_1".to_string(), - Self::phase0_pc_add_i_plus_1(), - ); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert("phase0_stack_bytes".to_string(), Self::phase0_stack_bytes()); - - map - } - } #[test] fn test_push1_construct_circuit() { @@ -221,7 +197,8 @@ mod test { phase0_values_map.insert( "phase0_stack_ts_add".to_string(), vec![ - Goldilocks::from(2u64), // first TSUInt::N_RANGE_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(2u64), /* first TSUInt::N_RANGE_CELLS = 1*(56/16) = 4 cells are + * range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), diff --git a/singer/src/instructions/ret.rs b/singer/src/instructions/ret.rs index 4a160031f..7ba6c1abe 100644 --- a/singer/src/instructions/ret.rs +++ b/singer/src/instructions/ret.rs @@ -15,7 +15,7 @@ use singer_utils::{ register_witness, structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, }; -use std::{mem, sync::Arc}; +use std::{collections::BTreeMap, mem, sync::Arc}; use crate::{error::ZKVMError, utils::add_assign_each_cell, CircuitWiresIn, SingerParams}; diff --git a/singer/src/instructions/swap.rs b/singer/src/instructions/swap.rs index fdf52aa17..cbff6a1fd 100644 --- a/singer/src/instructions/swap.rs +++ b/singer/src/instructions/swap.rs @@ -3,7 +3,6 @@ use ff_ext::ExtensionField; use gkr::structs::Circuit; use paste::paste; use simple_frontend::structs::{CircuitBuilder, MixedCell}; -use singer_utils::uint::constants::AddSubConstants; use singer_utils::{ chip_handler::{ BytecodeChipOperations, GlobalStateChipOperations, OAMOperations, ROMOperations, @@ -12,8 +11,9 @@ use singer_utils::{ constants::OpcodeType, register_witness, structs::{PCUInt, RAMHandler, ROMHandler, StackUInt, TSUInt}, + uint::constants::AddSubConstants, }; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use crate::error::ZKVMError; @@ -181,68 +181,24 @@ impl Instruction for SwapInstruction { #[cfg(test)] mod test { use ark_std::test_rng; - use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::LayerWitness; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; - use simple_frontend::structs::CellId; - use singer_utils::constants::RANGE_CHIP_BIT_WIDTH; - use singer_utils::structs::TSUInt; - use std::collections::BTreeMap; - use std::time::Instant; + use singer_utils::{constants::RANGE_CHIP_BIT_WIDTH, structs::TSUInt}; + use std::{collections::BTreeMap, time::Instant}; use transcript::Transcript; - use crate::instructions::{ - ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, SwapInstruction, + use crate::{ + instructions::{ + ChipChallenges, Instruction, InstructionGraph, SingerCircuitBuilder, SwapInstruction, + }, + scheme::GKRGraphProverState, + test::{get_uint_params, test_opcode_circuit}, + utils::u64vec, + CircuitWiresIn, SingerGraphBuilder, SingerParams, }; - use crate::scheme::GKRGraphProverState; - use crate::test::{get_uint_params, test_opcode_circuit, u2vec}; - use crate::{CircuitWiresIn, SingerGraphBuilder, SingerParams}; - - impl SwapInstruction { - #[inline] - fn phase0_idxes_map() -> BTreeMap> { - let mut map = BTreeMap::new(); - map.insert("phase0_pc".to_string(), Self::phase0_pc()); - map.insert("phase0_stack_ts".to_string(), Self::phase0_stack_ts()); - map.insert("phase0_memory_ts".to_string(), Self::phase0_memory_ts()); - map.insert("phase0_stack_top".to_string(), Self::phase0_stack_top()); - map.insert("phase0_clk".to_string(), Self::phase0_clk()); - map.insert("phase0_pc_add".to_string(), Self::phase0_pc_add()); - map.insert( - "phase0_stack_ts_add".to_string(), - Self::phase0_stack_ts_add(), - ); - map.insert( - "phase0_old_stack_ts_1".to_string(), - Self::phase0_old_stack_ts_1(), - ); - map.insert( - "phase0_old_stack_ts_lt_1".to_string(), - Self::phase0_old_stack_ts_lt_1(), - ); - map.insert( - "phase0_old_stack_ts_n_plus_1".to_string(), - Self::phase0_old_stack_ts_n_plus_1(), - ); - map.insert( - "phase0_old_stack_ts_lt_n_plus_1".to_string(), - Self::phase0_old_stack_ts_lt_n_plus_1(), - ); - map.insert( - "phase0_stack_values_1".to_string(), - Self::phase0_stack_values_1(), - ); - map.insert( - "phase0_stack_values_n_plus_1".to_string(), - Self::phase0_stack_values_n_plus_1(), - ); - - map - } - } #[test] fn test_swap2_construct_circuit() { @@ -279,7 +235,8 @@ mod test { phase0_values_map.insert( "phase0_stack_ts_add".to_string(), vec![ - Goldilocks::from(5u64), // first TSUInt::N_RANGE_CELLS = 1*(56/16) = 4 cells are range values, stack_ts + 1 = 4 + Goldilocks::from(5u64), /* first TSUInt::N_RANGE_CELLS = 1*(56/16) = 4 cells are + * range values, stack_ts + 1 = 4 */ Goldilocks::from(0u64), Goldilocks::from(0u64), Goldilocks::from(0u64), @@ -291,7 +248,7 @@ mod test { vec![Goldilocks::from(3u64)], ); let m: u64 = (1 << get_uint_params::().1) - 1; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt_1".to_string(), vec![ @@ -306,7 +263,7 @@ mod test { vec![Goldilocks::from(1u64)], ); let m: u64 = (1 << get_uint_params::().1) - 3; - let range_values = u2vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); + let range_values = u64vec::<{ TSUInt::N_RANGE_CELLS }, RANGE_CHIP_BIT_WIDTH>(m); phase0_values_map.insert( "phase0_old_stack_ts_lt_n_plus_1".to_string(), vec![ diff --git a/singer/src/lib.rs b/singer/src/lib.rs index 614ae606f..aa829c07f 100644 --- a/singer/src/lib.rs +++ b/singer/src/lib.rs @@ -18,6 +18,7 @@ pub mod instructions; pub mod scheme; #[cfg(test)] pub mod test; +pub use utils::u64vec; mod utils; // Process sketch: diff --git a/singer/src/test.rs b/singer/src/test.rs index 30c775fb7..563cc2b4d 100644 --- a/singer/src/test.rs +++ b/singer/src/test.rs @@ -2,10 +2,6 @@ use core::ops::Range; use ff::Field; use ff_ext::ExtensionField; use gkr::structs::CircuitWitness; -use gkr::structs::IOPProverState; -use gkr::utils::MultilinearExtensionFromVectors; -use goldilocks::SmallField; -use itertools::Itertools; use simple_frontend::structs::CellId; use singer_utils::uint::UInt; use std::collections::BTreeMap; @@ -26,21 +22,11 @@ pub(crate) fn get_uint_params() -> (usize, usize) { (T::BITS, T::CELL_BIT_WIDTH) } -pub(crate) fn u2vec(x: u64) -> [u64; W] { - let mut x = x; - let mut ret = [0; W]; - for i in 0..ret.len() { - ret[i] = x & ((1 << C) - 1); - x >>= C; - } - ret -} - -pub(crate) fn test_opcode_circuit( +pub(crate) fn test_opcode_circuit_v2( inst_circuit: &InstCircuit, - phase0_idx_map: &BTreeMap>, + phase0_idx_map: &BTreeMap<&'static str, Range>, phase0_witness_size: usize, - phase0_values_map: &BTreeMap>, + phase0_values_map: &BTreeMap<&'static str, Vec>, circuit_witness_challenges: Vec, ) -> CircuitWitness<::BaseField> { // configure circuit @@ -65,7 +51,9 @@ pub(crate) fn test_opcode_circuit( .unwrap() .clone() .collect::>(); - let values = phase0_values_map.get(key).unwrap(); + let values = phase0_values_map + .get(key) + .expect(&("unknown key ".to_owned() + key)); for (value_idx, cell_idx) in range.into_iter().enumerate() { if value_idx < values.len() { witness_in[phase0_input_idx as usize][cell_idx] = values[value_idx]; @@ -90,67 +78,86 @@ pub(crate) fn test_opcode_circuit( circuit_witness - /*let instance_num_vars = circuit_witness.instance_num_vars(); - let (proof, output_num_vars, output_eval) = { - let mut prover_transcript = Transcript::::new(b"example"); - let output_num_vars = instance_num_vars + circuit.first_layer_ref().num_vars(); - let output_point = (0..output_num_vars) - .map(|_| { - prover_transcript - .get_and_append_challenge(b"output point") - .elements - }) - .collect_vec(); - let output_eval = circuit_witness - .layer_poly(0, circuit.first_layer_ref().num_vars()) - .evaluate(&output_point); - ( - IOPProverState::prove_parallel( - &circuit, - &circuit_witness, - &[(output_point, output_eval)], - &[], - &mut prover_transcript, - ), - output_num_vars, - output_eval, - ) - };*/ - /* - let gkr_input_claims = { - let mut verifier_transcript = &mut Transcript::::new(b"example"); - let output_point = (0..output_num_vars) - .map(|_| { - verifier_transcript - .get_and_append_challenge(b"output point") - .elements - }) - .collect_vec(); - IOPVerifierState::verify_parallel( - &circuit, - circuit_witness.challenges(), - &[(output_point, output_eval)], - &[], - &proof, - instance_num_vars, - &mut verifier_transcript, - ) - .expect("verification failed") - }; - let expected_values = circuit_witness - .witness_in_ref() + // let instance_num_vars = circuit_witness.instance_num_vars(); + // let (proof, output_num_vars, output_eval) = { + // let mut prover_transcript = Transcript::::new(b"example"); + // let output_num_vars = instance_num_vars + circuit.first_layer_ref().num_vars(); + // let output_point = (0..output_num_vars) + // .map(|_| { + // prover_transcript + // .get_and_append_challenge(b"output point") + // .elements + // }) + // .collect_vec(); + // let output_eval = circuit_witness + // .layer_poly(0, circuit.first_layer_ref().num_vars()) + // .evaluate(&output_point); + // ( + // IOPProverState::prove_parallel( + // &circuit, + // &circuit_witness, + // &[(output_point, output_eval)], + // &[], + // &mut prover_transcript, + // ), + // output_num_vars, + // output_eval, + // ) + // }; + // let gkr_input_claims = { + // let mut verifier_transcript = &mut Transcript::::new(b"example"); + // let output_point = (0..output_num_vars) + // .map(|_| { + // verifier_transcript + // .get_and_append_challenge(b"output point") + // .elements + // }) + // .collect_vec(); + // IOPVerifierState::verify_parallel( + // &circuit, + // circuit_witness.challenges(), + // &[(output_point, output_eval)], + // &[], + // &proof, + // instance_num_vars, + // &mut verifier_transcript, + // ) + // .expect("verification failed") + // }; + // let expected_values = circuit_witness + // .witness_in_ref() + // .iter() + // .map(|witness| { + // witness + // .instances + // .as_slice() + // .mle(circuit.max_wit_in_num_vars.expect("REASON"), instance_num_vars) + // .evaluate(&gkr_input_claims.point_and_evals) + // }) + // .collect_vec(); + // for i in 0..gkr_input_claims.point_and_evals.len() { + // assert_eq!(expected_values[i], gkr_input_claims.point_and_evals[i]); + // } + // println!("verification succeeded"); +} + +#[deprecated(note = "deprecated and use test_opcode_circuit_v2 instead")] +pub(crate) fn test_opcode_circuit( + inst_circuit: &InstCircuit, + phase0_idx_map: &BTreeMap<&'static str, Range>, + phase0_witness_size: usize, + phase0_values_map: &BTreeMap>, + circuit_witness_challenges: Vec, +) -> CircuitWitness<::BaseField> { + let phase0_values_map = phase0_values_map .iter() - .map(|witness| { - witness - .instances - .as_slice() - .mle(circuit.max_wit_in_num_vars.expect("REASON"), instance_num_vars) - .evaluate(&gkr_input_claims.point_and_evals) - }) - .collect_vec(); - for i in 0..gkr_input_claims.point_and_evals.len() { - assert_eq!(expected_values[i], gkr_input_claims.point_and_evals[i]); - } - println!("verification succeeded"); - */ + .map(|(key, value)| (key.clone().leak() as &'static str, value.clone())) + .collect::>>(); + test_opcode_circuit_v2( + inst_circuit, + phase0_idx_map, + phase0_witness_size, + &phase0_values_map, + circuit_witness_challenges, + ) } diff --git a/singer/src/utils.rs b/singer/src/utils.rs index d194edc8c..d1314c56b 100644 --- a/singer/src/utils.rs +++ b/singer/src/utils.rs @@ -21,3 +21,16 @@ pub(crate) fn add_assign_each_cell( circuit_builder.add(*dest, *src, E::BaseField::ONE); } } + +// split single u64 value into W slices, each slice got C bits. +// all the rest slices will be filled with 0 if W x C > 64 +pub fn u64vec(x: u64) -> [u64; W] { + assert!(C <= 64); + let mut x = x; + let mut ret = [0; W]; + for i in 0..ret.len() { + ret[i] = x & ((1 << C) - 1); + x >>= C; + } + ret +} diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 5af55e749..d32423ee7 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -1,4 +1,4 @@ -use std::{mem, sync::Arc}; +use std::{array, mem, sync::Arc}; use ark_std::{end_timer, start_timer}; use crossbeam_channel::bounded; @@ -8,8 +8,7 @@ use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator}, prelude::{IntoParallelIterator, ParallelIterator}, }; -use transcript::Challenge; -use transcript::{Transcript, TranscriptSyncronized}; +use transcript::{Challenge, Transcript, TranscriptSyncronized}; #[cfg(feature = "non_pow2_rayon_thread")] use crate::local_thread_pool::{create_local_pool_once, LOCAL_THREAD_POOL}; @@ -26,7 +25,8 @@ use crate::{ impl IOPProverState { /// Given a virtual polynomial, generate an IOP proof. /// multi-threads model follow https://arxiv.org/pdf/2210.00264#page=8 "distributed sumcheck" - /// This is experiment features. It's preferable that we move parallel level up more to "bould_poly" so it can be more isolation + /// This is experiment features. It's preferable that we move parallel level up more to + /// "bould_poly" so it can be more isolation #[tracing::instrument(skip_all, name = "sumcheck::prove_batch_polys")] pub fn prove_batch_polys( max_thread_id: usize, @@ -72,7 +72,8 @@ impl IOPProverState { }) .collect::>(); - // spawn extra #(max_thread_id - 1) work threads, whereas the main-thread be the last work thread + // spawn extra #(max_thread_id - 1) work threads, whereas the main-thread be the last work + // thread for thread_id in 0..(max_thread_id - 1) { let mut prover_state = Self::prover_init_with_extrapolation_aux( mem::take(&mut polys[thread_id]), @@ -357,8 +358,9 @@ impl IOPProverState { self.poly .flattened_ml_extensions .iter_mut() - // benchmark result indicate make_mut achieve better performange than get_mut, which can be +5% overhead - // rust docs doen't explain the reason + // benchmark result indicate make_mut achieve better performange than get_mut, + // which can be +5% overhead rust docs doen't explain the + // reason .map(Arc::make_mut) .for_each(|f| { f.fix_variables_in_place(&[r.elements]); @@ -382,16 +384,16 @@ impl IOPProverState { 1 => { let f = &self.poly.flattened_ml_extensions[products[0]]; op_mle! { - |f| (0..f.len()) - .into_iter() - .step_by(2) - .map(|b| { - AdditiveArray([ - f[b], - f[b + 1] - ]) - }) - .sum::>(), + |f| { + (0..f.len()) + .into_iter() + .step_by(2) + .fold(AdditiveArray::(array::from_fn(|_| 0.into())), |mut acc, b| { + acc.0[0] += f[b]; + acc.0[1] += f[b+1]; + acc + }) + }, |sum| AdditiveArray(sum.0.map(E::from)) } .to_vec() @@ -402,17 +404,16 @@ impl IOPProverState { &self.poly.flattened_ml_extensions[products[1]], ); commutative_op_mle_pair!( - |f, g| (0..f.len()) - .into_iter() - .step_by(2) - .map(|b| { - AdditiveArray([ - f[b] * g[b], - f[b + 1] * g[b + 1], - (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]), - ]) - }) - .sum::>(), + |f, g| (0..f.len()).into_iter().step_by(2).fold( + AdditiveArray::(array::from_fn(|_| 0.into())), + |mut acc, b| { + acc.0[0] += f[b] * g[b]; + acc.0[1] += f[b + 1] * g[b + 1]; + acc.0[2] += + (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); + acc + } + ), |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() @@ -623,8 +624,9 @@ impl IOPProverState { self.poly .flattened_ml_extensions .par_iter_mut() - // benchmark result indicate make_mut achieve better performange than get_mut, which can be +5% overhead - // rust docs doen't explain the reason + // benchmark result indicate make_mut achieve better performange than get_mut, + // which can be +5% overhead rust docs doen't explain the + // reason .map(Arc::make_mut) .for_each(|f| { f.fix_variables_in_place_parallel(&[r.elements]); diff --git a/sumcheck/src/verifier.rs b/sumcheck/src/verifier.rs index c554dc12d..b4e119d3d 100644 --- a/sumcheck/src/verifier.rs +++ b/sumcheck/src/verifier.rs @@ -155,8 +155,10 @@ impl IOPVerifierState { // 1. check if the received 'P(0) + P(1) = expected`. if evaluations[0] + evaluations[1] != expected { panic!( - "{}th round's prover message is not consistent with the claim. {:?} {:?} {:?}", - i, evaluations[0], evaluations[1], expected + "{}th round's prover message is not consistent with the claim. {:?} {:?}", + i, + evaluations[0] + evaluations[1], + expected ); } }