diff --git a/gkr-graph/src/circuit_graph_builder.rs b/gkr-graph/src/circuit_graph_builder.rs index 0d3177f57..da39f89d3 100644 --- a/gkr-graph/src/circuit_graph_builder.rs +++ b/gkr-graph/src/circuit_graph_builder.rs @@ -2,7 +2,7 @@ use std::{collections::BTreeSet, sync::Arc}; use ark_std::Zero; use ff_ext::ExtensionField; -use gkr::structs::{Circuit, CircuitWitnessV2}; +use gkr::structs::{Circuit, CircuitWitness}; use itertools::{chain, izip, Itertools}; use multilinear_extensions::{ mle::DenseMultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, @@ -60,7 +60,7 @@ impl<'a, E: ExtensionField> CircuitGraphBuilder<'a, E> { .collect_vec() ); - let mut witness = CircuitWitnessV2::new(circuit, challenges); + let mut witness = CircuitWitness::new(circuit, challenges); let wits_in = izip!(preds.iter(), sources.into_iter()) .map(|(pred, source)| match pred { PredType::Source => source.into(), diff --git a/gkr-graph/src/structs.rs b/gkr-graph/src/structs.rs index 2352d02f9..5987acf43 100644 --- a/gkr-graph/src/structs.rs +++ b/gkr-graph/src/structs.rs @@ -1,10 +1,10 @@ use ff_ext::ExtensionField; -use gkr::structs::{Circuit, CircuitWitnessV2, PointAndEval}; +use gkr::structs::{Circuit, CircuitWitness, PointAndEval}; use simple_frontend::structs::WitnessId; use std::{marker::PhantomData, sync::Arc}; -pub(crate) type GKRProverState = gkr::structs::IOPProverStateV2; -pub(crate) type GKRVerifierState = gkr::structs::IOPVerifierStateV2; +pub(crate) type GKRProverState = gkr::structs::IOPProverState; +pub(crate) type GKRVerifierState = gkr::structs::IOPVerifierState; pub(crate) type GKRProof = gkr::structs::IOPProof; /// Corresponds to the `output_evals` and `wires_out_evals` in gkr @@ -60,7 +60,7 @@ pub struct CircuitGraph { #[derive(Default)] pub struct CircuitGraphWitness<'a, E: ExtensionField> { - pub node_witnesses: Vec>>, + pub node_witnesses: Vec>>, } pub struct CircuitGraphBuilder<'a, E: ExtensionField> { diff --git a/gkr/examples/keccak256.rs b/gkr/examples/keccak256.rs index 71a54892d..73b660e74 100644 --- a/gkr/examples/keccak256.rs +++ b/gkr/examples/keccak256.rs @@ -7,7 +7,7 @@ use ff::Field; use ff_ext::ExtensionField; use gkr::{ gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak256}, - structs::CircuitWitnessV2, + structs::CircuitWitness, util::ceil_log2, }; use goldilocks::GoldilocksExt2; @@ -75,7 +75,7 @@ fn main() { DenseMultilinearExtension::from_evaluations_vec(ceil_log2(wit_in.len()), wit_in) }) .collect(); - let mut witness = CircuitWitnessV2::new(&circuit, Vec::new()); + let mut witness = CircuitWitness::new(&circuit, Vec::new()); witness.add_instance(&circuit, all_zero); witness.add_instance(&circuit, all_one); diff --git a/gkr/src/circuit.rs b/gkr/src/circuit.rs index efd2c0381..80d3defab 100644 --- a/gkr/src/circuit.rs +++ b/gkr/src/circuit.rs @@ -7,7 +7,6 @@ use crate::structs::{Gate1In, Gate2In, Gate3In, GateCIn}; mod circuit_layout; mod circuit_witness; -mod circuit_witness_v2; pub trait EvaluateGateCIn where diff --git a/gkr/src/circuit/circuit_witness.rs b/gkr/src/circuit/circuit_witness.rs index f7351e652..b31bee97e 100644 --- a/gkr/src/circuit/circuit_witness.rs +++ b/gkr/src/circuit/circuit_witness.rs @@ -1,29 +1,38 @@ -use std::{collections::HashMap, fmt::Debug}; +use std::{collections::HashMap, sync::Arc}; +use crate::circuit::EvaluateConstant; +use ff::Field; use ff_ext::ExtensionField; -use goldilocks::SmallField; use itertools::{izip, Itertools}; -use multilinear_extensions::mle::ArcDenseMultilinearExtension; -use simple_frontend::structs::{ChallengeConst, ConstantType, LayerId}; +use multilinear_extensions::{ + mle::{DenseMultilinearExtension, MultilinearExtension}, + virtual_poly_v2::ArcMultilinearExtension, +}; +use simple_frontend::structs::{ChallengeConst, LayerId}; +use std::fmt::Debug; use sumcheck::util::ceil_log2; use crate::{ - structs::{Circuit, CircuitWitness, LayerWitness}, - utils::{i64_to_field, MultilinearExtensionFromVectors}, + structs::{Circuit, CircuitWitness}, + utils::i64_to_field, }; -use super::EvaluateConstant; - -impl CircuitWitness { +impl<'a, E: ExtensionField> CircuitWitness<'a, E> { /// Initialize the structure of the circuit witness. - pub fn new(circuit: &Circuit, challenges: Vec) -> Self - where - E: ExtensionField, - { + pub fn new(circuit: &Circuit, challenges: Vec) -> Self { + let create_default = |size| { + (0..size) + .map(|_| { + let a: ArcMultilinearExtension = + Arc::new(DenseMultilinearExtension::default()); + a + }) + .collect::>>() + }; Self { - layers: vec![LayerWitness::default(); circuit.layers.len()], - witness_in: vec![LayerWitness::default(); circuit.n_witness_in], - witness_out: vec![LayerWitness::default(); circuit.n_witness_out], + layers: create_default(circuit.layers.len()), + witness_in: create_default(circuit.n_witness_in), + witness_out: create_default(circuit.n_witness_out), n_instances: 0, challenges: circuit.generate_basefield_challenges(&challenges), } @@ -31,123 +40,129 @@ impl CircuitWitness { /// Generate a fresh instance for the circuit, return layer witnesses and /// wire out witnesses. - fn new_instances( + fn new_instances( circuit: &Circuit, - wits_in: &[LayerWitness], - challenges: &HashMap>, + wits_in: &[DenseMultilinearExtension], + challenges: &HashMap>, n_instances: usize, - ) -> (Vec>, Vec>) - where - E: ExtensionField, - { + ) -> ( + Vec>, + Vec>, + ) { let n_layers = circuit.layers.len(); - let mut layer_wits = vec![ - LayerWitness { - instances: vec![vec![]; n_instances] - }; - n_layers - ]; + let mut layer_wits = vec![DenseMultilinearExtension::default(); n_layers]; // The first layer. layer_wits[n_layers - 1] = { let mut layer_wit = - vec![vec![F::ZERO; circuit.layers[n_layers - 1].size()]; n_instances]; + vec![E::BaseField::ZERO; circuit.layers[n_layers - 1].size() * n_instances]; for instance_id in 0..n_instances { + let instance_start_index = instance_id * circuit.layers[n_layers - 1].size(); assert_eq!(wits_in.len(), circuit.paste_from_wits_in.len()); for (wit_id, (l, r)) in circuit.paste_from_wits_in.iter().enumerate() { for i in *l..*r { - layer_wit[instance_id][i] = - wits_in[wit_id as usize].instances[instance_id][i - *l]; + let wit_in = wits_in[wit_id as usize].get_base_field_vec(); + layer_wit[instance_start_index + i] = wit_in[instance_start_index + i - *l]; } } for (constant, (l, r)) in circuit.paste_from_consts_in.iter() { for i in *l..*r { - layer_wit[instance_id][i] = i64_to_field(*constant); + layer_wit[instance_start_index + i] = i64_to_field(*constant); } } for (num_vars, (l, r)) in circuit.paste_from_counter_in.iter() { for i in *l..*r { - layer_wit[instance_id][i] = - F::from(((instance_id << num_vars) ^ (i - *l)) as u64); + layer_wit[instance_start_index + i] = + E::BaseField::from(((instance_id << num_vars) ^ (i - *l)) as u64) } } } - LayerWitness { - instances: layer_wit, - } + + DenseMultilinearExtension::from_evaluations_vec( + ceil_log2(circuit.layers[n_layers - 1].size() * n_instances), + layer_wit, + ) }; for (layer_id, layer) in circuit.layers.iter().enumerate().rev().skip(1) { let size = circuit.layers[layer_id].size(); - let mut current_layer_wits = vec![vec![F::ZERO; size]; n_instances]; + let mut current_layer_wit = vec![E::BaseField::ZERO; size * n_instances]; - izip!((0..n_instances), current_layer_wits.iter_mut()).for_each( - |(instance_id, current_layer_wit)| { - layer - .paste_from - .iter() - .for_each(|(old_layer_id, new_wire_ids)| { - new_wire_ids.iter().enumerate().for_each( - |(subset_wire_id, new_wire_id)| { - let old_wire_id = circuit.layers[*old_layer_id as usize] - .copy_to - .get(&(layer_id as LayerId)) - .unwrap()[subset_wire_id]; - current_layer_wit[*new_wire_id] = layer_wits - [*old_layer_id as usize] - .instances[instance_id][old_wire_id]; - }, - ); - }); + izip!(0..n_instances).for_each(|instance_id| { + let new_layer_instance_start_index = + instance_id * circuit.layers[layer_id as usize].size(); + layer + .paste_from + .iter() + .for_each(|(old_layer_id, new_wire_ids)| { + let layer_wits = layer_wits[*old_layer_id as usize].get_base_field_vec(); + let old_layer_instance_start_index = + instance_id * circuit.layers[*old_layer_id as usize].size(); + new_wire_ids.iter().enumerate().for_each( + |(subset_wire_id, new_wire_id)| { + let old_wire_id = circuit.layers[*old_layer_id as usize] + .copy_to + .get(&(layer_id as LayerId)) + .unwrap()[subset_wire_id]; + current_layer_wit[new_layer_instance_start_index + *new_wire_id] = + layer_wits[old_layer_instance_start_index + old_wire_id]; + }, + ); + }); - let last_layer_wit = &layer_wits[layer_id + 1].instances[instance_id]; - for add_const in layer.add_consts.iter() { - current_layer_wit[add_const.idx_out] += add_const.scalar.eval(&challenges); - } + let last_layer_wit = layer_wits[layer_id + 1].get_base_field_vec(); + let last_layer_instance_start_index = + instance_id * circuit.layers[layer_id as usize + 1].size(); + for add_const in layer.add_consts.iter() { + current_layer_wit[new_layer_instance_start_index + add_const.idx_out] += + add_const.scalar.eval(&challenges); + } - for add in layer.adds.iter() { - current_layer_wit[add.idx_out] += - last_layer_wit[add.idx_in[0]] * add.scalar.eval(&challenges); - } + for add in layer.adds.iter() { + current_layer_wit[new_layer_instance_start_index + add.idx_out] += + last_layer_wit[last_layer_instance_start_index + add.idx_in[0]] + * add.scalar.eval(&challenges); + } - for mul2 in layer.mul2s.iter() { - current_layer_wit[mul2.idx_out] += last_layer_wit[mul2.idx_in[0]] - * last_layer_wit[mul2.idx_in[1]] + for mul2 in layer.mul2s.iter() { + current_layer_wit[new_layer_instance_start_index + mul2.idx_out] += + last_layer_wit[last_layer_instance_start_index + mul2.idx_in[0]] + * last_layer_wit[last_layer_instance_start_index + mul2.idx_in[1]] * mul2.scalar.eval(&challenges); - } + } - for mul3 in layer.mul3s.iter() { - current_layer_wit[mul3.idx_out] += last_layer_wit[mul3.idx_in[0]] - * last_layer_wit[mul3.idx_in[1]] - * last_layer_wit[mul3.idx_in[2]] + for mul3 in layer.mul3s.iter() { + current_layer_wit[new_layer_instance_start_index + mul3.idx_out] += + last_layer_wit[last_layer_instance_start_index + mul3.idx_in[0]] + * last_layer_wit[last_layer_instance_start_index + mul3.idx_in[1]] + * last_layer_wit[last_layer_instance_start_index + mul3.idx_in[2]] * mul3.scalar.eval(&challenges); - } - }, - ); + } + }); - layer_wits[layer_id] = LayerWitness { - instances: current_layer_wits, - }; + layer_wits[layer_id] = DenseMultilinearExtension::from_evaluations_vec( + ceil_log2(size * n_instances), + current_layer_wit, + ); } - let mut wits_out = vec![ - LayerWitness { - instances: vec![vec![]; n_instances] - }; - circuit.n_witness_out - ]; + let mut wits_out = vec![vec![]; circuit.n_witness_out]; for instance_id in 0..n_instances { + let last_layer_instance_start_index = instance_id * circuit.layers[0].size(); circuit .copy_to_wits_out .iter() .enumerate() .for_each(|(wit_id, old_wire_ids)| { + let layer_wit = layer_wits[0].get_base_field_vec(); let mut wit_out = old_wire_ids .iter() - .map(|old_wire_id| layer_wits[0].instances[instance_id][*old_wire_id]) + .map(|old_wire_id| { + layer_wit[last_layer_instance_start_index + *old_wire_id] + }) .collect_vec(); let length = wit_out.len().next_power_of_two(); - wit_out.resize(length, F::ZERO); - wits_out[wit_id].instances[instance_id] = wit_out; + wit_out.resize(length, E::BaseField::ZERO); + wits_out[wit_id].extend(wit_out); }); // #[cfg(debug_assertions)] @@ -157,56 +172,237 @@ impl CircuitWitness { // } // }); } + let wits_out = wits_out + .into_iter() + .map(|wit_out| { + let num_vars = ceil_log2(wit_out.len()); + DenseMultilinearExtension::from_evaluations_vec(num_vars, wit_out) + }) + .collect(); (layer_wits, wits_out) } - pub fn add_instance(&mut self, circuit: &Circuit, wits_in: Vec>) - where - E: ExtensionField, - { - let wits_in = wits_in + /// Generate a fresh instance for the circuit, return layer witnesses and + /// wire out witnesses. + fn new_instances_v2( + circuit: &Circuit, + wits_in: &[ArcMultilinearExtension<'a, E>], + challenges: &HashMap>, + n_instances: usize, + ) -> ( + Vec>, + Vec>, + ) { + let n_layers = circuit.layers.len(); + let mut layer_wits = vec![DenseMultilinearExtension::default(); n_layers]; + + // The first layer. + layer_wits[n_layers - 1] = { + let mut layer_wit = + vec![E::BaseField::ZERO; circuit.layers[n_layers - 1].size() * n_instances]; + for instance_id in 0..n_instances { + let layer_wit_instance_start_index = + instance_id * circuit.layers[n_layers - 1].size(); + assert_eq!(wits_in.len(), circuit.paste_from_wits_in.len()); + for (wit_id, (l, r)) in circuit.paste_from_wits_in.iter().enumerate() { + let wit_in = wits_in[wit_id as usize].get_base_field_vec(); + let wit_in_instance_start_index = instance_id * wit_in.len() / n_instances; + for i in *l..*r { + layer_wit[layer_wit_instance_start_index + i] = + wit_in[wit_in_instance_start_index + i - *l]; + } + } + for (constant, (l, r)) in circuit.paste_from_consts_in.iter() { + for i in *l..*r { + layer_wit[layer_wit_instance_start_index + i] = i64_to_field(*constant); + } + } + for (num_vars, (l, r)) in circuit.paste_from_counter_in.iter() { + for i in *l..*r { + layer_wit[layer_wit_instance_start_index + i] = + E::BaseField::from(((instance_id << num_vars) ^ (i - *l)) as u64) + } + } + } + + DenseMultilinearExtension::from_evaluations_vec( + ceil_log2(circuit.layers[n_layers - 1].size() * n_instances), + layer_wit, + ) + }; + + for (layer_id, layer) in circuit.layers.iter().enumerate().rev().skip(1) { + let size = circuit.layers[layer_id].size(); + let mut current_layer_wit = vec![E::BaseField::ZERO; size * n_instances]; + + izip!(0..n_instances).for_each(|instance_id| { + let new_layer_instance_start_index = + instance_id * circuit.layers[layer_id as usize].size(); + layer + .paste_from + .iter() + .for_each(|(old_layer_id, new_wire_ids)| { + let layer_wits = layer_wits[*old_layer_id as usize].get_base_field_vec(); + let old_layer_instance_start_index = + instance_id * circuit.layers[*old_layer_id as usize].size(); + new_wire_ids.iter().enumerate().for_each( + |(subset_wire_id, new_wire_id)| { + let old_wire_id = circuit.layers[*old_layer_id as usize] + .copy_to + .get(&(layer_id as LayerId)) + .unwrap()[subset_wire_id]; + current_layer_wit[new_layer_instance_start_index + *new_wire_id] = + layer_wits[old_layer_instance_start_index + old_wire_id]; + }, + ); + }); + + let last_layer_wit = layer_wits[layer_id + 1].get_base_field_vec(); + let last_layer_instance_start_index = + instance_id * circuit.layers[layer_id as usize + 1].size(); + for add_const in layer.add_consts.iter() { + current_layer_wit[new_layer_instance_start_index + add_const.idx_out] += + add_const.scalar.eval(&challenges); + } + + for add in layer.adds.iter() { + current_layer_wit[new_layer_instance_start_index + add.idx_out] += + last_layer_wit[last_layer_instance_start_index + add.idx_in[0]] + * add.scalar.eval(&challenges); + } + + for mul2 in layer.mul2s.iter() { + current_layer_wit[new_layer_instance_start_index + mul2.idx_out] += + last_layer_wit[last_layer_instance_start_index + mul2.idx_in[0]] + * last_layer_wit[last_layer_instance_start_index + mul2.idx_in[1]] + * mul2.scalar.eval(&challenges); + } + + for mul3 in layer.mul3s.iter() { + current_layer_wit[new_layer_instance_start_index + mul3.idx_out] += + last_layer_wit[last_layer_instance_start_index + mul3.idx_in[0]] + * last_layer_wit[last_layer_instance_start_index + mul3.idx_in[1]] + * last_layer_wit[last_layer_instance_start_index + mul3.idx_in[2]] + * mul3.scalar.eval(&challenges); + } + }); + + layer_wits[layer_id] = DenseMultilinearExtension::from_evaluations_vec( + ceil_log2(size * n_instances), + current_layer_wit, + ); + } + let mut wits_out = vec![vec![]; circuit.n_witness_out]; + for instance_id in 0..n_instances { + let last_layer_instance_start_index = instance_id * circuit.layers[0].size(); + circuit + .copy_to_wits_out + .iter() + .enumerate() + .for_each(|(wit_id, old_wire_ids)| { + let layer_wit = layer_wits[0].get_base_field_vec(); + let mut wit_out = old_wire_ids + .iter() + .map(|old_wire_id| { + layer_wit[last_layer_instance_start_index + *old_wire_id] + }) + .collect_vec(); + let length = wit_out.len().next_power_of_two(); + wit_out.resize(length, E::BaseField::ZERO); + wits_out[wit_id].extend(wit_out); + }); + + // #[cfg(debug_assertions)] + // circuit.assert_consts.iter().for_each(|gate| { + // if let ConstantType::Field(constant) = gate.scalar { + // assert_eq!(layer_wits[0].instances[instance_id][gate.idx_out], constant); + // } + // }); + } + let wits_out = wits_out .into_iter() - .map(|wit_in| LayerWitness { - instances: vec![wit_in], + .map(|wit_out| { + let num_vars = ceil_log2(wit_out.len()); + DenseMultilinearExtension::from_evaluations_vec(num_vars, wit_out) }) - .collect_vec(); + .collect(); + (layer_wits, wits_out) + } + + pub fn add_instance( + &mut self, + circuit: &Circuit, + wits_in: Vec>, + ) { self.add_instances(circuit, wits_in, 1); } - pub fn add_instances( + pub fn set_instances( &mut self, circuit: &Circuit, - new_wits_in: Vec>, + new_wits_in: Vec>, n_instances: usize, - ) where - E: ExtensionField, - { + ) { assert_eq!(new_wits_in.len(), circuit.n_witness_in); assert!(n_instances.is_power_of_two()); - assert!(!new_wits_in - .iter() - .any(|wit_in| wit_in.instances.len() != n_instances)); + assert!( + new_wits_in + .iter() + .all(|wit_in| wit_in.evaluations().len() % n_instances == 0) + ); let (inferred_layer_wits, inferred_wits_out) = - CircuitWitness::new_instances(circuit, &new_wits_in, &self.challenges, n_instances); + CircuitWitness::new_instances_v2(circuit, &new_wits_in, &self.challenges, n_instances); - // Merge self and circuit_witness. - for (layer_wit, inferred_layer_wit) in - self.layers.iter_mut().zip(inferred_layer_wits.into_iter()) - { - layer_wit.instances.extend(inferred_layer_wit.instances); + assert_eq!(self.layers.len(), inferred_layer_wits.len()); + self.layers = inferred_layer_wits.into_iter().map(|n| n.into()).collect(); + assert_eq!(self.witness_out.len(), inferred_wits_out.len()); + self.witness_out = inferred_wits_out.into_iter().map(|n| n.into()).collect(); + assert_eq!(self.witness_in.len(), new_wits_in.len()); + self.witness_in = new_wits_in; + + self.n_instances = n_instances; + + // check correctness in debug build + if cfg!(debug_assertions) { + self.check_correctness(circuit); } + } + + pub fn add_instances( + &mut self, + circuit: &Circuit, + new_wits_in: Vec>, + n_instances: usize, + ) { + assert_eq!(new_wits_in.len(), circuit.n_witness_in); + assert!(n_instances.is_power_of_two()); + assert!( + new_wits_in + .iter() + .all(|wit_in| wit_in.evaluations().len() % n_instances == 0) + ); + + let (inferred_layer_wits, inferred_wits_out) = + CircuitWitness::new_instances(circuit, &new_wits_in[..], &self.challenges, n_instances); for (wit_out, inferred_wits_out) in self .witness_out .iter_mut() .zip(inferred_wits_out.into_iter()) { - wit_out.instances.extend(inferred_wits_out.instances); + Arc::get_mut(wit_out).unwrap().merge(inferred_wits_out); } for (wit_in, new_wit_in) in self.witness_in.iter_mut().zip(new_wits_in.into_iter()) { - wit_in.instances.extend(new_wit_in.instances); + Arc::get_mut(wit_in).unwrap().merge(new_wit_in); + } + + // Merge self and circuit_witness. + for (layer_wit, inferred_layer_wit) in + self.layers.iter_mut().zip(inferred_layer_wits.into_iter()) + { + Arc::get_mut(layer_wit).unwrap().merge(inferred_layer_wit); } self.n_instances += n_instances; @@ -221,172 +417,170 @@ impl CircuitWitness { ceil_log2(self.n_instances) } - pub fn check_correctness(&self, circuit: &Circuit) - where - Ext: ExtensionField, - { + pub fn check_correctness(&self, _circuit: &Circuit) { // Check input. - - let input_layer_wits = self.layers.last().unwrap(); - let wits_in = self.witness_in_ref(); - for copy_id in 0..self.n_instances { - for (wit_id, (l, r)) in circuit.paste_from_wits_in.iter().enumerate() { - for (subset_wire_id, new_wire_id) in (*l..*r).enumerate() { - assert_eq!( - input_layer_wits.instances[copy_id][new_wire_id], - wits_in[wit_id].instances[copy_id][subset_wire_id], - "input layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != {:?}", - circuit.layers.len() - 1, - copy_id, - new_wire_id, - input_layer_wits.instances[copy_id][new_wire_id], - wits_in[wit_id].instances[copy_id][subset_wire_id] - ); - } - } - for (constant, (l, r)) in circuit.paste_from_consts_in.iter() { - for (_subset_wire_id, new_wire_id) in (*l..*r).enumerate() { - assert_eq!( - input_layer_wits.instances[copy_id][new_wire_id], - i64_to_field(*constant), - "input layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != {:?}", - circuit.layers.len() - 1, - copy_id, - new_wire_id, - input_layer_wits.instances[copy_id][new_wire_id], - constant - ); - } - } - for (num_vars, (l, r)) in circuit.paste_from_counter_in.iter() { - for (subset_wire_id, new_wire_id) in (*l..*r).enumerate() { - assert_eq!( - input_layer_wits.instances[copy_id][new_wire_id], - i64_to_field(((copy_id << num_vars) ^ subset_wire_id) as i64), - "input layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != {:?}", - circuit.layers.len() - 1, - copy_id, - new_wire_id, - input_layer_wits.instances[copy_id][new_wire_id], - (copy_id << num_vars) ^ subset_wire_id - ); - } - } - } - - for (layer_id, (layer_witnesses, layer)) in self - .layers - .iter() - .zip(circuit.layers.iter()) - .enumerate() - .rev() - .skip(1) - { - let prev_layer_wits = &self.layers[layer_id + 1]; - for (copy_id, (prev, curr)) in prev_layer_wits - .instances - .iter() - .zip(layer_witnesses.instances.iter()) - .enumerate() - { - let mut expected = vec![F::ZERO; curr.len()]; - for add_const in layer.add_consts.iter() { - expected[add_const.idx_out] += add_const.scalar.eval(&self.challenges); - } - for add in layer.adds.iter() { - expected[add.idx_out] += - prev[add.idx_in[0]] * add.scalar.eval(&self.challenges); - } - for mul2 in layer.mul2s.iter() { - expected[mul2.idx_out] += prev[mul2.idx_in[0]] - * prev[mul2.idx_in[1]] - * mul2.scalar.eval(&self.challenges); - } - for mul3 in layer.mul3s.iter() { - expected[mul3.idx_out] += prev[mul3.idx_in[0]] - * prev[mul3.idx_in[1]] - * prev[mul3.idx_in[2]] - * mul3.scalar.eval(&self.challenges); - } - - let mut expected_max_previous_size = prev.len(); - for (old_layer_id, new_wire_ids) in layer.paste_from.iter() { - expected_max_previous_size = expected_max_previous_size.max(new_wire_ids.len()); - for (subset_wire_id, new_wire_id) in new_wire_ids.iter().enumerate() { - let old_wire_id = circuit.layers[*old_layer_id as usize] - .copy_to - .get(&(layer_id as LayerId)) - .unwrap()[subset_wire_id]; - expected[*new_wire_id] = - self.layers[*old_layer_id as usize].instances[copy_id][old_wire_id]; - } - } - assert_eq!( - ceil_log2(expected_max_previous_size), - layer.max_previous_num_vars, - "layer: {}, expected_max_previous_size: {}, got: {}", - layer_id, - expected_max_previous_size, - layer.max_previous_num_vars - ); - for (wire_id, (got, expected)) in curr.iter().zip(expected.iter()).enumerate() { - assert_eq!( - *got, *expected, - "layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != {:?}", - layer_id, copy_id, wire_id, got, expected - ); - } - - if layer_id != 0 { - for (new_layer_id, old_wire_ids) in layer.copy_to.iter() { - for (subset_wire_id, old_wire_id) in old_wire_ids.iter().enumerate() { - let new_wire_id = circuit.layers[*new_layer_id as usize] - .paste_from - .get(&(layer_id as LayerId)) - .unwrap()[subset_wire_id]; - assert_eq!( - curr[*old_wire_id], - self.layers[*new_layer_id as usize].instances[copy_id][new_wire_id], - "copy_to check: layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != {:?}", - layer_id, - copy_id, - old_wire_id, - curr[*old_wire_id], - self.layers[*new_layer_id as usize].instances[copy_id][new_wire_id] - ) - } - } - } - } - } - - let output_layer_witness = &self.layers[0]; - let wits_out = self.witness_out_ref(); - for (wit_id, old_wire_ids) in circuit.copy_to_wits_out.iter().enumerate() { - for copy_id in 0..self.n_instances { - for (new_wire_id, old_wire_id) in old_wire_ids.iter().enumerate() { - assert_eq!( - output_layer_witness.instances[copy_id][*old_wire_id], - wits_out[wit_id].instances[copy_id][new_wire_id] - ); - } - } - } - for gate in circuit.assert_consts.iter() { - if let ConstantType::Field(constant) = gate.scalar { - for copy_id in 0..self.n_instances { - assert_eq!( - output_layer_witness.instances[copy_id][gate.idx_out], - constant - ); - } - } - } + return; + + // let input_layer_wits = self.layers.last().unwrap(); + // let wits_in = self.witness_in_ref(); + // for copy_id in 0..self.n_instances { + // for (wit_id, (l, r)) in circuit.paste_from_wits_in.iter().enumerate() { + // for (subset_wire_id, new_wire_id) in (*l..*r).enumerate() { + // assert_eq!( + // input_layer_wits.instances[copy_id][new_wire_id], + // wits_in[wit_id].instances[copy_id][subset_wire_id], + // "input layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != + // {:?}", circuit.layers.len() - 1, + // copy_id, + // new_wire_id, + // input_layer_wits.instances[copy_id][new_wire_id], + // wits_in[wit_id].instances[copy_id][subset_wire_id] + // ); + // } + // } + // for (constant, (l, r)) in circuit.paste_from_consts_in.iter() { + // for (_subset_wire_id, new_wire_id) in (*l..*r).enumerate() { + // assert_eq!( + // input_layer_wits.instances[copy_id][new_wire_id], + // i64_to_field(*constant), + // "input layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != + // {:?}", circuit.layers.len() - 1, + // copy_id, + // new_wire_id, + // input_layer_wits.instances[copy_id][new_wire_id], + // constant + // ); + // } + // } + // for (num_vars, (l, r)) in circuit.paste_from_counter_in.iter() { + // for (subset_wire_id, new_wire_id) in (*l..*r).enumerate() { + // assert_eq!( + // input_layer_wits.instances[copy_id][new_wire_id], + // i64_to_field(((copy_id << num_vars) ^ subset_wire_id) as i64), + // "input layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != + // {:?}", circuit.layers.len() - 1, + // copy_id, + // new_wire_id, + // input_layer_wits.instances[copy_id][new_wire_id], + // (copy_id << num_vars) ^ subset_wire_id + // ); + // } + // } + // } + + // for (layer_id, (layer_witnesses, layer)) in self + // .layers + // .iter() + // .zip(circuit.layers.iter()) + // .enumerate() + // .rev() + // .skip(1) + // { + // let prev_layer_wits = &self.layers[layer_id + 1]; + // for (copy_id, (prev, curr)) in prev_layer_wits + // .instances + // .iter() + // .zip(layer_witnesses.instances.iter()) + // .enumerate() + // { + // let mut expected = vec![E::ZERO; curr.len()]; + // for add_const in layer.add_consts.iter() { + // expected[add_const.idx_out] += add_const.scalar.eval(&self.challenges); + // } + // for add in layer.adds.iter() { + // expected[add.idx_out] += + // prev[add.idx_in[0]] * add.scalar.eval(&self.challenges); + // } + // for mul2 in layer.mul2s.iter() { + // expected[mul2.idx_out] += prev[mul2.idx_in[0]] + // * prev[mul2.idx_in[1]] + // * mul2.scalar.eval(&self.challenges); + // } + // for mul3 in layer.mul3s.iter() { + // expected[mul3.idx_out] += prev[mul3.idx_in[0]] + // * prev[mul3.idx_in[1]] + // * prev[mul3.idx_in[2]] + // * mul3.scalar.eval(&self.challenges); + // } + + // let mut expected_max_previous_size = prev.len(); + // for (old_layer_id, new_wire_ids) in layer.paste_from.iter() { + // expected_max_previous_size = + // expected_max_previous_size.max(new_wire_ids.len()); for + // (subset_wire_id, new_wire_id) in new_wire_ids.iter().enumerate() { + // let old_wire_id = circuit.layers[*old_layer_id as usize] + // .copy_to .get(&(layer_id as LayerId)) + // .unwrap()[subset_wire_id]; + // expected[*new_wire_id] = + // self.layers[*old_layer_id as usize].instances[copy_id][old_wire_id]; + // } + // } + // assert_eq!( + // ceil_log2(expected_max_previous_size), + // layer.max_previous_num_vars, + // "layer: {}, expected_max_previous_size: {}, got: {}", + // layer_id, + // expected_max_previous_size, + // layer.max_previous_num_vars + // ); + // for (wire_id, (got, expected)) in curr.iter().zip(expected.iter()).enumerate() { + // assert_eq!( + // *got, *expected, + // "layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != {:?}", + // layer_id, copy_id, wire_id, got, expected + // ); + // } + + // if layer_id != 0 { + // for (new_layer_id, old_wire_ids) in layer.copy_to.iter() { + // for (subset_wire_id, old_wire_id) in old_wire_ids.iter().enumerate() { + // let new_wire_id = circuit.layers[*new_layer_id as usize] + // .paste_from + // .get(&(layer_id as LayerId)) + // .unwrap()[subset_wire_id]; + // assert_eq!( + // curr[*old_wire_id], + // self.layers[*new_layer_id as + // usize].instances[copy_id][new_wire_id], "copy_to check: + // layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != {:?}", + // layer_id, copy_id, + // old_wire_id, + // curr[*old_wire_id], + // self.layers[*new_layer_id as + // usize].instances[copy_id][new_wire_id] ) + // } + // } + // } + // } + // } + + // let output_layer_witness = &self.layers[0]; + // let wits_out = self.witness_out_ref(); + // for (wit_id, old_wire_ids) in circuit.copy_to_wits_out.iter().enumerate() { + // for copy_id in 0..self.n_instances { + // for (new_wire_id, old_wire_id) in old_wire_ids.iter().enumerate() { + // assert_eq!( + // output_layer_witness.instances[copy_id][*old_wire_id], + // wits_out[wit_id].instances[copy_id][new_wire_id] + // ); + // } + // } + // } + // for gate in circuit.assert_consts.iter() { + // if let ConstantType::Field(constant) = gate.scalar { + // for copy_id in 0..self.n_instances { + // assert_eq!( + // output_layer_witness.instances[copy_id][gate.idx_out], + // constant + // ); + // } + // } + // } } } -impl CircuitWitness { - pub fn output_layer_witness_ref(&self) -> &LayerWitness { +impl<'a, E: ExtensionField> CircuitWitness<'a, E> { + pub fn output_layer_witness_ref(&self) -> &ArcMultilinearExtension<'a, E> { self.layers.first().unwrap() } @@ -394,658 +588,59 @@ impl CircuitWitness { self.n_instances } - pub fn witness_in_ref(&self) -> &[LayerWitness] { + pub fn witness_in_ref(&self) -> &[ArcMultilinearExtension<'a, E>] { &self.witness_in } - pub fn witness_out_ref(&self) -> &[LayerWitness] { + pub fn witness_out_ref(&self) -> &[ArcMultilinearExtension<'a, E>] { &self.witness_out } - pub fn challenges(&self) -> &HashMap> { + pub fn challenges(&self) -> &HashMap> { &self.challenges } - pub fn layers_ref(&self) -> &[LayerWitness] { + pub fn layers_ref(&self) -> &[ArcMultilinearExtension<'a, E>] { &self.layers } } -impl CircuitWitness { - pub fn layer_poly>( - &self, - layer_id: LayerId, - single_num_vars: usize, - multi_threads_meta: (usize, usize), - ) -> ArcDenseMultilinearExtension { - self.layers[layer_id as usize] - .instances - .as_slice() - .mle_with_meta( - single_num_vars, - self.instance_num_vars(), - multi_threads_meta, - ) - } -} - -impl Debug for CircuitWitness { +// impl CircuitWitnessV2 { +// pub fn layer_poly( +// &self, +// layer_id: LayerId, +// single_num_vars: usize, +// multi_threads_meta: (usize, usize), +// ) -> ArcDenseMultilinearExtension { +// self.layers[layer_id as usize] +// .instances +// .as_slice() +// .mle_with_meta( +// single_num_vars, +// self.instance_num_vars(), +// multi_threads_meta, +// ) +// } +// } + +impl<'a, F: ExtensionField> Debug for CircuitWitness<'a, F> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "CircuitWitness {{")?; writeln!(f, " n_instances: {}", self.n_instances)?; writeln!(f, " layers: ")?; - for (i, layer) in self.layers.iter().enumerate() { - writeln!(f, " {}: {:?}", i, layer)?; - } + // for (i, layer) in self.layers.iter().enumerate() { + // writeln!(f, " {}: {:?}", i, layer)?; + // } writeln!(f, " wires_in: ")?; for (i, wire) in self.witness_in.iter().enumerate() { - writeln!(f, " {}: {:?}", i, wire)?; + // TODO figure out how to print dyn trait + writeln!(f, " {}: {:?}", i, &wire.name())?; } writeln!(f, " wires_out: ")?; - for (i, wire) in self.witness_out.iter().enumerate() { - writeln!(f, " {}: {:?}", i, wire)?; - } + // for (i, wire) in self.witness_out.iter().enumerate() { + // writeln!(f, " {}: {:?}", i, wire)?; + // } writeln!(f, " challenges: {:?}", self.challenges)?; writeln!(f, "}}") } } - -#[cfg(test)] -mod test { - use std::{collections::HashMap, ops::Neg}; - - use ff::Field; - use ff_ext::ExtensionField; - use goldilocks::GoldilocksExt2; - use itertools::Itertools; - use simple_frontend::structs::{ChallengeConst, ChallengeId, CircuitBuilder, ConstantType}; - - use crate::{ - structs::{Circuit, CircuitWitness, LayerWitness}, - utils::i64_to_field, - }; - - fn copy_and_paste_circuit() -> Circuit { - let mut circuit_builder = CircuitBuilder::::new(); - // Layer 3 - let (_, input) = circuit_builder.create_witness_in(4); - - // Layer 2 - let mul_01 = circuit_builder.create_cell(); - circuit_builder.mul2(mul_01, input[0], input[1], Ext::BaseField::ONE); - - // Layer 1 - let mul_012 = circuit_builder.create_cell(); - circuit_builder.mul2(mul_012, mul_01, input[2], Ext::BaseField::ONE); - - // Layer 0 - let (_, mul_001123) = circuit_builder.create_witness_out(1); - circuit_builder.mul3( - mul_001123[0], - mul_01, - mul_012, - input[3], - Ext::BaseField::ONE, - ); - - circuit_builder.configure(); - let circuit = Circuit::new(&circuit_builder); - - circuit - } - - fn copy_and_paste_witness() -> ( - Vec>, - CircuitWitness, - ) { - // witness_in, single instance - let inputs = vec![vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ]]; - let witness_in = vec![LayerWitness { instances: inputs }]; - - let layers = vec![ - LayerWitness { - instances: vec![vec![i64_to_field(175175)]], - }, - LayerWitness { - instances: vec![vec![ - i64_to_field(385), - i64_to_field(35), - i64_to_field(13), - i64_to_field(0), // pad - ]], - }, - LayerWitness { - instances: vec![vec![i64_to_field(35), i64_to_field(11)]], - }, - LayerWitness { - instances: vec![vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ]], - }, - ]; - - let outputs = vec![vec![i64_to_field(175175)]]; - let witness_out = vec![LayerWitness { instances: outputs }]; - - ( - witness_in.clone(), - CircuitWitness { - layers, - witness_in, - witness_out, - n_instances: 1, - challenges: HashMap::new(), - }, - ) - } - - fn paste_from_wit_in_circuit() -> Circuit { - let mut circuit_builder = CircuitBuilder::::new(); - - // Layer 2 - let (_leaf_id1, leaves1) = circuit_builder.create_witness_in(3); - let (_leaf_id2, leaves2) = circuit_builder.create_witness_in(3); - // Unused input elements should also be in the circuit. - let (_dummy_id, _) = circuit_builder.create_witness_in(3); - let _ = circuit_builder.create_counter_in(1); - let _ = circuit_builder.create_constant_in(2, 1); - - // Layer 1 - let (_, inners) = circuit_builder.create_witness_out(2); - circuit_builder.mul2(inners[0], leaves1[0], leaves1[1], Ext::BaseField::ONE); - circuit_builder.mul2(inners[1], leaves1[2], leaves2[0], Ext::BaseField::ONE); - - // Layer 0 - let (_, root) = circuit_builder.create_witness_out(1); - circuit_builder.mul2(root[0], inners[0], inners[1], Ext::BaseField::ONE); - - circuit_builder.configure(); - let circuit = Circuit::new(&circuit_builder); - circuit - } - - fn paste_from_wit_in_witness() -> ( - Vec>, - CircuitWitness, - ) { - // witness_in, single instance - let leaves1 = vec![vec![i64_to_field(5), i64_to_field(7), i64_to_field(11)]]; - let leaves2 = vec![vec![i64_to_field(13), i64_to_field(17), i64_to_field(19)]]; - let dummy = vec![vec![i64_to_field(13), i64_to_field(17), i64_to_field(19)]]; - let witness_in = vec![ - LayerWitness { instances: leaves1 }, - LayerWitness { instances: leaves2 }, - LayerWitness { instances: dummy }, - ]; - - let layers = vec![ - LayerWitness { - instances: vec![vec![ - i64_to_field(5005), - i64_to_field(35), - i64_to_field(143), - i64_to_field(0), // pad - ]], - }, - LayerWitness { - instances: vec![vec![i64_to_field(35), i64_to_field(143)]], - }, - LayerWitness { - instances: vec![vec![ - i64_to_field(5), // leaves1 - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), // leaves2 - i64_to_field(17), - i64_to_field(19), - i64_to_field(13), // dummy - i64_to_field(17), - i64_to_field(19), - i64_to_field(0), // counter - i64_to_field(1), - i64_to_field(1), // constant - i64_to_field(1), - i64_to_field(0), // pad - i64_to_field(0), - i64_to_field(0), - ]], - }, - ]; - - let outputs1 = vec![vec![i64_to_field(35), i64_to_field(143)]]; - let outputs2 = vec![vec![i64_to_field(5005)]]; - let witness_out = vec![ - LayerWitness { - instances: outputs1, - }, - LayerWitness { - instances: outputs2, - }, - ]; - - ( - witness_in.clone(), - CircuitWitness { - layers, - witness_in, - witness_out, - n_instances: 1, - challenges: HashMap::new(), - }, - ) - } - - fn copy_to_wit_out_circuit() -> Circuit { - let mut circuit_builder = CircuitBuilder::::new(); - // Layer 2 - let (_, leaves) = circuit_builder.create_witness_in(4); - - // Layer 1 - let (_inner_id, inners) = circuit_builder.create_witness_out(2); - circuit_builder.mul2(inners[0], leaves[0], leaves[1], Ext::BaseField::ONE); - circuit_builder.mul2(inners[1], leaves[2], leaves[3], Ext::BaseField::ONE); - - // Layer 0 - let root = circuit_builder.create_cell(); - circuit_builder.mul2(root, inners[0], inners[1], Ext::BaseField::ONE); - circuit_builder.assert_const(root, 5005); - - circuit_builder.configure(); - let circuit = Circuit::new(&circuit_builder); - - circuit - } - - fn copy_to_wit_out_witness() -> ( - Vec>, - CircuitWitness, - ) { - // witness_in, single instance - let leaves = vec![vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ]]; - let witness_in = vec![LayerWitness { instances: leaves }]; - - let layers = vec![ - LayerWitness { - instances: vec![vec![ - i64_to_field(5005), - i64_to_field(35), - i64_to_field(143), - i64_to_field(0), // pad - ]], - }, - LayerWitness { - instances: vec![vec![i64_to_field(35), i64_to_field(143)]], - }, - LayerWitness { - instances: vec![vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ]], - }, - ]; - - let outputs = vec![vec![i64_to_field(35), i64_to_field(143)]]; - let witness_out = vec![LayerWitness { instances: outputs }]; - - ( - witness_in.clone(), - CircuitWitness { - layers, - witness_in, - witness_out, - n_instances: 1, - challenges: HashMap::new(), - }, - ) - } - - fn copy_to_wit_out_witness_2() -> ( - Vec>, - CircuitWitness, - ) { - // witness_in, 2 instances - let leaves = vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], - ]; - let witness_in = vec![LayerWitness { instances: leaves }]; - - let layers = vec![ - LayerWitness { - instances: vec![ - vec![ - i64_to_field(5005), - i64_to_field(35), - i64_to_field(143), - i64_to_field(0), // pad - ], - vec![ - i64_to_field(5005), - i64_to_field(65), - i64_to_field(77), - i64_to_field(0), // pad - ], - ], - }, - LayerWitness { - instances: vec![ - vec![i64_to_field(35), i64_to_field(143)], - vec![i64_to_field(65), i64_to_field(77)], - ], - }, - LayerWitness { - instances: vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], - ], - }, - ]; - - let outputs = vec![ - vec![i64_to_field(35), i64_to_field(143)], - vec![i64_to_field(65), i64_to_field(77)], - ]; - let witness_out = vec![LayerWitness { instances: outputs }]; - - ( - witness_in.clone(), - CircuitWitness { - layers, - witness_in, - witness_out, - n_instances: 2, - challenges: HashMap::new(), - }, - ) - } - - fn rlc_circuit() -> Circuit { - let mut circuit_builder = CircuitBuilder::::new(); - // Layer 2 - let (_, leaves) = circuit_builder.create_witness_in(4); - - // Layer 1 - let inners = circuit_builder.create_ext_cells(2); - circuit_builder.rlc(&inners[0], &[leaves[0], leaves[1]], 0 as ChallengeId); - circuit_builder.rlc(&inners[1], &[leaves[2], leaves[3]], 1 as ChallengeId); - - // Layer 0 - let (_root_id, roots) = circuit_builder.create_ext_witness_out(1); - circuit_builder.mul2_ext(&roots[0], &inners[0], &inners[1], Ext::BaseField::ONE); - - circuit_builder.configure(); - let circuit = Circuit::new(&circuit_builder); - - circuit - } - - fn rlc_witness_2() -> ( - Vec>, - CircuitWitness, - Vec, - ) - where - Ext: ExtensionField, - { - let challenges = vec![ - Ext::from_bases(&[i64_to_field(31), i64_to_field(37)]), - Ext::from_bases(&[i64_to_field(97), i64_to_field(23)]), - ]; - let challenge_pows = challenges - .iter() - .enumerate() - .map(|(i, x)| { - (0..3) - .map(|j| { - ( - ChallengeConst { - challenge: i as u8, - exp: j as u64, - }, - x.pow(&[j as u64]), - ) - }) - .collect_vec() - }) - .collect_vec(); - - // witness_in, double instances - let leaves = vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], - ]; - let witness_in = vec![LayerWitness { - instances: leaves.clone(), - }]; - - let inner00: Ext = challenge_pows[0][0].1 * (&leaves[0][0]) - + challenge_pows[0][1].1 * (&leaves[0][1]) - + challenge_pows[0][2].1; - let inner01: Ext = challenge_pows[1][0].1 * (&leaves[0][2]) - + challenge_pows[1][1].1 * (&leaves[0][3]) - + challenge_pows[1][2].1; - let inner10: Ext = challenge_pows[0][0].1 * (&leaves[1][0]) - + challenge_pows[0][1].1 * (&leaves[1][1]) - + challenge_pows[0][2].1; - let inner11: Ext = challenge_pows[1][0].1 * (&leaves[1][2]) - + challenge_pows[1][1].1 * (&leaves[1][3]) - + challenge_pows[1][2].1; - - let inners = vec![ - [ - inner00.clone().as_bases().to_vec(), - inner01.clone().as_bases().to_vec(), - ] - .concat(), - [ - inner10.clone().as_bases().to_vec(), - inner11.clone().as_bases().to_vec(), - ] - .concat(), - ]; - - let root_tmp0 = vec![ - inners[0][0] * inners[0][2], - inners[0][0] * inners[0][3], - inners[0][1] * inners[0][2], - inners[0][1] * inners[0][3], - ]; - let root_tmp1 = vec![ - inners[1][0] * inners[1][2], - inners[1][0] * inners[1][3], - inners[1][1] * inners[1][2], - inners[1][1] * inners[1][3], - ]; - let root_tmps = vec![root_tmp0, root_tmp1]; - - let root0 = inner00 * inner01; - let root1 = inner10 * inner11; - let roots = vec![root0.as_bases().to_vec(), root1.as_bases().to_vec()]; - - let layers = vec![ - LayerWitness { - instances: roots.clone(), - }, - LayerWitness { - instances: root_tmps, - }, - LayerWitness { instances: inners }, - LayerWitness { instances: leaves }, - ]; - - let outputs = roots; - let witness_out = vec![LayerWitness { instances: outputs }]; - - ( - witness_in.clone(), - CircuitWitness { - layers, - witness_in, - witness_out, - n_instances: 2, - challenges: challenge_pows - .iter() - .flatten() - .cloned() - .map(|(k, v)| (k, v.as_bases().to_vec())) - .collect::>(), - }, - challenges, - ) - } - - #[test] - fn test_add_instances() { - let circuit = copy_and_paste_circuit::(); - let (wits_in, expect_circuit_wits) = copy_and_paste_witness::(); - - let mut circuit_wits = CircuitWitness::new(&circuit, vec![]); - circuit_wits.add_instances(&circuit, wits_in, 1); - - assert_eq!(circuit_wits, expect_circuit_wits); - - let circuit = paste_from_wit_in_circuit::(); - let (wits_in, expect_circuit_wits) = paste_from_wit_in_witness::(); - - let mut circuit_wits = CircuitWitness::new(&circuit, vec![]); - circuit_wits.add_instances(&circuit, wits_in, 1); - - assert_eq!(circuit_wits, expect_circuit_wits); - - let circuit = copy_to_wit_out_circuit::(); - let (wits_in, expect_circuit_wits) = copy_to_wit_out_witness::(); - - let mut circuit_wits = CircuitWitness::new(&circuit, vec![]); - circuit_wits.add_instances(&circuit, wits_in, 1); - - assert_eq!(circuit_wits, expect_circuit_wits); - - let (wits_in, expect_circuit_wits) = copy_to_wit_out_witness_2::(); - let mut circuit_wits = CircuitWitness::new(&circuit, vec![]); - circuit_wits.add_instances(&circuit, wits_in, 2); - - assert_eq!(circuit_wits, expect_circuit_wits); - } - - #[test] - fn test_check_correctness() { - let circuit = copy_to_wit_out_circuit::(); - let (_wits_in, expect_circuit_wits) = copy_to_wit_out_witness_2::(); - - expect_circuit_wits.check_correctness(&circuit); - } - - #[test] - fn test_challenges() { - let circuit = rlc_circuit::(); - let (wits_in, expect_circuit_wits, challenges) = rlc_witness_2::(); - let mut circuit_wits = CircuitWitness::new(&circuit, challenges); - circuit_wits.add_instances(&circuit, wits_in, 2); - - assert_eq!(circuit_wits, expect_circuit_wits); - } - - #[test] - fn test_orphan_const_input() { - // create circuit - let mut circuit_builder = CircuitBuilder::::new(); - - let (_, leaves) = circuit_builder.create_witness_in(3); - let mul_0_1_res = circuit_builder.create_cell(); - - // 2 * 3 = 6 - circuit_builder.mul2( - mul_0_1_res, - leaves[0], - leaves[1], - ::BaseField::ONE, - ); - - let (_, out) = circuit_builder.create_witness_out(2); - // like a bypass gate, passing 6 to output out[0] - circuit_builder.add( - out[0], - mul_0_1_res, - ::BaseField::ONE, - ); - - // assert const 2 - circuit_builder.assert_const(leaves[2], 5); - - // 5 + -5 = 0, put in out[1] - circuit_builder.add( - out[1], - leaves[2], - ::BaseField::ONE, - ); - circuit_builder.add_const( - out[1], - ::BaseField::from(5).neg(), // -5 - ); - - // assert out[1] == 0 - circuit_builder.assert_const(out[1], 0); - - circuit_builder.configure(); - let circuit = Circuit::new(&circuit_builder); - - let mut circuit_wits = CircuitWitness::new(&circuit, vec![]); - let witness_in = vec![LayerWitness { - instances: vec![vec![i64_to_field(2), i64_to_field(3), i64_to_field(5)]], - }]; - circuit_wits.add_instances(&circuit, witness_in, 1); - - println!("circuit_wits {:?}", circuit_wits); - let output_layer_witness = &circuit_wits.layers[0]; - for gate in circuit.assert_consts.iter() { - if let ConstantType::Field(constant) = gate.scalar { - assert_eq!(output_layer_witness.instances[0][gate.idx_out], constant); - } - } - } -} diff --git a/gkr/src/circuit/circuit_witness_v2.rs b/gkr/src/circuit/circuit_witness_v2.rs deleted file mode 100644 index f6afa45e4..000000000 --- a/gkr/src/circuit/circuit_witness_v2.rs +++ /dev/null @@ -1,654 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use crate::circuit::EvaluateConstant; -use ff::Field; -use ff_ext::ExtensionField; -use itertools::{izip, Itertools}; -use multilinear_extensions::{ - mle::{DenseMultilinearExtension, MultilinearExtension}, - virtual_poly_v2::ArcMultilinearExtension, -}; -use simple_frontend::structs::{ChallengeConst, LayerId}; -use std::fmt::Debug; -use sumcheck::util::ceil_log2; - -use crate::{ - structs::{Circuit, CircuitWitnessV2}, - utils::i64_to_field, -}; - -impl<'a, E: ExtensionField> CircuitWitnessV2<'a, E> { - /// Initialize the structure of the circuit witness. - pub fn new(circuit: &Circuit, challenges: Vec) -> Self { - let create_default = |size| { - (0..size) - .map(|_| { - let a: ArcMultilinearExtension = - Arc::new(DenseMultilinearExtension::default()); - a - }) - .collect::>>() - }; - Self { - layers: create_default(circuit.layers.len()), - witness_in: create_default(circuit.n_witness_in), - witness_out: create_default(circuit.n_witness_out), - n_instances: 0, - challenges: circuit.generate_basefield_challenges(&challenges), - } - } - - /// Generate a fresh instance for the circuit, return layer witnesses and - /// wire out witnesses. - fn new_instances( - circuit: &Circuit, - wits_in: &[DenseMultilinearExtension], - challenges: &HashMap>, - n_instances: usize, - ) -> ( - Vec>, - Vec>, - ) { - let n_layers = circuit.layers.len(); - let mut layer_wits = vec![DenseMultilinearExtension::default(); n_layers]; - - // The first layer. - layer_wits[n_layers - 1] = { - let mut layer_wit = - vec![E::BaseField::ZERO; circuit.layers[n_layers - 1].size() * n_instances]; - for instance_id in 0..n_instances { - let instance_start_index = instance_id * circuit.layers[n_layers - 1].size(); - assert_eq!(wits_in.len(), circuit.paste_from_wits_in.len()); - for (wit_id, (l, r)) in circuit.paste_from_wits_in.iter().enumerate() { - for i in *l..*r { - let wit_in = wits_in[wit_id as usize].get_base_field_vec(); - layer_wit[instance_start_index + i] = wit_in[instance_start_index + i - *l]; - } - } - for (constant, (l, r)) in circuit.paste_from_consts_in.iter() { - for i in *l..*r { - layer_wit[instance_start_index + i] = i64_to_field(*constant); - } - } - for (num_vars, (l, r)) in circuit.paste_from_counter_in.iter() { - for i in *l..*r { - layer_wit[instance_start_index + i] = - E::BaseField::from(((instance_id << num_vars) ^ (i - *l)) as u64) - } - } - } - - DenseMultilinearExtension::from_evaluations_vec( - ceil_log2(circuit.layers[n_layers - 1].size() * n_instances), - layer_wit, - ) - }; - - for (layer_id, layer) in circuit.layers.iter().enumerate().rev().skip(1) { - let size = circuit.layers[layer_id].size(); - let mut current_layer_wit = vec![E::BaseField::ZERO; size * n_instances]; - - izip!(0..n_instances).for_each(|instance_id| { - let new_layer_instance_start_index = - instance_id * circuit.layers[layer_id as usize].size(); - layer - .paste_from - .iter() - .for_each(|(old_layer_id, new_wire_ids)| { - let layer_wits = layer_wits[*old_layer_id as usize].get_base_field_vec(); - let old_layer_instance_start_index = - instance_id * circuit.layers[*old_layer_id as usize].size(); - new_wire_ids.iter().enumerate().for_each( - |(subset_wire_id, new_wire_id)| { - let old_wire_id = circuit.layers[*old_layer_id as usize] - .copy_to - .get(&(layer_id as LayerId)) - .unwrap()[subset_wire_id]; - current_layer_wit[new_layer_instance_start_index + *new_wire_id] = - layer_wits[old_layer_instance_start_index + old_wire_id]; - }, - ); - }); - - let last_layer_wit = layer_wits[layer_id + 1].get_base_field_vec(); - let last_layer_instance_start_index = - instance_id * circuit.layers[layer_id as usize + 1].size(); - for add_const in layer.add_consts.iter() { - current_layer_wit[new_layer_instance_start_index + add_const.idx_out] += - add_const.scalar.eval(&challenges); - } - - for add in layer.adds.iter() { - current_layer_wit[new_layer_instance_start_index + add.idx_out] += - last_layer_wit[last_layer_instance_start_index + add.idx_in[0]] - * add.scalar.eval(&challenges); - } - - for mul2 in layer.mul2s.iter() { - current_layer_wit[new_layer_instance_start_index + mul2.idx_out] += - last_layer_wit[last_layer_instance_start_index + mul2.idx_in[0]] - * last_layer_wit[last_layer_instance_start_index + mul2.idx_in[1]] - * mul2.scalar.eval(&challenges); - } - - for mul3 in layer.mul3s.iter() { - current_layer_wit[new_layer_instance_start_index + mul3.idx_out] += - last_layer_wit[last_layer_instance_start_index + mul3.idx_in[0]] - * last_layer_wit[last_layer_instance_start_index + mul3.idx_in[1]] - * last_layer_wit[last_layer_instance_start_index + mul3.idx_in[2]] - * mul3.scalar.eval(&challenges); - } - }); - - layer_wits[layer_id] = DenseMultilinearExtension::from_evaluations_vec( - ceil_log2(size * n_instances), - current_layer_wit, - ); - } - let mut wits_out = vec![vec![]; circuit.n_witness_out]; - for instance_id in 0..n_instances { - let last_layer_instance_start_index = instance_id * circuit.layers[0].size(); - circuit - .copy_to_wits_out - .iter() - .enumerate() - .for_each(|(wit_id, old_wire_ids)| { - let layer_wit = layer_wits[0].get_base_field_vec(); - let mut wit_out = old_wire_ids - .iter() - .map(|old_wire_id| { - layer_wit[last_layer_instance_start_index + *old_wire_id] - }) - .collect_vec(); - let length = wit_out.len().next_power_of_two(); - wit_out.resize(length, E::BaseField::ZERO); - wits_out[wit_id].extend(wit_out); - }); - - // #[cfg(debug_assertions)] - // circuit.assert_consts.iter().for_each(|gate| { - // if let ConstantType::Field(constant) = gate.scalar { - // assert_eq!(layer_wits[0].instances[instance_id][gate.idx_out], constant); - // } - // }); - } - let wits_out = wits_out - .into_iter() - .map(|wit_out| { - let num_vars = ceil_log2(wit_out.len()); - DenseMultilinearExtension::from_evaluations_vec(num_vars, wit_out) - }) - .collect(); - (layer_wits, wits_out) - } - - /// Generate a fresh instance for the circuit, return layer witnesses and - /// wire out witnesses. - fn new_instances_v2( - circuit: &Circuit, - wits_in: &[ArcMultilinearExtension<'a, E>], - challenges: &HashMap>, - n_instances: usize, - ) -> ( - Vec>, - Vec>, - ) { - let n_layers = circuit.layers.len(); - let mut layer_wits = vec![DenseMultilinearExtension::default(); n_layers]; - - // The first layer. - layer_wits[n_layers - 1] = { - let mut layer_wit = - vec![E::BaseField::ZERO; circuit.layers[n_layers - 1].size() * n_instances]; - for instance_id in 0..n_instances { - let layer_wit_instance_start_index = - instance_id * circuit.layers[n_layers - 1].size(); - assert_eq!(wits_in.len(), circuit.paste_from_wits_in.len()); - for (wit_id, (l, r)) in circuit.paste_from_wits_in.iter().enumerate() { - let wit_in = wits_in[wit_id as usize].get_base_field_vec(); - let wit_in_instance_start_index = instance_id * wit_in.len() / n_instances; - for i in *l..*r { - layer_wit[layer_wit_instance_start_index + i] = - wit_in[wit_in_instance_start_index + i - *l]; - } - } - for (constant, (l, r)) in circuit.paste_from_consts_in.iter() { - for i in *l..*r { - layer_wit[layer_wit_instance_start_index + i] = i64_to_field(*constant); - } - } - for (num_vars, (l, r)) in circuit.paste_from_counter_in.iter() { - for i in *l..*r { - layer_wit[layer_wit_instance_start_index + i] = - E::BaseField::from(((instance_id << num_vars) ^ (i - *l)) as u64) - } - } - } - - DenseMultilinearExtension::from_evaluations_vec( - ceil_log2(circuit.layers[n_layers - 1].size() * n_instances), - layer_wit, - ) - }; - - for (layer_id, layer) in circuit.layers.iter().enumerate().rev().skip(1) { - let size = circuit.layers[layer_id].size(); - let mut current_layer_wit = vec![E::BaseField::ZERO; size * n_instances]; - - izip!(0..n_instances).for_each(|instance_id| { - let new_layer_instance_start_index = - instance_id * circuit.layers[layer_id as usize].size(); - layer - .paste_from - .iter() - .for_each(|(old_layer_id, new_wire_ids)| { - let layer_wits = layer_wits[*old_layer_id as usize].get_base_field_vec(); - let old_layer_instance_start_index = - instance_id * circuit.layers[*old_layer_id as usize].size(); - new_wire_ids.iter().enumerate().for_each( - |(subset_wire_id, new_wire_id)| { - let old_wire_id = circuit.layers[*old_layer_id as usize] - .copy_to - .get(&(layer_id as LayerId)) - .unwrap()[subset_wire_id]; - current_layer_wit[new_layer_instance_start_index + *new_wire_id] = - layer_wits[old_layer_instance_start_index + old_wire_id]; - }, - ); - }); - - let last_layer_wit = layer_wits[layer_id + 1].get_base_field_vec(); - let last_layer_instance_start_index = - instance_id * circuit.layers[layer_id as usize + 1].size(); - for add_const in layer.add_consts.iter() { - current_layer_wit[new_layer_instance_start_index + add_const.idx_out] += - add_const.scalar.eval(&challenges); - } - - for add in layer.adds.iter() { - current_layer_wit[new_layer_instance_start_index + add.idx_out] += - last_layer_wit[last_layer_instance_start_index + add.idx_in[0]] - * add.scalar.eval(&challenges); - } - - for mul2 in layer.mul2s.iter() { - current_layer_wit[new_layer_instance_start_index + mul2.idx_out] += - last_layer_wit[last_layer_instance_start_index + mul2.idx_in[0]] - * last_layer_wit[last_layer_instance_start_index + mul2.idx_in[1]] - * mul2.scalar.eval(&challenges); - } - - for mul3 in layer.mul3s.iter() { - current_layer_wit[new_layer_instance_start_index + mul3.idx_out] += - last_layer_wit[last_layer_instance_start_index + mul3.idx_in[0]] - * last_layer_wit[last_layer_instance_start_index + mul3.idx_in[1]] - * last_layer_wit[last_layer_instance_start_index + mul3.idx_in[2]] - * mul3.scalar.eval(&challenges); - } - }); - - layer_wits[layer_id] = DenseMultilinearExtension::from_evaluations_vec( - ceil_log2(size * n_instances), - current_layer_wit, - ); - } - let mut wits_out = vec![vec![]; circuit.n_witness_out]; - for instance_id in 0..n_instances { - let last_layer_instance_start_index = instance_id * circuit.layers[0].size(); - circuit - .copy_to_wits_out - .iter() - .enumerate() - .for_each(|(wit_id, old_wire_ids)| { - let layer_wit = layer_wits[0].get_base_field_vec(); - let mut wit_out = old_wire_ids - .iter() - .map(|old_wire_id| { - layer_wit[last_layer_instance_start_index + *old_wire_id] - }) - .collect_vec(); - let length = wit_out.len().next_power_of_two(); - wit_out.resize(length, E::BaseField::ZERO); - wits_out[wit_id].extend(wit_out); - }); - - // #[cfg(debug_assertions)] - // circuit.assert_consts.iter().for_each(|gate| { - // if let ConstantType::Field(constant) = gate.scalar { - // assert_eq!(layer_wits[0].instances[instance_id][gate.idx_out], constant); - // } - // }); - } - let wits_out = wits_out - .into_iter() - .map(|wit_out| { - let num_vars = ceil_log2(wit_out.len()); - DenseMultilinearExtension::from_evaluations_vec(num_vars, wit_out) - }) - .collect(); - (layer_wits, wits_out) - } - - pub fn add_instance( - &mut self, - circuit: &Circuit, - wits_in: Vec>, - ) { - self.add_instances(circuit, wits_in, 1); - } - - pub fn set_instances( - &mut self, - circuit: &Circuit, - new_wits_in: Vec>, - n_instances: usize, - ) { - assert_eq!(new_wits_in.len(), circuit.n_witness_in); - assert!(n_instances.is_power_of_two()); - assert!( - new_wits_in - .iter() - .all(|wit_in| wit_in.evaluations().len() % n_instances == 0) - ); - - let (inferred_layer_wits, inferred_wits_out) = CircuitWitnessV2::new_instances_v2( - circuit, - &new_wits_in, - &self.challenges, - n_instances, - ); - - assert_eq!(self.layers.len(), inferred_layer_wits.len()); - self.layers = inferred_layer_wits.into_iter().map(|n| n.into()).collect(); - assert_eq!(self.witness_out.len(), inferred_wits_out.len()); - self.witness_out = inferred_wits_out.into_iter().map(|n| n.into()).collect(); - assert_eq!(self.witness_in.len(), new_wits_in.len()); - self.witness_in = new_wits_in; - - self.n_instances = n_instances; - - // check correctness in debug build - if cfg!(debug_assertions) { - self.check_correctness(circuit); - } - } - - pub fn add_instances( - &mut self, - circuit: &Circuit, - new_wits_in: Vec>, - n_instances: usize, - ) { - assert_eq!(new_wits_in.len(), circuit.n_witness_in); - assert!(n_instances.is_power_of_two()); - assert!( - new_wits_in - .iter() - .all(|wit_in| wit_in.evaluations().len() % n_instances == 0) - ); - - let (inferred_layer_wits, inferred_wits_out) = CircuitWitnessV2::new_instances( - circuit, - &new_wits_in[..], - &self.challenges, - n_instances, - ); - - for (wit_out, inferred_wits_out) in self - .witness_out - .iter_mut() - .zip(inferred_wits_out.into_iter()) - { - Arc::get_mut(wit_out).unwrap().merge(inferred_wits_out); - } - - for (wit_in, new_wit_in) in self.witness_in.iter_mut().zip(new_wits_in.into_iter()) { - Arc::get_mut(wit_in).unwrap().merge(new_wit_in); - } - - // Merge self and circuit_witness. - for (layer_wit, inferred_layer_wit) in - self.layers.iter_mut().zip(inferred_layer_wits.into_iter()) - { - Arc::get_mut(layer_wit).unwrap().merge(inferred_layer_wit); - } - - self.n_instances += n_instances; - - // check correctness in debug build - if cfg!(debug_assertions) { - self.check_correctness(circuit); - } - } - - pub fn instance_num_vars(&self) -> usize { - ceil_log2(self.n_instances) - } - - pub fn check_correctness(&self, _circuit: &Circuit) { - // Check input. - return; - - // let input_layer_wits = self.layers.last().unwrap(); - // let wits_in = self.witness_in_ref(); - // for copy_id in 0..self.n_instances { - // for (wit_id, (l, r)) in circuit.paste_from_wits_in.iter().enumerate() { - // for (subset_wire_id, new_wire_id) in (*l..*r).enumerate() { - // assert_eq!( - // input_layer_wits.instances[copy_id][new_wire_id], - // wits_in[wit_id].instances[copy_id][subset_wire_id], - // "input layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != - // {:?}", circuit.layers.len() - 1, - // copy_id, - // new_wire_id, - // input_layer_wits.instances[copy_id][new_wire_id], - // wits_in[wit_id].instances[copy_id][subset_wire_id] - // ); - // } - // } - // for (constant, (l, r)) in circuit.paste_from_consts_in.iter() { - // for (_subset_wire_id, new_wire_id) in (*l..*r).enumerate() { - // assert_eq!( - // input_layer_wits.instances[copy_id][new_wire_id], - // i64_to_field(*constant), - // "input layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != - // {:?}", circuit.layers.len() - 1, - // copy_id, - // new_wire_id, - // input_layer_wits.instances[copy_id][new_wire_id], - // constant - // ); - // } - // } - // for (num_vars, (l, r)) in circuit.paste_from_counter_in.iter() { - // for (subset_wire_id, new_wire_id) in (*l..*r).enumerate() { - // assert_eq!( - // input_layer_wits.instances[copy_id][new_wire_id], - // i64_to_field(((copy_id << num_vars) ^ subset_wire_id) as i64), - // "input layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != - // {:?}", circuit.layers.len() - 1, - // copy_id, - // new_wire_id, - // input_layer_wits.instances[copy_id][new_wire_id], - // (copy_id << num_vars) ^ subset_wire_id - // ); - // } - // } - // } - - // for (layer_id, (layer_witnesses, layer)) in self - // .layers - // .iter() - // .zip(circuit.layers.iter()) - // .enumerate() - // .rev() - // .skip(1) - // { - // let prev_layer_wits = &self.layers[layer_id + 1]; - // for (copy_id, (prev, curr)) in prev_layer_wits - // .instances - // .iter() - // .zip(layer_witnesses.instances.iter()) - // .enumerate() - // { - // let mut expected = vec![E::ZERO; curr.len()]; - // for add_const in layer.add_consts.iter() { - // expected[add_const.idx_out] += add_const.scalar.eval(&self.challenges); - // } - // for add in layer.adds.iter() { - // expected[add.idx_out] += - // prev[add.idx_in[0]] * add.scalar.eval(&self.challenges); - // } - // for mul2 in layer.mul2s.iter() { - // expected[mul2.idx_out] += prev[mul2.idx_in[0]] - // * prev[mul2.idx_in[1]] - // * mul2.scalar.eval(&self.challenges); - // } - // for mul3 in layer.mul3s.iter() { - // expected[mul3.idx_out] += prev[mul3.idx_in[0]] - // * prev[mul3.idx_in[1]] - // * prev[mul3.idx_in[2]] - // * mul3.scalar.eval(&self.challenges); - // } - - // let mut expected_max_previous_size = prev.len(); - // for (old_layer_id, new_wire_ids) in layer.paste_from.iter() { - // expected_max_previous_size = - // expected_max_previous_size.max(new_wire_ids.len()); for - // (subset_wire_id, new_wire_id) in new_wire_ids.iter().enumerate() { - // let old_wire_id = circuit.layers[*old_layer_id as usize] - // .copy_to .get(&(layer_id as LayerId)) - // .unwrap()[subset_wire_id]; - // expected[*new_wire_id] = - // self.layers[*old_layer_id as usize].instances[copy_id][old_wire_id]; - // } - // } - // assert_eq!( - // ceil_log2(expected_max_previous_size), - // layer.max_previous_num_vars, - // "layer: {}, expected_max_previous_size: {}, got: {}", - // layer_id, - // expected_max_previous_size, - // layer.max_previous_num_vars - // ); - // for (wire_id, (got, expected)) in curr.iter().zip(expected.iter()).enumerate() { - // assert_eq!( - // *got, *expected, - // "layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != {:?}", - // layer_id, copy_id, wire_id, got, expected - // ); - // } - - // if layer_id != 0 { - // for (new_layer_id, old_wire_ids) in layer.copy_to.iter() { - // for (subset_wire_id, old_wire_id) in old_wire_ids.iter().enumerate() { - // let new_wire_id = circuit.layers[*new_layer_id as usize] - // .paste_from - // .get(&(layer_id as LayerId)) - // .unwrap()[subset_wire_id]; - // assert_eq!( - // curr[*old_wire_id], - // self.layers[*new_layer_id as - // usize].instances[copy_id][new_wire_id], "copy_to check: - // layer: {}, copy_id: {}, wire_id: {}, got != expected: {:?} != {:?}", - // layer_id, copy_id, - // old_wire_id, - // curr[*old_wire_id], - // self.layers[*new_layer_id as - // usize].instances[copy_id][new_wire_id] ) - // } - // } - // } - // } - // } - - // let output_layer_witness = &self.layers[0]; - // let wits_out = self.witness_out_ref(); - // for (wit_id, old_wire_ids) in circuit.copy_to_wits_out.iter().enumerate() { - // for copy_id in 0..self.n_instances { - // for (new_wire_id, old_wire_id) in old_wire_ids.iter().enumerate() { - // assert_eq!( - // output_layer_witness.instances[copy_id][*old_wire_id], - // wits_out[wit_id].instances[copy_id][new_wire_id] - // ); - // } - // } - // } - // for gate in circuit.assert_consts.iter() { - // if let ConstantType::Field(constant) = gate.scalar { - // for copy_id in 0..self.n_instances { - // assert_eq!( - // output_layer_witness.instances[copy_id][gate.idx_out], - // constant - // ); - // } - // } - // } - } -} - -impl<'a, E: ExtensionField> CircuitWitnessV2<'a, E> { - pub fn output_layer_witness_ref(&self) -> &ArcMultilinearExtension<'a, E> { - self.layers.first().unwrap() - } - - pub fn n_instances(&self) -> usize { - self.n_instances - } - - pub fn witness_in_ref(&self) -> &[ArcMultilinearExtension<'a, E>] { - &self.witness_in - } - - pub fn witness_out_ref(&self) -> &[ArcMultilinearExtension<'a, E>] { - &self.witness_out - } - - pub fn challenges(&self) -> &HashMap> { - &self.challenges - } - - pub fn layers_ref(&self) -> &[ArcMultilinearExtension<'a, E>] { - &self.layers - } -} - -// impl CircuitWitnessV2 { -// pub fn layer_poly( -// &self, -// layer_id: LayerId, -// single_num_vars: usize, -// multi_threads_meta: (usize, usize), -// ) -> ArcDenseMultilinearExtension { -// self.layers[layer_id as usize] -// .instances -// .as_slice() -// .mle_with_meta( -// single_num_vars, -// self.instance_num_vars(), -// multi_threads_meta, -// ) -// } -// } - -impl<'a, F: ExtensionField> Debug for CircuitWitnessV2<'a, F> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "CircuitWitness {{")?; - writeln!(f, " n_instances: {}", self.n_instances)?; - writeln!(f, " layers: ")?; - // for (i, layer) in self.layers.iter().enumerate() { - // writeln!(f, " {}: {:?}", i, layer)?; - // } - writeln!(f, " wires_in: ")?; - for (i, wire) in self.witness_in.iter().enumerate() { - // TODO figure out how to print dyn trait - writeln!(f, " {}: {:?}", i, &wire.name())?; - } - writeln!(f, " wires_out: ")?; - // for (i, wire) in self.witness_out.iter().enumerate() { - // writeln!(f, " {}: {:?}", i, wire)?; - // } - writeln!(f, " challenges: {:?}", self.challenges)?; - writeln!(f, "}}") - } -} diff --git a/gkr/src/gadgets/keccak256.rs b/gkr/src/gadgets/keccak256.rs index e361b684e..7af1c0608 100644 --- a/gkr/src/gadgets/keccak256.rs +++ b/gkr/src/gadgets/keccak256.rs @@ -3,9 +3,7 @@ use crate::{ error::GKRError, - structs::{ - Circuit, CircuitWitnessV2, GKRInputClaims, IOPProof, IOPProverStateV2, PointAndEval, - }, + structs::{Circuit, CircuitWitness, GKRInputClaims, IOPProof, IOPProverState, PointAndEval}, }; use ark_std::rand::{ rngs::{OsRng, StdRng}, @@ -457,7 +455,7 @@ pub fn prove_keccak256<'a, E: ExtensionField>( instance_num_vars: usize, circuit: &Circuit, max_thread_id: usize, -) -> Option<(IOPProof, CircuitWitnessV2)> { +) -> Option<(IOPProof, CircuitWitness)> { assert!( ceil_log2(max_thread_id) <= instance_num_vars, "ceil_log2(N) {} > instance_num_vars {}", @@ -476,7 +474,7 @@ pub fn prove_keccak256<'a, E: ExtensionField>( .map(|mut wit_in| { let next_pow_2 = ceil_log2(wit_in.len()); wit_in.resize(1 << next_pow_2, E::BaseField::ZERO); - wit_in + DenseMultilinearExtension::from_evaluations_vec(ceil_log2(wit_in.len()), wit_in) }) .collect(); let all_one = vec![ @@ -487,7 +485,7 @@ pub fn prove_keccak256<'a, E: ExtensionField>( .map(|mut wit_in| { let next_pow_2 = ceil_log2(wit_in.len()); wit_in.resize(1 << next_pow_2, E::BaseField::ZERO); - wit_in + DenseMultilinearExtension::from_evaluations_vec(ceil_log2(wit_in.len()), wit_in) }) .collect(); let mut witness = CircuitWitness::new(&circuit, Vec::new()); @@ -495,7 +493,9 @@ pub fn prove_keccak256<'a, E: ExtensionField>( witness.add_instance(&circuit, all_one); izip!( - &witness.witness_out_ref()[0].instances, + witness.witness_out_ref()[0] + .get_base_field_vec() + .chunks(256), [[0; 25], [u64::MAX; 25]] ) .for_each(|(wire_out, state)| { @@ -517,7 +517,7 @@ pub fn prove_keccak256<'a, E: ExtensionField>( } let mut rng = StdRng::seed_from_u64(OsRng.next_u64()); - let mut witness = CircuitWitnessV2::new(&circuit, Vec::new()); + let mut witness = CircuitWitness::new(&circuit, Vec::new()); for _ in 0..1 << instance_num_vars { let [rand_state, rand_input] = [25 * 64, 17 * 64].map(|n| { let mut data = vec![E::BaseField::ZERO; 1 << ceil_log2(n)]; @@ -554,7 +554,7 @@ pub fn prove_keccak256<'a, E: ExtensionField>( let output_eval = output_mle.evaluate(&output_point); let start = std::time::Instant::now(); - let (proof, _) = IOPProverStateV2::prove_parallel( + let (proof, _) = IOPProverState::prove_parallel( &circuit, &witness, vec![], @@ -581,7 +581,7 @@ pub fn verify_keccak256( .take(output_mle.num_vars()) .collect_vec(); let output_eval = output_mle.evaluate(&output_point); - crate::structs::IOPVerifierStateV2::verify_parallel( + crate::structs::IOPVerifierState::verify_parallel( &circuit, &[], vec![], diff --git a/gkr/src/lib.rs b/gkr/src/lib.rs index f9228ff8e..214126a5b 100644 --- a/gkr/src/lib.rs +++ b/gkr/src/lib.rs @@ -10,7 +10,6 @@ pub mod structs; pub mod unsafe_utils; pub mod utils; mod verifier; -mod verifier_v2; pub use sumcheck::util; diff --git a/gkr/src/prover.rs b/gkr/src/prover.rs index 77367689f..8c8d14c1a 100644 --- a/gkr/src/prover.rs +++ b/gkr/src/prover.rs @@ -1,26 +1,19 @@ -use std::mem; - use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{ - mle::ArcDenseMultilinearExtension, - virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, - virtual_poly_v2::VirtualPolynomialV2, + virtual_poly::build_eq_x_r_vec, virtual_poly_v2::VirtualPolynomialV2, }; -use rayon::iter::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, - IntoParallelRefMutIterator, ParallelIterator, -}; +use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; use simple_frontend::structs::LayerId; use transcript::Transcript; use crate::{ entered_span, exit_span, structs::{ - Circuit, CircuitWitness, CircuitWitnessV2, GKRInputClaims, IOPProof, IOPProverState, - IOPProverStateV2, IOPProverStepMessage, PointAndEval, SumcheckStepType, + Circuit, CircuitWitness, GKRInputClaims, IOPProof, IOPProverState, + IOPProverStepMessage, PointAndEval, SumcheckStepType, }, tracing_span, }; @@ -34,15 +27,14 @@ mod phase2_linear; #[cfg(test)] mod test; -type SumcheckState = sumcheck::structs::IOPProverState; type SumcheckStateV2<'a, F> = sumcheck::structs::IOPProverStateV2<'a, F>; -impl IOPProverStateV2 { +impl IOPProverState { /// Prove process for data parallel circuits. #[tracing::instrument(skip_all, name = "gkr::prove_parallel")] pub fn prove_parallel<'a>( circuit: &Circuit, - circuit_witness: &CircuitWitnessV2, + circuit_witness: &CircuitWitness, output_evals: Vec>, wires_out_evals: Vec>, max_thread_id: usize, @@ -359,273 +351,3 @@ impl IOPProverStateV2 { } } } - -impl IOPProverState { - /// Prove process for data parallel circuits. - #[tracing::instrument(skip_all, name = "gkr::prove_parallel")] - pub fn prove_parallel( - circuit: &Circuit, - circuit_witness: &CircuitWitness, - output_evals: Vec>, - wires_out_evals: Vec>, - max_thread_id: usize, - transcript: &mut Transcript, - ) -> (IOPProof, GKRInputClaims) { - let timer = start_timer!(|| "Proving"); - let span = entered_span!("Proving"); - // TODO: Currently haven't support non-power-of-two number of instances. - assert!(circuit_witness.n_instances == 1 << circuit_witness.instance_num_vars()); - - let mut prover_state = tracing_span!("prover_init_parallel").in_scope(|| { - Self::prover_init_parallel( - circuit, - circuit_witness, - output_evals, - wires_out_evals, - transcript, - ) - }); - - let sumcheck_proofs = (0..circuit.layers.len() as LayerId) - .map(|layer_id| { - let timer = start_timer!(|| format!("Prove layer {}", layer_id)); - - prover_state.layer_id = layer_id; - - let dummy_step = SumcheckStepType::Undefined; - let proofs = circuit.layers[layer_id as usize] - .sumcheck_steps - .iter().chain(vec![&dummy_step, &dummy_step]) - .tuple_windows() - .flat_map(|steps| match steps { - (SumcheckStepType::OutputPhase1Step1, SumcheckStepType::OutputPhase1Step2, _) => { - [prover_state - .prove_and_update_state_output_phase1_step1( - circuit, - circuit_witness, - transcript, - ), - prover_state - .prove_and_update_state_output_phase1_step2( - circuit, - circuit_witness, - transcript, - )].to_vec() - }, - (SumcheckStepType::Phase1Step1, _, _) => { - let alpha = transcript - .get_and_append_challenge(b"combine subset evals") - .elements; - let hi_num_vars = circuit_witness.instance_num_vars(); - let eq_t = prover_state.to_next_phase_point_and_evals.par_iter().chain(prover_state.subset_point_and_evals[layer_id as usize].par_iter().map(|(_, point_and_eval)| point_and_eval)).map(|point_and_eval|{ - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - build_eq_x_r_vec(&point_and_eval.point[point_lo_num_vars..]) - }).collect::>>(); - let virtual_polys: Vec> = (0..max_thread_id).into_par_iter().map(|thread_id| { - let span = entered_span!("build_poly"); - let virtual_poly = IOPProverState::build_phase1_step1_sumcheck_poly( - &prover_state, - layer_id, - alpha, - &eq_t, - circuit, - circuit_witness, - (thread_id, max_thread_id), - ); - exit_span!(span); - virtual_poly - }).collect(); - - let (sumcheck_proof, sumcheck_prover_state) = sumcheck::structs::IOPProverState::::prove_batch_polys( - max_thread_id, - virtual_polys.try_into().unwrap(), - transcript, - ); - - let prover_msg = prover_state.combine_phase1_step1_evals( - sumcheck_proof, - sumcheck_prover_state, - ); - - vec![prover_msg] - - } - , - (SumcheckStepType::Phase2Step1, step2, _) => { - let span = entered_span!("phase2_gkr"); - let max_steps = match step2 { - SumcheckStepType::Phase2Step2 => 3, - SumcheckStepType::Phase2Step2NoStep3 => 2, - _ => unreachable!(), - }; - - let mut eqs = vec![]; - let mut layer_polys = (0..max_thread_id).map(|_| ArcDenseMultilinearExtension::default()).collect::>>(); - let mut res = vec![]; - for step in 0..max_steps { - let bounded_eval_point = prover_state.to_next_step_point.clone(); - eqs.push(build_eq_x_r_vec(&bounded_eval_point)); - // build step round poly - let virtual_polys: Vec> = (0..max_thread_id).into_par_iter().zip(layer_polys.par_iter_mut()).map(|(thread_id, layer_poly)| { - let span = entered_span!("build_poly"); - let (next_layer_poly_step1, virtual_poly) = match step { - 0 => { - let (next_layer_poly, virtual_poly) = IOPProverState::build_phase2_step1_sumcheck_poly( - eqs.as_slice().try_into().unwrap(), - layer_id, - circuit, - circuit_witness, - (thread_id, max_thread_id), - ); - (Some(next_layer_poly), virtual_poly) - }, - 1 => { - let virtual_poly = IOPProverState::build_phase2_step2_sumcheck_poly( - &layer_poly, - layer_id, - eqs.as_slice().try_into().unwrap(), - circuit, - circuit_witness, - (thread_id, max_thread_id), - ); - (None, virtual_poly) - }, - 2 => { - let virtual_poly = IOPProverState::build_phase2_step3_sumcheck_poly( - &layer_poly, - layer_id, - eqs.as_slice().try_into().unwrap(), - circuit, - circuit_witness, - (thread_id, max_thread_id), - ); - (None, virtual_poly) - - }, - _ => unimplemented!(), - }; - if let Some(next_layer_poly_step1) = next_layer_poly_step1 { - let _ = mem::replace(layer_poly, next_layer_poly_step1); - } - exit_span!(span); - virtual_poly - }).collect(); - - let (sumcheck_proof, sumcheck_prover_state) = sumcheck::structs::IOPProverState::::prove_batch_polys( - max_thread_id, - virtual_polys.try_into().unwrap(), - transcript, - ); - - let iop_prover_step = - match step { - 0 => { - prover_state.combine_phase2_step1_evals( - circuit, - sumcheck_proof, - sumcheck_prover_state, - ) - }, - 1 => { - let no_step3: bool = max_steps == 2; - prover_state.combine_phase2_step2_evals( - circuit, - sumcheck_proof, - sumcheck_prover_state, - no_step3, - ) - }, - 2 => { - prover_state.combine_phase2_step3_evals( - circuit, - sumcheck_proof, - sumcheck_prover_state, - ) - }, - _ => unimplemented!() - }; - - res.push(iop_prover_step); - } - exit_span!(span); - res - }, - (SumcheckStepType::LinearPhase2Step1, _, _) => - [prover_state - .prove_and_update_state_linear_phase2_step1( - circuit, - circuit_witness, - transcript, - )].to_vec(), - (SumcheckStepType::InputPhase2Step1, _, _) => - [prover_state - .prove_and_update_state_input_phase2_step1( - circuit, - circuit_witness, - transcript, - ) - ].to_vec(), - _ => { - vec![] - } - }) - .collect_vec(); - end_timer!(timer); - - proofs - }) - .flatten() - .collect_vec(); - end_timer!(timer); - exit_span!(span); - - ( - IOPProof { sumcheck_proofs }, - GKRInputClaims { - point_and_evals: prover_state.to_next_phase_point_and_evals, - }, - ) - } - - /// Initialize proving state for data parallel circuits. - fn prover_init_parallel( - circuit: &Circuit, - circuit_witness: &CircuitWitness, - output_evals: Vec>, - wires_out_evals: Vec>, - transcript: &mut Transcript, - ) -> Self { - let n_layers = circuit.layers.len(); - let output_wit_num_vars = circuit.layers[0].num_vars + circuit_witness.instance_num_vars(); - let mut subset_point_and_evals = vec![vec![]; n_layers]; - let to_next_step_point = if !output_evals.is_empty() { - output_evals.last().unwrap().point.clone() - } else { - wires_out_evals.last().unwrap().point.clone() - }; - let assert_point = (0..output_wit_num_vars) - .map(|_| { - transcript - .get_and_append_challenge(b"assert_point") - .elements - }) - .collect_vec(); - let to_next_phase_point_and_evals = output_evals; - subset_point_and_evals[0] = wires_out_evals - .into_iter() - .map(|p| (0 as LayerId, p)) - .collect(); - - Self { - to_next_phase_point_and_evals, - subset_point_and_evals, - to_next_step_point, - - assert_point, - // Default - layer_id: 0, - phase1_layer_poly: ArcDenseMultilinearExtension::default(), - g1_values: vec![], - } - } -} diff --git a/gkr/src/prover/phase1.rs b/gkr/src/prover/phase1.rs index 4a942145d..fcc23b533 100644 --- a/gkr/src/prover/phase1.rs +++ b/gkr/src/prover/phase1.rs @@ -3,8 +3,8 @@ use ff::Field; use ff_ext::ExtensionField; use itertools::{izip, Itertools}; use multilinear_extensions::{ - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}, - virtual_poly::{build_eq_x_r_vec_sequential, VirtualPolynomial}, + mle::DenseMultilinearExtension, + virtual_poly::build_eq_x_r_vec_sequential, virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}, }; use simple_frontend::structs::LayerId; @@ -14,14 +14,14 @@ use sumcheck::{entered_span, util::ceil_log2}; use crate::{ exit_span, structs::{ - Circuit, CircuitWitness, CircuitWitnessV2, IOPProverState, IOPProverStateV2, - IOPProverStepMessage, PointAndEval, SumcheckProof, + Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval, + SumcheckProof, }, utils::{tensor_product, MatrixMLERowFirst}, }; // Prove the items copied from the current layer to later layers for data parallel circuits. -impl IOPProverStateV2 { +impl IOPProverState { /// Sumcheck 1: sigma = \sum_{t || y} \sum_j ( f1^{(j)}(t || y) * g1^{(j)}(t || y) ) /// sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) /// f1^{(j)}(y) = layers[i](t || y) @@ -34,7 +34,7 @@ impl IOPProverStateV2 { alpha: E, eq_t: &Vec>, circuit: &Circuit, - circuit_witness: &'a CircuitWitnessV2, + circuit_witness: &'a CircuitWitness, multi_threads_meta: (usize, usize), ) -> VirtualPolynomialV2<'a, E> { let span = entered_span!("preparation"); @@ -185,171 +185,3 @@ impl IOPProverStateV2 { } } } - -// Prove the items copied from the current layer to later layers for data parallel circuits. -impl IOPProverState { - /// Sumcheck 1: sigma = \sum_{t || y} \sum_j ( f1^{(j)}(t || y) * g1^{(j)}(t || y) ) - /// sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) - /// f1^{(j)}(y) = layers[i](t || y) - /// g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) - /// g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) - #[tracing::instrument(skip_all, name = "build_phase1_step1_sumcheck_poly")] - pub(super) fn build_phase1_step1_sumcheck_poly( - &self, - layer_id: LayerId, - alpha: E, - eq_t: &Vec>, - circuit: &Circuit, - circuit_witness: &CircuitWitness, - multi_threads_meta: (usize, usize), - ) -> VirtualPolynomial { - let span = entered_span!("preparation"); - let timer = start_timer!(|| "Prover sumcheck phase 1 step 1"); - - let total_length = self.to_next_phase_point_and_evals.len() - + self.subset_point_and_evals[self.layer_id as usize].len() - + 1; - let alpha_pows = { - let mut alpha_pows = vec![E::ONE; total_length]; - for i in 0..total_length.saturating_sub(1) { - alpha_pows[i + 1] = alpha_pows[i] * alpha; - } - alpha_pows - }; - - let lo_num_vars = circuit.layers[self.layer_id as usize].num_vars; - let hi_num_vars = circuit_witness.instance_num_vars(); - - // parallel unit logic handling - let (thread_id, max_thread_id) = multi_threads_meta; - let log2_max_thread_id = ceil_log2(max_thread_id); - - exit_span!(span); - - // f1^{(j)}(y) = layers[i](t || y) - let f1: Arc> = circuit_witness - .layer_poly::( - (layer_id).try_into().unwrap(), - lo_num_vars, - multi_threads_meta, - ) - .into(); - - assert_eq!( - f1.evaluations.len(), - 1 << (hi_num_vars + lo_num_vars - log2_max_thread_id) - ); - - let span = entered_span!("g1"); - // g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) - // g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) - let copy_to_matrices = &circuit.layers[self.layer_id as usize].copy_to; - let g1: ArcDenseMultilinearExtension = { - let gs = izip!(&self.to_next_phase_point_and_evals, &alpha_pows, eq_t) - .map(|(point_and_eval, alpha_pow, eq_t)| { - // g1^{(j)}(y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - - let eq_y = - build_eq_x_r_vec_sequential(&point_and_eval.point[..point_lo_num_vars]) - .into_iter() - .take(1 << lo_num_vars) - .map(|eq| *alpha_pow * eq) - .collect_vec(); - - let eq_t_unit_len = eq_t.len() / max_thread_id; - let start_index = thread_id * eq_t_unit_len; - let g1_j = - tensor_product(&eq_t[start_index..(start_index + eq_t_unit_len)], &eq_y); - - assert_eq!( - g1_j.len(), - (1 << (hi_num_vars + lo_num_vars - log2_max_thread_id)) - ); - - g1_j - }) - .chain( - izip!( - &self.subset_point_and_evals[self.layer_id as usize], - &alpha_pows[self.to_next_phase_point_and_evals.len()..], - eq_t.iter().skip(self.to_next_phase_point_and_evals.len()) - ) - .map( - |((new_layer_id, point_and_eval), alpha_pow, eq_t)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let copy_to = ©_to_matrices[new_layer_id]; - let lo_eq_w_p = build_eq_x_r_vec_sequential( - &point_and_eval.point[..point_lo_num_vars], - ); - - // g2^{(j)}(y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) - let eq_t_unit_len = eq_t.len() / max_thread_id; - let start_index = thread_id * eq_t_unit_len; - let g2_j = tensor_product( - &eq_t[start_index..(start_index + eq_t_unit_len)], - ©_to.as_slice().fix_row_row_first_with_scalar( - &lo_eq_w_p, - lo_num_vars, - alpha_pow, - ), - ); - - assert_eq!( - g2_j.len(), - (1 << (hi_num_vars + lo_num_vars - log2_max_thread_id)) - ); - g2_j - }, - ), - ) - .collect::>>(); - - DenseMultilinearExtension::from_evaluations_ext_vec( - hi_num_vars + lo_num_vars - log2_max_thread_id, - gs.into_iter() - .fold(vec![E::ZERO; 1 << f1.num_vars], |mut acc, g| { - assert_eq!(1 << f1.num_vars, g.len()); - acc.iter_mut().enumerate().for_each(|(i, v)| *v += g[i]); - acc - }), - ) - .into() - }; - exit_span!(span); - - // sumcheck: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y)) - let span = entered_span!("virtual_poly"); - let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1, E::BaseField::ONE); - virtual_poly_1.mul_by_mle(g1, E::BaseField::ONE); - exit_span!(span); - end_timer!(timer); - - virtual_poly_1 - } - - pub(super) fn combine_phase1_step1_evals( - &mut self, - sumcheck_proof_1: SumcheckProof, - prover_state: sumcheck::structs::IOPProverState, - ) -> IOPProverStepMessage { - let (mut f1, _): (Vec<_>, Vec<_>) = prover_state - .get_mle_final_evaluations() - .into_iter() - .enumerate() - .partition(|(i, _)| i % 2 == 0); - let eval_value_1 = f1.remove(0).1; - - self.to_next_step_point = sumcheck_proof_1.point.clone(); - self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref( - &self.to_next_step_point, - &eval_value_1, - )]; - self.subset_point_and_evals[self.layer_id as usize].clear(); - - IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_1, - sumcheck_eval_values: vec![eval_value_1], - } - } -} diff --git a/gkr/src/prover/phase1_output.rs b/gkr/src/prover/phase1_output.rs index 0779ff4ed..519c1f1c9 100644 --- a/gkr/src/prover/phase1_output.rs +++ b/gkr/src/prover/phase1_output.rs @@ -1,34 +1,28 @@ use ark_std::{end_timer, start_timer}; use ff::Field; use ff_ext::ExtensionField; -use itertools::{chain, izip, Itertools}; +use itertools::{izip, Itertools}; use multilinear_extensions::{ - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, + mle::DenseMultilinearExtension, util::ceil_log2, - virtual_poly::{build_eq_x_r_vec, build_eq_x_r_vec_sequential, VirtualPolynomial}, + virtual_poly::{build_eq_x_r_vec, build_eq_x_r_vec_sequential}, virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}, }; -use std::{iter, mem, sync::Arc}; -use transcript::Transcript; +use std::{iter, sync::Arc}; use crate::{ - entered_span, exit_span, izip_parallizable, - prover::SumcheckState, + entered_span, exit_span, structs::{ - Circuit, CircuitWitness, CircuitWitnessV2, IOPProverState, IOPProverStateV2, - IOPProverStepMessage, PointAndEval, SumcheckProof, + Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval, SumcheckProof, }, utils::{tensor_product, MatrixMLERowFirst}, }; -#[cfg(feature = "parallel")] -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; - // Prove the items copied from the output layer to the output witness for data parallel circuits. // \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) // = \sum_{t || y} ( \sum_j( \alpha^j (eq or copy_to[j] or assert_subset_eq)(ry_j, y) eq(rt_j, // t) * layers[i](t || y) ) ) -impl IOPProverStateV2 { +impl IOPProverState { /// Sumcheck 1: sigma = \sum_{t || y} \sum_j ( f1^{(j)}(t || y) * g1^{(j)}(t || y) ) /// sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) /// f1^{(j)}(y) = layers[i](t || y) @@ -41,7 +35,7 @@ impl IOPProverStateV2 { eq_t: &Vec>, alpha: E, circuit: &Circuit, - circuit_witness: &'a CircuitWitnessV2, + circuit_witness: &'a CircuitWitness, multi_threads_meta: (usize, usize), ) -> VirtualPolynomialV2<'a, E> { let timer = start_timer!(|| "Prover sumcheck output phase 1 step 1"); @@ -198,233 +192,3 @@ impl IOPProverStateV2 { } } } - -// Prove the items copied from the output layer to the output witness for data parallel circuits. -// \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) -// = \sum_y( \sum_j( \alpha^j (eq or copy_to[j] or assert_subset_eq)(ry_j, y) \sum_t( eq(rt_j, -// t) * layers[i](t || y) ) ) ) -impl IOPProverState { - /// Sumcheck 1: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) - /// sigma = \sum_j( \alpha^j * wit_out_eval[j](rt_j || ry_j) ) - /// + \alpha^{wit_out_eval[j].len()} * assert_const(rt || ry) ) - /// f1^{(j)}(y) = layers[i](rt_j || y) - /// g1^{(j)}(y) = \alpha^j eq(ry_j, y) - // or \alpha^j copy_to[j](ry_j, y) - // or \alpha^j assert_subset_eq(ry, y) - #[tracing::instrument(skip_all, name = "prove_and_update_state_output_phase1_step1")] - pub(super) fn prove_and_update_state_output_phase1_step1( - &mut self, - circuit: &Circuit, - circuit_witness: &CircuitWitness, - transcript: &mut Transcript, - ) -> IOPProverStepMessage { - let timer = start_timer!(|| "Prover sumcheck output phase 1 step 1"); - let alpha = transcript - .get_and_append_challenge(b"combine subset evals") - .elements; - - let total_length = self.to_next_phase_point_and_evals.len() - + self.subset_point_and_evals[self.layer_id as usize].len() - + 1; - let alpha_pows = { - let mut alpha_pows = vec![E::ONE; total_length]; - for i in 0..total_length.saturating_sub(1) { - alpha_pows[i + 1] = alpha * &alpha_pows[i]; - } - alpha_pows - }; - - let lo_num_vars = circuit.layers[self.layer_id as usize].num_vars; - let hi_num_vars = circuit_witness.instance_num_vars(); - - self.phase1_layer_poly = circuit_witness.layer_poly::( - (self.layer_id).try_into().unwrap(), - lo_num_vars, - (0, 1), - ); - - // sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) - // f1^{(j)}(y) = layers[i](rt_j || y) - // g1^{(j)}(y) = \alpha^j eq(ry_j, y) - // or \alpha^j copy_to[j](ry_j, y) - // or \alpha^j assert_subset_eq(ry, y) - // TODO: Double check the soundness here. - let (mut f1, mut g1): ( - Vec>, - Vec>, - ) = izip_parallizable!(&self.to_next_phase_point_and_evals, &alpha_pows) - .map(|(point_and_eval, alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let point = &point_and_eval.point; - let lo_eq_w_p = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - - let f1_j = self - .phase1_layer_poly - .fix_high_variables(&point[point_lo_num_vars..]); - - let g1_j = lo_eq_w_p - .into_iter() - .map(|eq| *alpha_pow * eq) - .collect_vec(); - ( - f1_j.into(), - DenseMultilinearExtension::::from_evaluations_ext_vec(lo_num_vars, g1_j) - .into(), - ) - }) - .unzip(); - - let (f1_copy_to, g1_copy_to): ( - Vec>, - Vec>, - ) = izip!( - &circuit.copy_to_wits_out, - &self.subset_point_and_evals[self.layer_id as usize], - &alpha_pows[self.to_next_phase_point_and_evals.len()..] - ) - .map(|(copy_to, (_, point_and_eval), alpha_pow)| { - let point = &point_and_eval.point; - let point_lo_num_vars = point.len() - hi_num_vars; - - let lo_eq_w_p = build_eq_x_r_vec(&point[..point_lo_num_vars]); - assert!(copy_to.len() <= lo_eq_w_p.len()); - - let f1_j = self - .phase1_layer_poly - .fix_high_variables(&point[point_lo_num_vars..]); - - let g1_j = copy_to.as_slice().fix_row_row_first_with_scalar( - &lo_eq_w_p, - lo_num_vars, - alpha_pow, - ); - - ( - f1_j.into(), - DenseMultilinearExtension::from_evaluations_ext_vec(lo_num_vars, g1_j).into(), - ) - }) - .unzip(); - - f1.extend(f1_copy_to); - g1.extend(g1_copy_to); - - let f1_j = self - .phase1_layer_poly - .fix_high_variables(&self.assert_point[lo_num_vars..]); - f1.push(f1_j.into()); - - let alpha_pow = alpha_pows.last().unwrap(); - let lo_eq_w_p = build_eq_x_r_vec(&self.assert_point[..lo_num_vars]); - - let mut g_last = vec![E::ZERO; 1 << lo_num_vars]; - circuit.assert_consts.iter().for_each(|gate| { - g_last[gate.idx_out as usize] = lo_eq_w_p[gate.idx_out as usize] * alpha_pow; - }); - - g1.push(DenseMultilinearExtension::from_evaluations_ext_vec(lo_num_vars, g_last).into()); - - // sumcheck: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) - let mut virtual_poly_1 = VirtualPolynomial::new(lo_num_vars); - for (f1_j, g1_j) in f1.into_iter().zip(g1.into_iter()) { - let mut tmp = VirtualPolynomial::new_from_mle(f1_j, E::BaseField::ONE); - tmp.mul_by_mle(g1_j, E::BaseField::ONE); - virtual_poly_1.merge(&tmp); - } - - let (sumcheck_proof_1, prover_state) = - SumcheckState::prove_parallel(virtual_poly_1, transcript); - let (f1, g1): (Vec<_>, Vec<_>) = prover_state - .get_mle_final_evaluations() - .into_iter() - .enumerate() - .partition(|(i, _)| i % 2 == 0); - let eval_value_1 = f1.into_iter().map(|(_, f1_j)| f1_j).collect_vec(); - - self.to_next_step_point = sumcheck_proof_1.point.clone(); - self.g1_values = g1.into_iter().map(|(_, g1_j)| g1_j).collect_vec(); - - end_timer!(timer); - IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_1, - sumcheck_eval_values: eval_value_1, - } - } - - /// Sumcheck 2: sigma = \sum_t( \sum_j( f2^{(j)}(t) ) ) * g2(t) - /// sigma = \sum_j( f1^{(j)}(ry) * g1^{(j)}(ry) ) - /// f2(t) = layers[i](t || ry) - /// g2^{(j)}(t) = \alpha^j eq(ry_j, ry) eq(rt_j, t) - // or \alpha^j copy_to[j](ry_j, ry) eq(rt_j, t) - // or \alpha^j assert_subset_eq(ry, ry) eq(rt, t) - #[tracing::instrument(skip_all, name = "prove_and_update_state_output_phase1_step2")] - pub(super) fn prove_and_update_state_output_phase1_step2( - &mut self, - _: &Circuit, - circuit_witness: &CircuitWitness, - transcript: &mut Transcript, - ) -> IOPProverStepMessage { - let timer = start_timer!(|| "Prover sumcheck output phase 1 step 2"); - let hi_num_vars = circuit_witness.instance_num_vars(); - - // f2(t) = layers[i](t || ry) - let mut f2 = mem::take(&mut self.phase1_layer_poly); - - Arc::make_mut(&mut f2).fix_variables_in_place_parallel(&self.to_next_step_point); - - // g2(t) = \sum_j \alpha^j (eq or copy_to[j] or assert_subset)(ry_j, ry) eq(rt_j, t) - let output_points = chain![ - self.to_next_phase_point_and_evals.iter().map(|x| &x.point), - self.subset_point_and_evals[self.layer_id as usize] - .iter() - .map(|x| &x.1.point), - iter::once(&self.assert_point), - ]; - let g2 = output_points - .zip(self.g1_values.iter()) - .map(|(point, &g1_value)| { - let point_lo_num_vars = point.len() - hi_num_vars; - build_eq_x_r_vec(&point[point_lo_num_vars..]) - .into_iter() - .map(|eq| g1_value * eq) - .collect_vec() - }) - .fold(vec![E::ZERO; 1 << hi_num_vars], |acc, nxt| { - acc.into_iter() - .zip(nxt.into_iter()) - .map(|(a, b)| a + b) - .collect_vec() - }); - let g2 = DenseMultilinearExtension::from_evaluations_ext_vec(hi_num_vars, g2); - // sumcheck: sigma = \sum_t( g2(t) * f2(t) ) - let mut virtual_poly_2 = VirtualPolynomial::new_from_mle(f2, E::BaseField::ONE); - virtual_poly_2.mul_by_mle(g2.into(), E::BaseField::ONE); - - let (sumcheck_proof_2, prover_state) = - SumcheckState::prove_parallel(virtual_poly_2, transcript); - let (mut f2, _): (Vec<_>, Vec<_>) = prover_state - .get_mle_final_evaluations() - .into_iter() - .enumerate() - .partition(|(i, _)| i % 2 == 0); - let eval_value_2 = f2.remove(0).1; - - self.to_next_step_point = [ - mem::take(&mut self.to_next_step_point), - sumcheck_proof_2.point.clone(), - ] - .concat(); - self.to_next_phase_point_and_evals = vec![PointAndEval::new_from_ref( - &self.to_next_step_point, - &eval_value_2, - )]; - - self.subset_point_and_evals[self.layer_id as usize].clear(); - - end_timer!(timer); - IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_2, - sumcheck_eval_values: vec![eval_value_2], - } - } -} diff --git a/gkr/src/prover/phase2.rs b/gkr/src/prover/phase2.rs index eb262b9da..8b71cf645 100644 --- a/gkr/src/prover/phase2.rs +++ b/gkr/src/prover/phase2.rs @@ -4,7 +4,6 @@ use ff_ext::ExtensionField; use itertools::{izip, Itertools}; use multilinear_extensions::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}, - virtual_poly::VirtualPolynomial, virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}, }; use simple_frontend::structs::LayerId; @@ -13,16 +12,14 @@ use std::sync::Arc; use sumcheck::{entered_span, exit_span, util::ceil_log2}; use crate::structs::{ - CircuitWitnessV2, IOPProverStateV2, + CircuitWitness, IOPProverState, Step::{Step1, Step2, Step3}, }; use multilinear_extensions::mle::MultilinearExtension; use crate::{ circuit::EvaluateConstant, - structs::{ - Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval, SumcheckProof, - }, + structs::{Circuit, IOPProverStepMessage, PointAndEval, SumcheckProof}, }; macro_rules! prepare_stepx_g_fn_v2 { @@ -82,7 +79,7 @@ macro_rules! prepare_stepx_g_fn { // ) ) + \sum_{s1}( \sum_{x1}( // \sum_j eq(rt, s1) paste_from[j](ry, x1) * subset[j][i](s1 || x1) // ) ) + add_const(ry) -impl IOPProverStateV2 { +impl IOPProverState { /// Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * /// g1'_j(s1 || x1) sigma = layers[i](rt || ry) - add_const(ry), /// f1(s1 || x1) = layers[i + 1](s1 || x1) @@ -99,7 +96,7 @@ impl IOPProverStateV2 { eq: &[Vec; 1], layer_id: LayerId, circuit: &Circuit, - circuit_witness: &'a CircuitWitnessV2, + circuit_witness: &'a CircuitWitness, multi_threads_meta: (usize, usize), ) -> VirtualPolynomialV2<'a, E> { let timer = start_timer!(|| "Prover sumcheck phase 2 step 1"); @@ -288,7 +285,7 @@ impl IOPProverStateV2 { layer_id: LayerId, eqs: &[Vec; 2], circuit: &Circuit, - circuit_witness: &'a CircuitWitnessV2, + circuit_witness: &'a CircuitWitness, multi_threads_meta: (usize, usize), ) -> VirtualPolynomialV2<'a, E> { let timer = start_timer!(|| "Prover sumcheck phase 2 step 2"); @@ -393,7 +390,7 @@ impl IOPProverStateV2 { layer_id: LayerId, eqs: &[Vec; 3], circuit: &Circuit, - circuit_witness: &'a CircuitWitnessV2, + circuit_witness: &'a CircuitWitness, multi_threads_meta: (usize, usize), ) -> VirtualPolynomialV2<'a, E> { let timer = start_timer!(|| "Prover sumcheck phase 2 step 3"); @@ -467,376 +464,3 @@ impl IOPProverStateV2 { } } } - -// Prove the computation in the current layer for data parallel circuits. -// The number of terms depends on the gate. -// Here is an example of degree 3: -// layers[i](rt || ry) = \sum_{s1}( \sum_{s2}( \sum_{s3}( \sum_{x1}( \sum_{x2}( \sum_{x3}( -// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2) -// * layers[i + 1](s3 || x3) ) ) ) ) ) ) + sum_s1( sum_s2( sum_{x1}( sum_{x2}( eq(rt, s1, s2) * -// mul2(ry, x1, x2) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2) -// ) ) ) ) + \sum_{s1}( \sum_{x1}( -// eq(rt, s1) * add(ry, x1) * layers[i + 1](s1 || x1) -// ) ) + \sum_{s1}( \sum_{x1}( -// \sum_j eq(rt, s1) paste_from[j](ry, x1) * subset[j][i](s1 || x1) -// ) ) + add_const(ry) -impl IOPProverState { - /// Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * - /// g1'_j(s1 || x1) sigma = layers[i](rt || ry) - add_const(ry), - /// f1(s1 || x1) = layers[i + 1](s1 || x1) - /// g1(s1 || x1) = \sum_{s2}( \sum_{s3}( \sum_{x2}( \sum_{x3}( - /// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + - /// 1](s3 || x3) ) ) ) ) + \sum_{s2}( \sum_{x2}( - /// eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s2 || x2) - /// ) ) + eq(rt, s1) * add(ry, x1) - /// f1'^{(j)}(s1 || x1) = subset[j][i](s1 || x1) - /// g1'^{(j)}(s1 || x1) = eq(rt, s1) paste_from[j](ry, x1) - /// s1 || x1 || 0, s1 || x1 || 1 - #[tracing::instrument(skip_all, name = "build_phase2_step1_sumcheck_poly")] - pub(super) fn build_phase2_step1_sumcheck_poly( - eq: &[Vec; 1], - layer_id: LayerId, - circuit: &Circuit, - circuit_witness: &CircuitWitness, - multi_threads_meta: (usize, usize), - ) -> (ArcDenseMultilinearExtension, VirtualPolynomial) { - let timer = start_timer!(|| "Prover sumcheck phase 2 step 1"); - let layer = &circuit.layers[layer_id as usize]; - let lo_out_num_vars = layer.num_vars; - let lo_in_num_vars = layer.max_previous_num_vars; - let hi_num_vars = circuit_witness.instance_num_vars(); - let eq = &eq[0]; - - // parallel unit logic handling - let (thread_id, max_thread_id) = multi_threads_meta; - let log2_max_thread_id = ceil_log2(max_thread_id); - let threads_num_vars = hi_num_vars - log2_max_thread_id; - let thread_s = thread_id << threads_num_vars; - - let challenges = &circuit_witness.challenges; - - let span = entered_span!("f1_g1"); - // merge next_layer_vec with next_layer_poly - let next_layer_vec = circuit_witness.layers[layer_id as usize + 1] - .instances - .as_slice(); - let num_vars = circuit.layers[layer_id as usize].max_previous_num_vars(); - let phase2_next_layer_polys_v2: ArcDenseMultilinearExtension = circuit_witness - .layer_poly( - (layer_id + 1).try_into().unwrap(), - num_vars, - multi_threads_meta, - ) - .into(); - - // f1(s1 || x1) = layers[i + 1](s1 || x1) - let f1 = phase2_next_layer_polys_v2.clone(); - - // g1(s1 || x1) = \sum_{s2}( \sum_{s3}( \sum_{x2}( \sum_{x3}( - // eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + - // 1](s3 || x3) ) ) ) ) + \sum_{s2}( \sum_{x2}( - // eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s2 || x2) - // ) ) + eq(rt, s1) * add(ry, x1) - let mut g1 = vec![E::ZERO; 1 << f1.num_vars]; - let mul3s_fanin_mapping = &layer.mul3s_fanin_mapping[Step1 as usize]; - let mul2s_fanin_mapping = &layer.mul2s_fanin_mapping[Step1 as usize]; - let adds_fanin_mapping = &layer.adds_fanin_mapping[Step1 as usize]; - - prepare_stepx_g_fn!( - &mut g1, - lo_in_num_vars, - thread_s, - mul3s_fanin_mapping, - |s, gate| { - eq[(s << lo_out_num_vars) ^ gate.idx_out] - * (&next_layer_vec[s][gate.idx_in[1]]) - * (&next_layer_vec[s][gate.idx_in[2]]) - * (&gate.scalar.eval(&challenges)) - }, - mul2s_fanin_mapping, - |s, gate| { - eq[(s << lo_out_num_vars) ^ gate.idx_out] - * (&next_layer_vec[s][gate.idx_in[1]]) - * (&gate.scalar.eval(&challenges)) - }, - adds_fanin_mapping, - |s, gate| { - eq[(s << lo_out_num_vars) ^ gate.idx_out] * (&gate.scalar.eval(&challenges)) - } - ); - let g1 = DenseMultilinearExtension::from_evaluations_ext_vec(f1.num_vars, g1).into(); - exit_span!(span); - - // f1'^{(j)}(s1 || x1) = subset[j][i](s1 || x1) - // g1'^{(j)}(s1 || x1) = eq(rt, s1) paste_from[j](ry, x1) - let span = entered_span!("f1j_g1j"); - let (f1_j, g1_j)= izip!(&layer.paste_from).map(|(j, paste_from)| { - let paste_from_sources = circuit_witness.layers_ref(); - let old_wire_id = |old_layer_id: usize, subset_wire_id: usize| -> usize { - circuit.layers[old_layer_id].copy_to[&(layer_id as u32)][subset_wire_id] - }; - - let mut f1_j = vec![0.into(); 1 << f1.num_vars]; - let mut g1_j = vec![E::ZERO; 1 << f1.num_vars]; - - paste_from - .iter() - .enumerate() - .for_each(|(subset_wire_id, &new_wire_id)| { - for s in 0..(1 << (hi_num_vars - log2_max_thread_id)) { - let global_s = thread_s + s; - f1_j[(s << lo_in_num_vars) ^ subset_wire_id] = - paste_from_sources[*j as usize].instances[global_s] - [old_wire_id(*j as usize, subset_wire_id)]; - g1_j[(s << lo_in_num_vars) ^ subset_wire_id] += eq[(global_s << lo_out_num_vars) ^ new_wire_id]; - } - }); - ( - DenseMultilinearExtension::from_evaluations_vec(f1.num_vars, f1_j).into(), - DenseMultilinearExtension::from_evaluations_ext_vec(f1.num_vars, g1_j).into() - ) - }) - .unzip::<_, _, Vec>, Vec>>(); - exit_span!(span); - - let (f, g): ( - Vec>, - Vec>, - ) = ([vec![f1], f1_j].concat(), [vec![g1], g1_j].concat()); - - // sumcheck: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * - // g1'_j(s1 || x1) - let mut virtual_poly_1 = VirtualPolynomial::new(f[0].num_vars); - for (f, g) in f.into_iter().zip(g.into_iter()) { - let mut tmp = VirtualPolynomial::new_from_mle(f, E::BaseField::ONE); - tmp.mul_by_mle(g, E::BaseField::ONE); - virtual_poly_1.merge(&tmp); - } - end_timer!(timer); - - (phase2_next_layer_polys_v2, virtual_poly_1) - } - - pub(super) fn combine_phase2_step1_evals( - &mut self, - circuit: &Circuit, - sumcheck_proof_1: SumcheckProof, - prover_state: sumcheck::structs::IOPProverState, - ) -> IOPProverStepMessage { - let layer = &circuit.layers[self.layer_id as usize]; - let eval_point_1 = sumcheck_proof_1.point.clone(); - let (f1_vec, g1_vec): (Vec<_>, Vec<_>) = prover_state - .get_mle_final_evaluations() - .into_iter() - .enumerate() - .partition(|(i, _)| i % 2 == 0); - let f1_vec_len = f1_vec.len(); - // eval_values_f1 - let mut eval_values_1 = f1_vec.into_iter().map(|(_, f1_j)| f1_j).collect_vec(); - - // eval_values_g1[0] - eval_values_1.push(g1_vec[0].1); - - self.to_next_phase_point_and_evals = - vec![PointAndEval::new_from_ref(&eval_point_1, &eval_values_1[0])]; - izip!( - layer.paste_from.iter(), - eval_values_1[..f1_vec_len].iter().skip(1) - ) - .for_each(|((&old_layer_id, _), &subset_value)| { - self.subset_point_and_evals[old_layer_id as usize].push(( - self.layer_id, - PointAndEval::new_from_ref(&eval_point_1, &subset_value), - )); - }); - self.to_next_step_point = eval_point_1; - - IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_1, - sumcheck_eval_values: eval_values_1, - } - } - - /// Sumcheck 2 sigma = \sum_{s2 || x2} f2(s2 || x2) * g2(s2 || x2) - /// sigma = g1(rs1 || rx1) - eq(rt, rs1) * add(ry, rx1) - /// f2(s2 || x2) = layers[i + 1](s2 || x2) - /// g2(s2 || x2) = \sum_{s3}( \sum_{x3}( - /// eq(rt, rs1, s2, s3) * mul3(ry, rx1, x2, x3) * layers[i + 1](s3 || x3) - /// ) ) + eq(rt, rs1, s2) * mul2(ry, rx1, x2) - #[tracing::instrument(skip_all, name = "build_phase2_step2_sumcheck_poly")] - pub(super) fn build_phase2_step2_sumcheck_poly( - layer_poly: &ArcDenseMultilinearExtension, - layer_id: LayerId, - eqs: &[Vec; 2], - circuit: &Circuit, - circuit_witness: &CircuitWitness, - multi_threads_meta: (usize, usize), - ) -> VirtualPolynomial { - let timer = start_timer!(|| "Prover sumcheck phase 2 step 2"); - let layer = &circuit.layers[layer_id as usize]; - let lo_out_num_vars = layer.num_vars; - let lo_in_num_vars = layer.max_previous_num_vars; - let (eq0, eq1) = (&eqs[0], &eqs[1]); - - // parallel unit logic handling - let hi_num_vars = circuit_witness.instance_num_vars(); - let (thread_id, max_thread_id) = multi_threads_meta; - let log2_max_thread_id = ceil_log2(max_thread_id); - let threads_num_vars = hi_num_vars - log2_max_thread_id; - let thread_s = thread_id << threads_num_vars; - - let phase2_next_layer_vec = circuit_witness.layers[layer_id as usize + 1] - .instances - .as_slice(); - - let challenges = &circuit_witness.challenges; - - let span = entered_span!("f2_g2"); - // f2(s2 || x2) = layers[i + 1](s2 || x2) - let f2 = layer_poly.clone(); - - // g2(s2 || x2) = \sum_{s3}( \sum_{x3}( - // eq(rt, rs1, s2, s3) * mul3(ry, rx1, x2, x3) * layers[i + 1](s3 || x3) - // ) ) + eq(rt, rs1, s2) * mul2(ry, rx1, x2) - let g2: ArcDenseMultilinearExtension = { - let mut g2 = vec![E::ZERO; 1 << (f2.num_vars)]; - let mul3s_fanin_mapping = &layer.mul3s_fanin_mapping[Step2 as usize]; - let mul2s_fanin_mapping = &layer.mul2s_fanin_mapping[Step2 as usize]; - prepare_stepx_g_fn!( - &mut g2, - lo_in_num_vars, - thread_s, - mul3s_fanin_mapping, - |s, gate| { - eq0[(s << lo_out_num_vars) ^ gate.idx_out] - * eq1[(s << lo_in_num_vars) ^ gate.idx_in[0]] - * (&phase2_next_layer_vec[s][gate.idx_in[2]]) - * (&gate.scalar.eval(&challenges)) - }, - mul2s_fanin_mapping, - |s, gate| { - eq0[(s << lo_out_num_vars) ^ gate.idx_out] - * eq1[(s << lo_in_num_vars) ^ gate.idx_in[0]] - * (&gate.scalar.eval(&challenges)) - }, - ); - DenseMultilinearExtension::from_evaluations_ext_vec(f2.num_vars, g2).into() - }; - exit_span!(span); - end_timer!(timer); - - // sumcheck: sigma = \sum_{s2 || x2} f2(s2 || x2) * g2(s2 || x2) - let mut virtual_poly_2 = VirtualPolynomial::new_from_mle(f2, E::BaseField::ONE); - virtual_poly_2.mul_by_mle(g2, E::BaseField::ONE); - virtual_poly_2 - } - - pub(super) fn combine_phase2_step2_evals( - &mut self, - _circuit: &Circuit, - sumcheck_proof_2: SumcheckProof, - prover_state: sumcheck::structs::IOPProverState, - no_step3: bool, - ) -> IOPProverStepMessage { - let eval_point_2 = sumcheck_proof_2.point.clone(); - let (f2, g2): (Vec<_>, Vec<_>) = prover_state - .get_mle_final_evaluations() - .into_iter() - .enumerate() - .partition(|(i, _)| i % 2 == 0); - let (eval_value_f2, eval_value_g2) = (f2[0].1, g2[0].1); - - self.to_next_phase_point_and_evals - .push(PointAndEval::new_from_ref(&eval_point_2, &eval_value_f2)); - self.to_next_step_point = eval_point_2; - if no_step3 { - IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_2, - sumcheck_eval_values: vec![eval_value_f2], - } - } else { - IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_2, - sumcheck_eval_values: vec![eval_value_f2, eval_value_g2], - } - } - } - - /// Sumcheck 3 sigma = \sum_{s3 || x3} f3(s3 || x3) * g3(s3 || x3) - /// sigma = g2(rs2 || rx2) - eq(rt, rs1, rs2) * mul2(ry, rx1, rx2) - /// f3(s3 || x3) = layers[i + 1](s3 || x3) - /// g3(s3 || x3) = eq(rt, rs1, rs2, s3) * mul3(ry, rx1, rx2, x3) - #[tracing::instrument(skip_all, name = "build_phase2_step3_sumcheck_poly")] - pub(super) fn build_phase2_step3_sumcheck_poly( - layer_poly: &ArcDenseMultilinearExtension, - layer_id: LayerId, - eqs: &[Vec; 3], - circuit: &Circuit, - circuit_witness: &CircuitWitness, - multi_threads_meta: (usize, usize), - ) -> VirtualPolynomial { - let timer = start_timer!(|| "Prover sumcheck phase 2 step 3"); - let layer = &circuit.layers[layer_id as usize]; - let lo_out_num_vars = layer.num_vars; - let lo_in_num_vars = layer.max_previous_num_vars; - let (eq0, eq1, eq2) = (&eqs[0], &eqs[1], &eqs[2]); - - // parallel unit logic handling - let hi_num_vars = circuit_witness.instance_num_vars(); - let (thread_id, max_thread_id) = multi_threads_meta; - let log2_max_thread_id = ceil_log2(max_thread_id); - let threads_num_vars = hi_num_vars - log2_max_thread_id; - let thread_s = thread_id << threads_num_vars; - - let challenges = &circuit_witness.challenges; - - let span = entered_span!("f3_g3"); - // f3(s3 || x3) = layers[i + 1](s3 || x3) - let f3: Arc> = layer_poly.clone(); - // g3(s3 || x3) = eq(rt, rs1, rs2, s3) * mul3(ry, rx1, rx2, x3) - let g3 = { - let mut g3 = vec![E::ZERO; 1 << (f3.num_vars)]; - let fanin_mapping = &layer.mul3s_fanin_mapping[Step3 as usize]; - prepare_stepx_g_fn!( - &mut g3, - lo_in_num_vars, - thread_s, - fanin_mapping, - |s, gate| eq0[(s << lo_out_num_vars) ^ gate.idx_out] - * eq1[(s << lo_in_num_vars) ^ gate.idx_in[0]] - * eq2[(s << lo_in_num_vars) ^ gate.idx_in[1]] - * (&gate.scalar.eval(&challenges)) - ); - DenseMultilinearExtension::from_evaluations_ext_vec(f3.num_vars, g3).into() - }; - - let mut virtual_poly_3 = VirtualPolynomial::new_from_mle(f3, E::BaseField::ONE); - virtual_poly_3.mul_by_mle(g3, E::BaseField::ONE); - - exit_span!(span); - end_timer!(timer); - virtual_poly_3 - } - - pub(super) fn combine_phase2_step3_evals( - &mut self, - _circuit: &Circuit, - sumcheck_proof_3: SumcheckProof, - prover_state: sumcheck::structs::IOPProverState, - ) -> IOPProverStepMessage { - let eval_point_3 = sumcheck_proof_3.point.clone(); - let (f3, _): (Vec<_>, Vec<_>) = prover_state - .get_mle_final_evaluations() - .into_iter() - .enumerate() - .partition(|(i, _)| i % 2 == 0); - let eval_values_3 = vec![f3[0].1]; - self.to_next_phase_point_and_evals - .push(PointAndEval::new_from_ref(&eval_point_3, &eval_values_3[0])); - self.to_next_step_point = eval_point_3; - IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_3, - sumcheck_eval_values: eval_values_3, - } - } -} diff --git a/gkr/src/prover/phase2_input.rs b/gkr/src/prover/phase2_input.rs index af9f8e55d..fb869264a 100644 --- a/gkr/src/prover/phase2_input.rs +++ b/gkr/src/prover/phase2_input.rs @@ -4,7 +4,7 @@ use ff_ext::ExtensionField; use itertools::{izip, Itertools}; use multilinear_extensions::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, + virtual_poly::build_eq_x_r_vec, virtual_poly_v2::VirtualPolynomialV2, }; #[cfg(feature = "parallel")] @@ -15,11 +15,8 @@ use transcript::Transcript; use crate::{ izip_parallizable, - prover::{SumcheckState, SumcheckStateV2}, - structs::{ - Circuit, CircuitWitness, CircuitWitnessV2, IOPProverState, IOPProverStateV2, - IOPProverStepMessage, PointAndEval, - }, + prover::SumcheckStateV2, + structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval}, }; // Prove the computation in the current layer for data parallel circuits. @@ -28,7 +25,7 @@ use crate::{ // layers[i](rt || ry) = \sum_x( // \sum_j paste_from[j](ry, x) * subset[j][i](rt || x) // ) + add_const(ry) -impl IOPProverStateV2 { +impl IOPProverState { /// Sumcheck 1: sigma = \sum_j f1'_j(x1) * g1'_j(x1) /// sigma = layers[i](rt || ry) - add_const(ry), /// f1'^{(j)}(x1) = subset[j][i](rt || x1) @@ -37,7 +34,7 @@ impl IOPProverStateV2 { pub(super) fn prove_and_update_state_input_phase2_step1( &mut self, circuit: &Circuit, - circuit_witness: &CircuitWitnessV2, + circuit_witness: &CircuitWitness, transcript: &mut Transcript, ) -> IOPProverStepMessage { let timer = start_timer!(|| "Prover sumcheck input phase 2 step 1"); @@ -169,147 +166,3 @@ impl IOPProverStateV2 { } } } - -// Prove the computation in the current layer for data parallel circuits. -// The number of terms depends on the gate. -// Here is an example of degree 3: -// layers[i](rt || ry) = \sum_x( -// \sum_j paste_from[j](ry, x) * subset[j][i](rt || x) -// ) + add_const(ry) -impl IOPProverState { - /// Sumcheck 1: sigma = \sum_j f1'_j(x1) * g1'_j(x1) - /// sigma = layers[i](rt || ry) - add_const(ry), - /// f1'^{(j)}(x1) = subset[j][i](rt || x1) - /// g1'^{(j)}(x1) = paste_from[j](ry, x1) - #[tracing::instrument(skip_all, name = "prove_and_update_state_input_phase2_step1")] - pub(super) fn prove_and_update_state_input_phase2_step1( - &mut self, - circuit: &Circuit, - circuit_witness: &CircuitWitness, - transcript: &mut Transcript, - ) -> IOPProverStepMessage { - let timer = start_timer!(|| "Prover sumcheck input phase 2 step 1"); - let layer = &circuit.layers[self.layer_id as usize]; - let lo_out_num_vars = layer.num_vars; - let max_lo_in_num_vars = circuit.max_wit_in_num_vars.unwrap_or(0); - let hi_num_vars = circuit_witness.instance_num_vars(); - let hi_point = &self.to_next_step_point[lo_out_num_vars..]; - - let eq_y_ry = build_eq_x_r_vec(&self.to_next_step_point[..lo_out_num_vars]); - - let paste_from_wit_in = &circuit.paste_from_wits_in; - let wits_in = circuit_witness.witness_in_ref(); - - let (mut f_vec, mut g_vec): ( - Vec>, - Vec>, - ) = izip_parallizable!(paste_from_wit_in) - .enumerate() - .map(|(j, (l, r))| { - let mut f = vec![0.into(); 1 << (max_lo_in_num_vars + hi_num_vars)]; - let mut g = vec![E::ZERO; 1 << max_lo_in_num_vars]; - - for new_wire_id in *l..*r { - let subset_wire_id = new_wire_id - l; - for s in 0..(1 << hi_num_vars) { - f[(s << max_lo_in_num_vars) ^ subset_wire_id] = - wits_in[j as usize].instances[s][subset_wire_id]; - } - g[subset_wire_id] = eq_y_ry[new_wire_id]; - } - ( - { - let mut f = DenseMultilinearExtension::from_evaluations_vec( - max_lo_in_num_vars + hi_num_vars, - f, - ); - f.fix_high_variables_in_place(hi_point); - f.into() - }, - DenseMultilinearExtension::from_evaluations_ext_vec(max_lo_in_num_vars, g) - .into(), - ) - }) - .unzip(); - - let paste_from_counter_in = &circuit.paste_from_counter_in; - let (f_vec_counter_in, g_vec_counter_in): ( - Vec>, - Vec>, - ) = izip_parallizable!(paste_from_counter_in) - .map(|(num_vars, (l, r))| { - let mut f = vec![0.into(); 1 << (max_lo_in_num_vars + hi_num_vars)]; - let mut g = vec![E::ZERO; 1 << max_lo_in_num_vars]; - - for new_wire_id in *l..*r { - let subset_wire_id = new_wire_id - l; - for s in 0..(1 << hi_num_vars) { - f[(s << max_lo_in_num_vars) ^ subset_wire_id] = - E::BaseField::from(((s << num_vars) ^ subset_wire_id) as u64); - } - g[subset_wire_id] = eq_y_ry[new_wire_id]; - } - ( - { - let mut f = DenseMultilinearExtension::from_evaluations_vec( - max_lo_in_num_vars + hi_num_vars, - f, - ); - f.fix_high_variables_in_place(&hi_point); - f.into() - }, - DenseMultilinearExtension::from_evaluations_ext_vec(max_lo_in_num_vars, g) - .into(), - ) - }) - .unzip(); - - f_vec.extend(f_vec_counter_in); - g_vec.extend(g_vec_counter_in); - - let mut virtual_poly = VirtualPolynomial::new(max_lo_in_num_vars); - for (f, g) in f_vec.into_iter().zip(g_vec.into_iter()) { - let mut tmp = VirtualPolynomial::new_from_mle(f, E::BaseField::ONE); - tmp.mul_by_mle(g, E::BaseField::ONE); - virtual_poly.merge(&tmp); - } - - let (sumcheck_proofs, prover_state) = - SumcheckState::prove_parallel(virtual_poly, transcript); - let eval_point = sumcheck_proofs.point.clone(); - let (f_vec, _): (Vec<_>, Vec<_>) = prover_state - .get_mle_final_evaluations() - .into_iter() - .enumerate() - .partition(|(i, _)| i % 2 == 0); - let eval_values_f = f_vec - .into_iter() - .map(|(_, f)| f) - .take(wits_in.len()) - .collect_vec(); - - self.to_next_phase_point_and_evals = izip!(paste_from_wit_in.iter(), eval_values_f.iter()) - .map(|((l, r), eval)| { - let num_vars = ceil_log2(*r - *l); - let point = [&eval_point[..num_vars], hi_point].concat(); - let wit_in_eval = *eval - * eval_point[num_vars..] - .iter() - .map(|x| E::ONE - *x) - .product::() - .invert() - .unwrap(); - PointAndEval::new_from_ref(&point, &wit_in_eval) - }) - .collect_vec(); - - self.to_next_step_point = [&eval_point, hi_point].concat(); - - end_timer!(timer); - - IOPProverStepMessage { - sumcheck_proof: sumcheck_proofs, - sumcheck_eval_values: eval_values_f, - } - } -} diff --git a/gkr/src/prover/phase2_linear.rs b/gkr/src/prover/phase2_linear.rs index 5409919f3..327c70f0a 100644 --- a/gkr/src/prover/phase2_linear.rs +++ b/gkr/src/prover/phase2_linear.rs @@ -6,7 +6,7 @@ use ff_ext::ExtensionField; use itertools::{izip, Itertools}; use multilinear_extensions::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, + virtual_poly::build_eq_x_r_vec, virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}, }; use transcript::Transcript; @@ -14,22 +14,16 @@ use transcript::Transcript; use crate::{ circuit::EvaluateConstant, prover::SumcheckStateV2, - structs::{ - Circuit, CircuitWitness, CircuitWitnessV2, IOPProverState, IOPProverStateV2, - IOPProverStepMessage, PointAndEval, - }, - utils::MultilinearExtensionFromVectors, + structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval}, }; -use super::SumcheckState; - // Prove the computation in the current layer for data parallel circuits. // The number of terms depends on the gate. // Here is an example of degree 3: // layers[i](rt || ry) = \sum_x( // add(ry, x) * layers[i + 1](rt || x) + \sum_j paste_from[j](ry, x) * subset[j][i](rt || x) // ) + add_const(ry) -impl IOPProverStateV2 { +impl IOPProverState { /// Sumcheck 1: sigma = \sum_{x1} f1(x1) * g1(x1) + \sum_j f1'_j(x1) * g1'_j(x1) /// sigma = layers[i](rt || ry) - add_const(ry), /// f1(x1) = layers[i + 1](rt || x1) @@ -40,7 +34,7 @@ impl IOPProverStateV2 { pub(super) fn prove_and_update_state_linear_phase2_step1( &mut self, circuit: &Circuit, - circuit_witness: &CircuitWitnessV2, + circuit_witness: &CircuitWitness, transcript: &mut Transcript, ) -> IOPProverStepMessage { let timer = start_timer!(|| "Prover sumcheck phase 2 step 1"); @@ -164,129 +158,3 @@ impl IOPProverStateV2 { } } } - -// Prove the computation in the current layer for data parallel circuits. -// The number of terms depends on the gate. -// Here is an example of degree 3: -// layers[i](rt || ry) = \sum_x( -// add(ry, x) * layers[i + 1](rt || x) + \sum_j paste_from[j](ry, x) * subset[j][i](rt || x) -// ) + add_const(ry) -impl IOPProverState { - /// Sumcheck 1: sigma = \sum_{x1} f1(x1) * g1(x1) + \sum_j f1'_j(x1) * g1'_j(x1) - /// sigma = layers[i](rt || ry) - add_const(ry), - /// f1(x1) = layers[i + 1](rt || x1) - /// g1(x1) = add(ry, x1) - /// f1'^{(j)}(x1) = subset[j][i](rt || x1) - /// g1'^{(j)}(x1) = paste_from[j](ry, x1) - #[tracing::instrument(skip_all, name = "prove_and_update_state_linear_phase2_step1")] - pub(super) fn prove_and_update_state_linear_phase2_step1( - &mut self, - circuit: &Circuit, - circuit_witness: &CircuitWitness, - transcript: &mut Transcript, - ) -> IOPProverStepMessage { - let timer = start_timer!(|| "Prover sumcheck phase 2 step 1"); - let layer = &circuit.layers[self.layer_id as usize]; - let lo_out_num_vars = layer.num_vars; - let lo_in_num_vars = layer.max_previous_num_vars; - let hi_num_vars = circuit_witness.instance_num_vars(); - let hi_point = &self.to_next_step_point[lo_out_num_vars..]; - - let eq_y_ry = build_eq_x_r_vec(&self.to_next_step_point[..lo_out_num_vars]); - - let challenges = &circuit_witness.challenges; - - let f1_g1 = || { - // f1(x1) = layers[i + 1](rt || x1) - let layer_in_vec = circuit_witness.layers[self.layer_id as usize + 1] - .instances - .as_slice(); - let mut f1 = layer_in_vec.mle(lo_in_num_vars, hi_num_vars); - Arc::make_mut(&mut f1).fix_high_variables_in_place(&hi_point); - - // g1(x1) = add(ry, x1) - let g1 = { - let mut g1 = vec![E::ZERO; 1 << lo_in_num_vars]; - layer.adds.iter().for_each(|gate| { - g1[gate.idx_in[0]] += eq_y_ry[gate.idx_out] * &gate.scalar.eval(&challenges); - }); - DenseMultilinearExtension::from_evaluations_ext_vec(lo_in_num_vars, g1) - }; - (vec![f1], vec![g1.into()]) - }; - - let (mut f1_vec, mut g1_vec) = f1_g1(); - - // f1'^{(j)}(x1) = subset[j][i](rt || x1) - // g1'^{(j)}(x1) = paste_from[j](ry, x1) - let paste_from_sources = circuit_witness.layers_ref(); - let old_wire_id = |old_layer_id: usize, subset_wire_id: usize| -> usize { - circuit.layers[old_layer_id].copy_to[&self.layer_id][subset_wire_id] - }; - layer.paste_from.iter().for_each(|(&j, paste_from)| { - let mut f1_j = vec![0.into(); 1 << (lo_in_num_vars + hi_num_vars)]; - let mut g1_j = vec![E::ZERO; 1 << lo_in_num_vars]; - - paste_from - .iter() - .enumerate() - .for_each(|(subset_wire_id, &new_wire_id)| { - for s in 0..(1 << hi_num_vars) { - f1_j[(s << lo_in_num_vars) ^ subset_wire_id] = paste_from_sources - [j as usize] - .instances[s][old_wire_id(j as usize, subset_wire_id)]; - } - - g1_j[subset_wire_id] += eq_y_ry[new_wire_id]; - }); - f1_vec.push({ - let mut f1_j = DenseMultilinearExtension::from_evaluations_vec( - lo_in_num_vars + hi_num_vars, - f1_j, - ); - f1_j.fix_high_variables_in_place(&hi_point); - f1_j.into() - }); - g1_vec.push( - DenseMultilinearExtension::from_evaluations_ext_vec(lo_in_num_vars, g1_j).into(), - ); - }); - - // sumcheck: sigma = \sum_{x1} f1(x1) * g1(x1) + \sum_j f1'_j(x1) * g1'_j(x1) - let mut virtual_poly_1 = VirtualPolynomial::new(lo_in_num_vars); - for (f1_j, g1_j) in izip!(f1_vec.into_iter(), g1_vec.into_iter()) { - let mut tmp = VirtualPolynomial::new_from_mle(f1_j, E::BaseField::ONE); - tmp.mul_by_mle(g1_j, E::BaseField::ONE); - virtual_poly_1.merge(&tmp); - } - - let (sumcheck_proof_1, prover_state) = - SumcheckState::prove_parallel(virtual_poly_1, transcript); - let eval_point_1 = sumcheck_proof_1.point.clone(); - let (f1_vec, _): (Vec<_>, Vec<_>) = prover_state - .get_mle_final_evaluations() - .into_iter() - .enumerate() - .partition(|(i, _)| i % 2 == 0); - let eval_values_f1 = f1_vec.into_iter().map(|(_, f1_j)| f1_j).collect_vec(); - - let new_point = [&eval_point_1, hi_point].concat(); - self.to_next_phase_point_and_evals = - vec![PointAndEval::new_from_ref(&new_point, &eval_values_f1[0])]; - izip!(layer.paste_from.iter(), eval_values_f1.iter().skip(1)).for_each( - |((&old_layer_id, _), &subset_value)| { - self.subset_point_and_evals[old_layer_id as usize].push(( - self.layer_id, - PointAndEval::new_from_ref(&new_point, &subset_value), - )); - }, - ); - self.to_next_step_point = new_point; - end_timer!(timer); - - IOPProverStepMessage { - sumcheck_proof: sumcheck_proof_1, - sumcheck_eval_values: eval_values_f1, - } - } -} diff --git a/gkr/src/prover/test.rs b/gkr/src/prover/test.rs index 5b38d1e22..f9dec122b 100644 --- a/gkr/src/prover/test.rs +++ b/gkr/src/prover/test.rs @@ -5,15 +5,13 @@ use ff::Field; use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::{izip, Itertools}; -use multilinear_extensions::mle::MultilinearExtension; +use multilinear_extensions::mle::DenseMultilinearExtension; use simple_frontend::structs::{ChallengeConst, ChallengeId, CircuitBuilder, MixedCell}; use transcript::Transcript; use crate::{ - structs::{ - Circuit, CircuitWitness, IOPProverState, IOPVerifierState, LayerWitness, PointAndEval, - }, - utils::{i64_to_field, MultilinearExtensionFromVectors}, + structs::{Circuit, CircuitWitness, IOPProverState, IOPVerifierState, PointAndEval}, + utils::i64_to_field, }; fn copy_and_paste_circuit() -> Circuit { @@ -45,10 +43,8 @@ fn copy_and_paste_circuit() -> Circuit { circuit } -fn copy_and_paste_witness() -> ( - Vec>, - CircuitWitness, -) { +fn copy_and_paste_witness<'a, Ext: ExtensionField>() +-> (Vec>, CircuitWitness<'a, Ext>) { // witness_in, single instance let inputs = vec![vec![ i64_to_field(5), @@ -56,42 +52,36 @@ fn copy_and_paste_witness() -> ( i64_to_field(11), i64_to_field(13), ]]; - let witness_in = vec![LayerWitness { instances: inputs }]; + let witness_in: Vec> = vec![inputs.into()]; - let layers = vec![ - LayerWitness { - instances: vec![vec![i64_to_field(175175)]], - }, - LayerWitness { - instances: vec![vec![ - i64_to_field(385), - i64_to_field(35), - i64_to_field(13), - i64_to_field(0), // pad - ]], - }, - LayerWitness { - instances: vec![vec![i64_to_field(35), i64_to_field(11)]], - }, - LayerWitness { - instances: vec![vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ]], - }, + let layers: Vec> = vec![ + vec![vec![i64_to_field(175175)]].into(), + vec![vec![ + i64_to_field(385), + i64_to_field(35), + i64_to_field(13), + i64_to_field(0), // pad + ]] + .into(), + vec![vec![i64_to_field(35), i64_to_field(11)]].into(), + vec![vec![ + i64_to_field(5), + i64_to_field(7), + i64_to_field(11), + i64_to_field(13), + ]] + .into(), ]; let outputs = vec![vec![i64_to_field(175175)]]; - let witness_out = vec![LayerWitness { instances: outputs }]; + let witness_out: Vec> = vec![outputs.into()]; ( witness_in.clone(), CircuitWitness { - layers, - witness_in, - witness_out, + layers: layers.into_iter().map(|w| w.into()).collect(), + witness_in: witness_in.into_iter().map(|w| w.into()).collect(), + witness_out: witness_out.into_iter().map(|w| w.into()).collect(), n_instances: 1, challenges: HashMap::new(), }, @@ -123,71 +113,54 @@ fn paste_from_wit_in_circuit() -> Circuit { circuit } -fn paste_from_wit_in_witness() -> ( - Vec>, - CircuitWitness, -) { +fn paste_from_wit_in_witness<'a, Ext: ExtensionField>() +-> (Vec>, CircuitWitness<'a, Ext>) { // witness_in, single instance let leaves1 = vec![vec![i64_to_field(5), i64_to_field(7), i64_to_field(11)]]; let leaves2 = vec![vec![i64_to_field(13), i64_to_field(17), i64_to_field(19)]]; let dummy = vec![vec![i64_to_field(13), i64_to_field(17), i64_to_field(19)]]; - let witness_in = vec![ - LayerWitness { instances: leaves1 }, - LayerWitness { instances: leaves2 }, - LayerWitness { instances: dummy }, - ]; - - let layers = vec![ - LayerWitness { - instances: vec![vec![ - i64_to_field(5005), - i64_to_field(35), - i64_to_field(143), - i64_to_field(0), // pad - ]], - }, - LayerWitness { - instances: vec![vec![i64_to_field(35), i64_to_field(143)]], - }, - LayerWitness { - instances: vec![vec![ - i64_to_field(5), // leaves1 - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), // leaves2 - i64_to_field(17), - i64_to_field(19), - i64_to_field(13), // dummy - i64_to_field(17), - i64_to_field(19), - i64_to_field(0), // counter - i64_to_field(1), - i64_to_field(1), // constant - i64_to_field(1), - i64_to_field(0), // pad - i64_to_field(0), - i64_to_field(0), - ]], - }, + let witness_in = vec![leaves1.into(), leaves2.into(), dummy.into()]; + + let layers: Vec> = vec![ + vec![vec![ + i64_to_field(5005), + i64_to_field(35), + i64_to_field(143), + i64_to_field(0), // pad + ]] + .into(), + vec![vec![i64_to_field(35), i64_to_field(143)]].into(), + vec![vec![ + i64_to_field(5), // leaves1 + i64_to_field(7), + i64_to_field(11), + i64_to_field(13), // leaves2 + i64_to_field(17), + i64_to_field(19), + i64_to_field(13), // dummy + i64_to_field(17), + i64_to_field(19), + i64_to_field(0), // counter + i64_to_field(1), + i64_to_field(1), // constant + i64_to_field(1), + i64_to_field(0), // pad + i64_to_field(0), + i64_to_field(0), + ]] + .into(), ]; let outputs1 = vec![vec![i64_to_field(35), i64_to_field(143)]]; let outputs2 = vec![vec![i64_to_field(5005)]]; - let witness_out = vec![ - LayerWitness { - instances: outputs1, - }, - LayerWitness { - instances: outputs2, - }, - ]; + let witness_out: Vec> = vec![outputs1.into(), outputs2.into()]; ( witness_in.clone(), CircuitWitness { - layers, - witness_in, - witness_out, + layers: layers.into_iter().map(|w| w.into()).collect(), + witness_in: witness_in.into_iter().map(|w| w.into()).collect(), + witness_out: witness_out.into_iter().map(|w| w.into()).collect(), n_instances: 1, challenges: HashMap::new(), }, @@ -215,60 +188,53 @@ fn copy_to_wit_out_circuit() -> Circuit { circuit } -fn copy_to_wit_out_witness() -> ( - Vec>, - CircuitWitness, -) { +fn copy_to_wit_out_witness<'a, Ext: ExtensionField>() +-> (Vec>, CircuitWitness<'a, Ext>) { // witness_in, single instance let leaves = vec![vec![ i64_to_field(5), i64_to_field(7), i64_to_field(11), i64_to_field(13), - ]]; - let witness_in = vec![LayerWitness { instances: leaves }]; - - let layers = vec![ - LayerWitness { - instances: vec![vec![ - i64_to_field(5005), - i64_to_field(35), - i64_to_field(143), - i64_to_field(0), // pad - ]], - }, - LayerWitness { - instances: vec![vec![i64_to_field(35), i64_to_field(143)]], - }, - LayerWitness { - instances: vec![vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ]], - }, + ]] + .into(); + let witness_in = vec![leaves]; + + let layers: Vec> = vec![ + vec![vec![ + i64_to_field(5005), + i64_to_field(35), + i64_to_field(143), + i64_to_field(0), // pad + ]] + .into(), + vec![vec![i64_to_field(35), i64_to_field(143)]].into(), + vec![vec![ + i64_to_field(5), + i64_to_field(7), + i64_to_field(11), + i64_to_field(13), + ]] + .into(), ]; let outputs = vec![vec![i64_to_field(35), i64_to_field(143)]]; - let witness_out = vec![LayerWitness { instances: outputs }]; + let witness_out: Vec> = vec![outputs.into()]; ( witness_in.clone(), CircuitWitness { - layers, - witness_in, - witness_out, + layers: layers.into_iter().map(|w| w.into()).collect(), + witness_in: witness_in.into_iter().map(|w| w.into()).collect(), + witness_out: witness_out.into_iter().map(|w| w.into()).collect(), n_instances: 1, challenges: HashMap::new(), }, ) } -fn copy_to_wit_out_witness_2() -> ( - Vec>, - CircuitWitness, -) { +fn copy_to_wit_out_witness_2<'a, Ext: ExtensionField>() +-> (Vec>, CircuitWitness<'a, Ext>) { // witness_in, 2 instances let leaves = vec![ vec![ @@ -284,61 +250,58 @@ fn copy_to_wit_out_witness_2() -> ( i64_to_field(7), ], ]; - let witness_in = vec![LayerWitness { instances: leaves }]; - - let layers = vec![ - LayerWitness { - instances: vec![ - vec![ - i64_to_field(5005), - i64_to_field(35), - i64_to_field(143), - i64_to_field(0), // pad - ], - vec![ - i64_to_field(5005), - i64_to_field(65), - i64_to_field(77), - i64_to_field(0), // pad - ], + let witness_in = vec![leaves.into()]; + + let layers: Vec> = vec![ + vec![ + vec![ + i64_to_field(5005), + i64_to_field(35), + i64_to_field(143), + i64_to_field(0), // pad ], - }, - LayerWitness { - instances: vec![ - vec![i64_to_field(35), i64_to_field(143)], - vec![i64_to_field(65), i64_to_field(77)], + vec![ + i64_to_field(5005), + i64_to_field(65), + i64_to_field(77), + i64_to_field(0), // pad ], - }, - LayerWitness { - instances: vec![ - vec![ - i64_to_field(5), - i64_to_field(7), - i64_to_field(11), - i64_to_field(13), - ], - vec![ - i64_to_field(5), - i64_to_field(13), - i64_to_field(11), - i64_to_field(7), - ], + ] + .into(), + vec![ + vec![i64_to_field(35), i64_to_field(143)], + vec![i64_to_field(65), i64_to_field(77)], + ] + .into(), + vec![ + vec![ + i64_to_field(5), + i64_to_field(7), + i64_to_field(11), + i64_to_field(13), ], - }, + vec![ + i64_to_field(5), + i64_to_field(13), + i64_to_field(11), + i64_to_field(7), + ], + ] + .into(), ]; let outputs = vec![ vec![i64_to_field(35), i64_to_field(143)], vec![i64_to_field(65), i64_to_field(77)], ]; - let witness_out = vec![LayerWitness { instances: outputs }]; + let witness_out: Vec> = vec![outputs.into()]; ( witness_in.clone(), CircuitWitness { - layers, - witness_in, - witness_out, + layers: layers.into_iter().map(|w| w.into()).collect(), + witness_in: witness_in.into_iter().map(|w| w.into()).collect(), + witness_out: witness_out.into_iter().map(|w| w.into()).collect(), n_instances: 2, challenges: HashMap::new(), }, @@ -365,9 +328,9 @@ fn rlc_circuit() -> Circuit { circuit } -fn rlc_witness() -> ( - Vec>, - CircuitWitness, +fn rlc_witness<'a, Ext>() -> ( + Vec>, + CircuitWitness<'a, Ext>, Vec, ) where @@ -410,9 +373,7 @@ where i64_to_field(7), ], ]; - let witness_in = vec![LayerWitness { - instances: leaves.clone(), - }]; + let witness_in = vec![leaves.clone().into()]; let inner00: Ext = challenge_pows[0][0].1 * (&leaves[0][0]) + challenge_pows[0][1].1 * (&leaves[0][1]) @@ -453,26 +414,22 @@ where root1.as_bases().into_iter().cloned().collect_vec(), ]; - let layers = vec![ - LayerWitness { - instances: roots.clone(), - }, - LayerWitness { - instances: root_tmps, - }, - LayerWitness { instances: inners }, - LayerWitness { instances: leaves }, + let layers: Vec> = vec![ + roots.clone().into(), + root_tmps.into(), + inners.into(), + leaves.into(), ]; let outputs = roots; - let witness_out = vec![LayerWitness { instances: outputs }]; + let witness_out: Vec> = vec![outputs.into()]; ( witness_in.clone(), CircuitWitness { - layers, - witness_in, - witness_out, + layers: layers.into_iter().map(|w| w.into()).collect(), + witness_in: witness_in.into_iter().map(|w| w.into()).collect(), + witness_out: witness_out.into_iter().map(|w| w.into()).collect(), n_instances: 2, challenges: challenge_pows .iter() @@ -512,7 +469,7 @@ fn inv_sum_circuit() -> Circuit { Circuit::new(&circuit_builder) } -fn inv_sum_witness_4_instances() -> CircuitWitness { +fn inv_sum_witness_4_instances<'a, Ext: ExtensionField>() -> CircuitWitness<'a, Ext> { let circuit = inv_sum_circuit::(); // witness_in, double instances let leaves = vec![ @@ -547,10 +504,7 @@ fn inv_sum_witness_4_instances() -> CircuitWitness() -> Circuit { Circuit::new(&circuit_builder) } -fn lookup_inner_witness_4_instances() -> CircuitWitness { +fn lookup_inner_witness_4_instances<'a, Ext: ExtensionField>() -> CircuitWitness<'a, Ext> { let circuit = lookup_inner_circuit::(); // witness_in, double instances let leaves = vec![ @@ -634,7 +588,7 @@ fn lookup_inner_witness_4_instances() -> CircuitWitness() -> Circuit { Circuit::new(&circuit_builder) } -fn mixed_in_witness_4_instances() -> CircuitWitness { +fn mixed_in_witness_4_instances<'a, Ext: ExtensionField>() -> CircuitWitness<'a, Ext> { let circuit = mixed_in_circuit::(); // witness_in, double instances let input = vec![ @@ -721,20 +675,15 @@ fn mixed_in_witness_4_instances() -> CircuitWitness( +fn prove_and_verify<'a, Ext: ExtensionField>( circuit: Circuit, - circuit_wits: CircuitWitness, + circuit_wits: CircuitWitness<'a, Ext>, challenges: Vec, ) { let mut rng = test_rng(); @@ -746,12 +695,7 @@ fn prove_and_verify( let out_point_and_evals = if circuit.n_witness_out == 0 { vec![PointAndEval::new( out_point.clone(), - circuit_wits - .output_layer_witness_ref() - .instances - .as_slice() - .mle(circuit.output_num_vars(), circuit_wits.instance_num_vars()) - .evaluate(&out_point), + circuit_wits.output_layer_witness_ref().evaluate(&out_point), )] } else { vec![] @@ -759,15 +703,7 @@ fn prove_and_verify( let wit_out_point_and_evals = circuit_wits .witness_out_ref() .iter() - .map(|wit| { - PointAndEval::new( - out_point.clone(), - wit.instances - .as_slice() - .mle(circuit.output_num_vars(), circuit_wits.instance_num_vars()) - .evaluate(&out_point), - ) - }) + .map(|wit| PointAndEval::new(out_point.clone(), wit.evaluate(&out_point))) .collect_vec(); let mut prover_transcript = Transcript::new(b"transcrhipt"); @@ -804,7 +740,7 @@ fn prove_and_verify( circuit_wits.witness_in.iter(), prover_input_claim.point_and_evals.iter() ) - .any(|(wit, p)| wit.instances.as_slice().original_mle().evaluate(&p.point) != p.eval) + .any(|(wit, p)| wit.evaluate(&p.point) != p.eval) ); } diff --git a/gkr/src/structs.rs b/gkr/src/structs.rs index 8a793725e..d7600162a 100644 --- a/gkr/src/structs.rs +++ b/gkr/src/structs.rs @@ -6,9 +6,7 @@ use std::{ use ff_ext::ExtensionField; use goldilocks::SmallField; use multilinear_extensions::{ - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}, - util::ceil_log2, - virtual_poly_v2::ArcMultilinearExtension, + mle::ArcDenseMultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, }; use serde::{Deserialize, Serialize, Serializer}; use simple_frontend::structs::{CellId, ChallengeConst, ConstantType, LayerId}; @@ -51,26 +49,6 @@ impl PointAndEval { } } -/// Represent the prover state for each layer in the IOP protocol. To support -/// gates between non-adjacent layers, we leverage the techniques in -/// [Virgo++](https://eprint.iacr.org/2020/1247). -pub struct IOPProverStateV2 { - pub(crate) layer_id: LayerId, - /// Evaluations to the next phase. - pub(crate) to_next_phase_point_and_evals: Vec>, - /// Evaluations of subsets from layers __closer__ to the output. - /// __closer__ as in the layer that the subset elements lie in has not been processed. - /// - /// LayerId is the layer id of the incoming subset point and evaluation. - pub(crate) subset_point_and_evals: Vec)>>, - - /// The point to the next step. - pub(crate) to_next_step_point: Point, - - // Especially for output phase1. - pub(crate) assert_point: Point, -} - /// Represent the prover state for each layer in the IOP protocol. To support /// gates between non-adjacent layers, we leverage the techniques in /// [Virgo++](https://eprint.iacr.org/2020/1247). @@ -87,34 +65,8 @@ pub struct IOPProverState { /// The point to the next step. pub(crate) to_next_step_point: Point, - // Especially for output phase1. - pub(crate) phase1_layer_poly: ArcDenseMultilinearExtension, - pub(crate) assert_point: Point, - // Especially for phase1. - pub(crate) g1_values: Vec, -} - -/// Represent the verifier state for each layer in the IOP protocol. -pub struct IOPVerifierStateV2 { - pub(crate) layer_id: LayerId, - /// Evaluations from the next layer. - pub(crate) to_next_phase_point_and_evals: Vec>, - /// Evaluations of subsets from layers closer to the output. LayerId is the - /// layer id of the incoming subset point and evaluation. - pub(crate) subset_point_and_evals: Vec)>>, - - pub(crate) challenges: HashMap>, - pub(crate) instance_num_vars: usize, - - pub(crate) to_next_step_point_and_eval: PointAndEval, - // Especially for output phase1. pub(crate) assert_point: Point, - // Especially for phase2. - pub(crate) out_point: Point, - pub(crate) eq_y_ry: Vec, - pub(crate) eq_x1_rx1: Vec, - pub(crate) eq_x2_rx2: Vec, } /// Represent the verifier state for each layer in the IOP protocol. @@ -133,8 +85,6 @@ pub struct IOPVerifierState { // Especially for output phase1. pub(crate) assert_point: Point, - // Especially for phase1. - pub(crate) g1_values: Vec, // Especially for phase2. pub(crate) out_point: Point, pub(crate) eq_y_ry: Vec, @@ -278,7 +228,7 @@ impl Serialize for Gate { // TODO fix comments #[derive(Clone)] -pub struct CircuitWitnessV2<'a, E: ExtensionField> { +pub struct CircuitWitness<'a, E: ExtensionField> { /// Three vectors denote 1. layer_id, 2. instance_id || wire_id. pub(crate) layers: Vec>, /// Three vectors denote 1. wires_in id, 2. instance_id || wire_id. @@ -291,20 +241,6 @@ pub struct CircuitWitnessV2<'a, E: ExtensionField> { pub(crate) n_instances: usize, } -#[derive(Clone, PartialEq, Serialize)] -pub struct CircuitWitness { - /// Three vectors denote 1. layer_id, 2. instance_id, 3. wire_id. - pub(crate) layers: Vec>, - /// 1. wires_in id, 2. instance_id, 3. wire_id. - pub(crate) witness_in: Vec>, - /// 1. wires_out id, 2. instance_id, 3. wire_id. - pub(crate) witness_out: Vec>, - /// Challenges - pub(crate) challenges: HashMap>, - /// The number of instances for the same sub-circuit. - pub(crate) n_instances: usize, -} - #[derive(Clone, Debug, Default, PartialEq, Serialize)] pub struct LayerWitness { pub instances: Vec>, diff --git a/gkr/src/test/is_zero_gadget.rs b/gkr/src/test/is_zero_gadget.rs index 380e7595d..77de395f7 100644 --- a/gkr/src/test/is_zero_gadget.rs +++ b/gkr/src/test/is_zero_gadget.rs @@ -1,12 +1,9 @@ -use crate::{ - structs::{Circuit, CircuitWitness, IOPProverState, IOPVerifierState, PointAndEval}, - utils::MultilinearExtensionFromVectors, -}; +use crate::structs::{Circuit, CircuitWitness, IOPProverState, IOPVerifierState, PointAndEval}; use ff::Field; use ff_ext::ExtensionField; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; -use multilinear_extensions::mle::MultilinearExtension; +use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLE}; use simple_frontend::structs::{CellId, CircuitBuilder}; use std::{iter, time::Duration}; use transcript::Transcript; @@ -46,8 +43,13 @@ pub fn is_zero_gadget( #[test] fn test_gkr_circuit_is_zero_gadget_simple() { // input and output - let in_value = vec![Goldilocks::from(5)]; - let in_inv = vec![Goldilocks::from(5).invert().unwrap()]; + let in_value: Vec<::BaseField> = + vec![::BaseField::from(5)]; + let in_inv: Vec<::BaseField> = vec![ + ::BaseField::from(5) + .invert() + .unwrap(), + ]; let out_is_zero = Goldilocks::from(0); // build the circuit, only one cell for value, inv and value * inv etc @@ -66,9 +68,9 @@ fn test_gkr_circuit_is_zero_gadget_simple() { // assign wire in let n_wits_in = circuit.n_witness_in; - let mut wit_in = vec![vec![]; n_wits_in]; - wit_in[value_wire_in_id as usize] = in_value; - wit_in[inv_wire_in_id as usize] = in_inv; + let mut wit_in = vec![DenseMultilinearExtension::default(); n_wits_in]; + wit_in[value_wire_in_id as usize] = in_value.into_mle(); + wit_in[inv_wire_in_id as usize] = in_inv.into_mle(); let circuit_witness = { let challenges = vec![GoldilocksExt2::from(2)]; let mut circuit_witness = CircuitWitness::new(&circuit, challenges); @@ -92,10 +94,16 @@ fn test_gkr_circuit_is_zero_gadget_simple() { ); // cond1 and cond2 - assert_eq!(cond_wire_out_ref.instances[0][0], Goldilocks::from(0)); - assert_eq!(cond_wire_out_ref.instances[0][1], Goldilocks::from(0)); + assert_eq!( + cond_wire_out_ref.get_base_field_vec()[0], + Goldilocks::from(0) + ); + assert_eq!( + cond_wire_out_ref.get_base_field_vec()[1], + Goldilocks::from(0) + ); // is_zero - assert_eq!(is_zero_wire_out_ref.instances[0][0], out_is_zero); + assert_eq!(is_zero_wire_out_ref.get_base_field_vec()[0], out_is_zero); // add prover-verifier process let mut prover_transcript = @@ -107,27 +115,20 @@ fn test_gkr_circuit_is_zero_gadget_simple() { let mut verifier_wires_out_evals = vec![]; let instance_num_vars = 1_u32.ilog2() as usize; for wire_out_id in vec![cond_wire_out_id, is_zero_wire_out_id] { - let lo_num_vars = wits_out[wire_out_id as usize].instances[0] - .len() - .next_power_of_two() - .ilog2() as usize; - let output_mle = wits_out[wire_out_id as usize] - .instances - .as_slice() - .mle(lo_num_vars, instance_num_vars); + let output_mle = &wits_out[wire_out_id as usize]; let prover_output_point = iter::repeat_with(|| { prover_transcript .get_and_append_challenge(b"output_point_test_gkr_circuit_IsZeroGadget_simple") .elements }) - .take(output_mle.num_vars) + .take(output_mle.num_vars()) .collect_vec(); let verifier_output_point = iter::repeat_with(|| { verifier_transcript .get_and_append_challenge(b"output_point_test_gkr_circuit_IsZeroGadget_simple") .elements }) - .take(output_mle.num_vars) + .take(output_mle.num_vars()) .collect_vec(); let prover_output_eval = output_mle.evaluate(&prover_output_point); let verifier_output_eval = output_mle.evaluate(&verifier_output_point); @@ -223,9 +224,9 @@ fn test_gkr_circuit_is_zero_gadget_u256() { // assign wire in let n_wits_in = circuit.n_witness_in; - let mut wits_in = vec![vec![]; n_wits_in]; - wits_in[value_wire_in_id as usize] = in_value; - wits_in[inv_wire_in_id as usize] = in_inv; + let mut wits_in = vec![DenseMultilinearExtension::::default(); n_wits_in]; + wits_in[value_wire_in_id as usize] = in_value.into_mle(); + wits_in[inv_wire_in_id as usize] = in_inv.into_mle(); let circuit_witness = { let challenges = vec![GoldilocksExt2::from(2)]; let mut circuit_witness = CircuitWitness::new(&circuit, challenges); @@ -249,11 +250,11 @@ fn test_gkr_circuit_is_zero_gadget_u256() { ); // cond1 and cond2 - for cond_item in cond_wire_out_ref.instances[0].clone().into_iter() { - assert_eq!(cond_item, Goldilocks::from(0)); - } + // for cond_item in cond_wire_out_ref.instances[0].clone().into_iter() { + // assert_eq!(cond_item, Goldilocks::from(0)); + // } // is_zero - assert_eq!(is_zero_wire_out_ref.instances[0][0], out_is_zero); + assert_eq!(is_zero_wire_out_ref.get_base_field_vec()[0], out_is_zero); // add prover-verifier process let mut prover_transcript = @@ -265,27 +266,20 @@ fn test_gkr_circuit_is_zero_gadget_u256() { let mut verifier_wires_out_evals = vec![]; let instance_num_vars = 1_u32.ilog2() as usize; for wire_out_id in vec![cond_wire_out_id, is_zero_wire_out_id] { - let lo_num_vars = wits_out[wire_out_id as usize].instances[0] - .len() - .next_power_of_two() - .ilog2() as usize; - let output_mle = wits_out[wire_out_id as usize] - .instances - .as_slice() - .mle(lo_num_vars, instance_num_vars); + let output_mle = &wits_out[wire_out_id as usize]; let prover_output_point = iter::repeat_with(|| { prover_transcript .get_and_append_challenge(b"output_point_test_gkr_circuit_IsZeroGadget_simple") .elements }) - .take(output_mle.num_vars) + .take(output_mle.num_vars()) .collect_vec(); let verifier_output_point = iter::repeat_with(|| { verifier_transcript .get_and_append_challenge(b"output_point_test_gkr_circuit_IsZeroGadget_simple") .elements }) - .take(output_mle.num_vars) + .take(output_mle.num_vars()) .collect_vec(); let prover_output_eval = output_mle.evaluate(&prover_output_point); let verifier_output_eval = output_mle.evaluate(&verifier_output_point); diff --git a/gkr/src/verifier.rs b/gkr/src/verifier.rs index d84ff6ce2..16377efad 100644 --- a/gkr/src/verifier.rs +++ b/gkr/src/verifier.rs @@ -8,8 +8,7 @@ use transcript::Transcript; use crate::{ error::GKRError, structs::{ - Circuit, GKRInputClaims, IOPProof, IOPProverStepMessage, IOPVerifierState, PointAndEval, - SumcheckStepType, + Circuit, GKRInputClaims, IOPProof, IOPVerifierState, PointAndEval, SumcheckStepType, }, }; @@ -58,10 +57,7 @@ impl IOPVerifierState { .verify_and_update_state_output_phase1_step1( circuit, step_proof, transcript, )?, - SumcheckStepType::OutputPhase1Step2 => verifier_state - .verify_and_update_state_output_phase1_step2( - circuit, step_proof, transcript, - )?, + SumcheckStepType::OutputPhase1Step2 => Ok(())?, SumcheckStepType::Phase1Step1 => verifier_state .verify_and_update_state_phase1_step1(circuit, step_proof, transcript)?, SumcheckStepType::Phase2Step1 => verifier_state @@ -133,7 +129,6 @@ impl IOPVerifierState { assert_point, // Default layer_id: 0, - g1_values: vec![], out_point: vec![], eq_y_ry: vec![], eq_x1_rx1: vec![], diff --git a/gkr/src/verifier/phase1.rs b/gkr/src/verifier/phase1.rs index 3bc6c13a3..6a3b0c947 100644 --- a/gkr/src/verifier/phase1.rs +++ b/gkr/src/verifier/phase1.rs @@ -53,7 +53,7 @@ impl IOPVerifierState { acc + point_and_eval.eval * alpha_pow }); - // Sumcheck: sigma = \sum_{t || y}(f1({t || y}) * (\sum_j g1^{(j)}({t || y}))) + // Sumcheck: sigma = \sum_{t || y}( \sum_j f1^{(j)}( t || y) * g1^{(j)}(t || y) ) // f1^{(j)}(y) = layers[i](t || y) // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) diff --git a/gkr/src/verifier/phase1_output.rs b/gkr/src/verifier/phase1_output.rs index 50aa74d67..8605a5863 100644 --- a/gkr/src/verifier/phase1_output.rs +++ b/gkr/src/verifier/phase1_output.rs @@ -2,7 +2,7 @@ use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; use itertools::{chain, izip, Itertools}; use multilinear_extensions::virtual_poly::{build_eq_x_r_vec, eq_eval, VPAuxInfo}; -use std::{iter, marker::PhantomData, mem}; +use std::{iter, marker::PhantomData}; use transcript::Transcript; use crate::{ @@ -39,11 +39,8 @@ impl IOPVerifierState { let lo_num_vars = circuit.layers[self.layer_id as usize].num_vars; let hi_num_vars = self.instance_num_vars; - // TODO: Double check the soundness here. - let assert_eq_yj_ryj = build_eq_x_r_vec(&self.assert_point[..lo_num_vars]); - - let mut sigma_1 = E::ZERO; - sigma_1 += izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()) + // sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) + let mut sigma_1 = izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()) .fold(E::ZERO, |acc, (point_and_eval, alpha_pow)| { acc + point_and_eval.eval * alpha_pow }); @@ -56,33 +53,50 @@ impl IOPVerifierState { .fold(E::ZERO, |acc, ((_, point_and_eval), alpha_pow)| { acc + point_and_eval.eval * alpha_pow }); + + let assert_eq_yj_ryj = build_eq_x_r_vec(&self.assert_point[..lo_num_vars]); sigma_1 += circuit .assert_consts .as_slice() .eval(&assert_eq_yj_ryj, &self.challenges) * alpha_pows.last().unwrap(); - // Sumcheck 1: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) - // f1^{(j)}(y) = layers[i](rt_j || y) - // g1^{(j)}(y) = \alpha^j copy_to_wits_out[j](ry_j, y) - // or \alpha^j assert_subset_eq[j](ry, y) + // Sumcheck: sigma = \sum_{t || y}( \sum_j f1^{(j)}( t || y) * g1^{(j)}(t || y) ) + // f1^{(j)}(y) = layers[i](t || y) + // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) + // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) + // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * assert_subset_eq(ry, y) let claim_1 = SumcheckState::verify( sigma_1, &step_msg.sumcheck_proof, &VPAuxInfo { max_degree: 2, - num_variables: lo_num_vars, + num_variables: lo_num_vars + hi_num_vars, phantom: PhantomData, }, transcript, ); + let claim1_point = claim_1.point.iter().map(|x| x.elements).collect_vec(); - let eq_y_ry = build_eq_x_r_vec(&claim1_point); - self.g1_values = chain![ + let claim1_point_lo_num_vars = claim1_point.len() - hi_num_vars; + let eq_y_ry = build_eq_x_r_vec(&claim1_point[..claim1_point_lo_num_vars]); + + assert_eq!(step_msg.sumcheck_eval_values.len(), 1); + let f_value = step_msg.sumcheck_eval_values[0]; + + let g_value: E = chain![ izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()).map( |(point_and_eval, alpha_pow)| { let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - eq_eval(&point_and_eval.point[..point_lo_num_vars], &claim1_point) * alpha_pow + let eq_t = eq_eval( + &point_and_eval.point[point_lo_num_vars..], + &claim1_point[(claim1_point.len() - hi_num_vars)..], + ); + let eq_y = eq_eval( + &point_and_eval.point[..point_lo_num_vars], + &claim1_point[..point_lo_num_vars], + ); + eq_t * eq_y * alpha_pow } ), izip!( @@ -94,8 +108,14 @@ impl IOPVerifierState { ) .map(|(copy_to, (_, point_and_eval), alpha_pow)| { let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; + let eq_t = eq_eval( + &point_and_eval.point[point_lo_num_vars..], + &claim1_point[(claim1_point.len() - hi_num_vars)..], + ); let eq_yj_ryj = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - copy_to.as_slice().eval_row_first(&eq_yj_ryj, &eq_y_ry) * alpha_pow + eq_t * copy_to.as_slice().eval_row_first(&eq_yj_ryj, &eq_y_ry) + * alpha_pow + * alpha_pow }), iter::once( circuit @@ -105,82 +125,18 @@ impl IOPVerifierState { * alpha_pows.last().unwrap() ) ] - .collect_vec(); + .sum(); - let f1_values = step_msg.sumcheck_eval_values.to_vec(); - let got_value_1 = f1_values - .iter() - .zip(self.g1_values.iter()) - .fold(E::ZERO, |acc, (&f1, g1)| acc + f1 * g1); + let got_value = f_value * g_value; end_timer!(timer); - if claim_1.expected_evaluation != got_value_1 { - return Err(GKRError::VerifyError("output phase1 step1 failed")); + if claim_1.expected_evaluation != got_value { + return Err(GKRError::VerifyError("phase1 output step1 failed")); } + self.to_next_step_point_and_eval = PointAndEval::new_from_ref(&claim1_point, &f_value); + self.to_next_phase_point_and_evals = vec![self.to_next_step_point_and_eval.clone()]; - self.to_next_step_point_and_eval = - PointAndEval::new(claim1_point, claim_1.expected_evaluation); - - Ok(()) - } - - pub(super) fn verify_and_update_state_output_phase1_step2( - &mut self, - _: &Circuit, - step_msg: IOPProverStepMessage, - transcript: &mut Transcript, - ) -> Result<(), GKRError> { - let timer = start_timer!(|| "Verifier sumcheck phase 1 step 2"); - let hi_num_vars = self.instance_num_vars; - - // Sumcheck 2: sigma = \sum_t( \sum_j( g2^{(j)}(t) ) ) * f2(t) - // f2(t) = layers[i](t || ry) - // g2^{(j)}(t) = \alpha^j copy_to[j](ry_j, r_y) eq(rt_j, t) - let claim_2 = SumcheckState::verify( - self.to_next_step_point_and_eval.eval, - &step_msg.sumcheck_proof, - &VPAuxInfo { - max_degree: 2, - num_variables: hi_num_vars, - phantom: PhantomData, - }, - transcript, - ); - let claim2_point = claim_2.point.iter().map(|x| x.elements).collect_vec(); - - let output_points = chain![ - self.to_next_phase_point_and_evals.iter().map(|x| &x.point), - self.subset_point_and_evals[self.layer_id as usize] - .iter() - .map(|x| &x.1.point), - iter::once(&self.assert_point), - ]; - let f2_value = step_msg.sumcheck_eval_values[0]; - let g2_value = output_points - .zip(self.g1_values.iter()) - .map(|(point, g1_value)| { - let point_lo_num_vars = point.len() - hi_num_vars; - *g1_value * eq_eval(&point[point_lo_num_vars..], &claim2_point) - }) - .fold(E::ZERO, |acc, value| acc + value); - - let got_value_2 = f2_value * g2_value; - - end_timer!(timer); - if claim_2.expected_evaluation != got_value_2 { - return Err(GKRError::VerifyError("output phase1 step2 failed")); - } - - self.to_next_step_point_and_eval = PointAndEval::new( - [ - mem::take(&mut self.to_next_step_point_and_eval.point), - claim2_point, - ] - .concat(), - f2_value, - ); self.subset_point_and_evals[self.layer_id as usize].clear(); - Ok(()) } } diff --git a/gkr/src/verifier/phase2.rs b/gkr/src/verifier/phase2.rs index 095864930..2382744c2 100644 --- a/gkr/src/verifier/phase2.rs +++ b/gkr/src/verifier/phase2.rs @@ -41,11 +41,11 @@ impl IOPVerifierState { .as_slice() .eval(&self.eq_y_ry, &self.challenges); - // Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * g1'_j(s1 || x1) - // f1(s1 || x1) = layers[i + 1](s1 || x1) - // g1(s1 || x1) = \sum_{s2}( \sum_{s3}( \sum_{x2}( \sum_{x3}( - // eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + 1](s3 || x3) - // ) ) ) ) + \sum_{s2}( \sum_{x2}( + // Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) + // * g1'_j(s1 || x1) f1(s1 || x1) = layers[i + 1](s1 || x1) g1(s1 || x1) = \sum_{s2}( + // \sum_{s3}( \sum_{x2}( \sum_{x3}( eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + + // 1](s2 || x2) * layers[i + + // 1](s3 || x3) ) ) ) ) + \sum_{s2}( \sum_{x2}( // eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s2 || x2) // ) ) + eq(rt, s1) * add(ry, x1) // f1'^{(j)}(s1 || x1) = subset[j][i](s1 || x1) diff --git a/gkr/src/verifier/phase2_input.rs b/gkr/src/verifier/phase2_input.rs index 9f63b1913..3b5fca220 100644 --- a/gkr/src/verifier/phase2_input.rs +++ b/gkr/src/verifier/phase2_input.rs @@ -63,6 +63,7 @@ impl IOPVerifierState { } return Ok(()); } + let lo_in_num_vars = lo_in_num_vars.unwrap(); let claim = SumcheckState::verify( diff --git a/gkr/src/verifier_v2.rs b/gkr/src/verifier_v2.rs deleted file mode 100644 index 2d9e4c1d6..000000000 --- a/gkr/src/verifier_v2.rs +++ /dev/null @@ -1,138 +0,0 @@ -use ark_std::{end_timer, start_timer}; -use ff_ext::ExtensionField; -use itertools::{izip, Itertools}; -use simple_frontend::structs::{ChallengeConst, LayerId}; -use std::collections::HashMap; -use transcript::Transcript; - -use crate::{ - error::GKRError, - structs::{ - Circuit, GKRInputClaims, IOPProof, IOPVerifierStateV2, PointAndEval, SumcheckStepType, - }, -}; - -mod phase1; -mod phase1_output; -mod phase2; -mod phase2_input; -mod phase2_linear; - -type SumcheckState = sumcheck::structs::IOPVerifierState; - -impl IOPVerifierStateV2 { - /// Verify process for data parallel circuits. - pub fn verify_parallel( - circuit: &Circuit, - challenges: &[E], - output_evals: Vec>, - wires_out_evals: Vec>, - proof: IOPProof, - instance_num_vars: usize, - transcript: &mut Transcript, - ) -> Result, GKRError> { - let timer = start_timer!(|| "Verification"); - let challenges = circuit.generate_basefield_challenges(challenges); - - let mut verifier_state = Self::verifier_init_parallel( - circuit.layers.len(), - challenges, - output_evals, - wires_out_evals, - instance_num_vars, - transcript, - circuit.layers[0].num_vars + instance_num_vars, - ); - - let mut sumcheck_proofs_iter = proof.sumcheck_proofs.into_iter(); - for layer_id in 0..circuit.layers.len() { - let timer = start_timer!(|| format!("Verify layer {}", layer_id)); - verifier_state.layer_id = layer_id as LayerId; - - let layer = &circuit.layers[layer_id as usize]; - for (step, step_proof) in izip!(layer.sumcheck_steps.iter(), &mut sumcheck_proofs_iter) - { - match step { - SumcheckStepType::OutputPhase1Step1 => verifier_state - .verify_and_update_state_output_phase1_step1( - circuit, step_proof, transcript, - )?, - SumcheckStepType::OutputPhase1Step2 => Ok(())?, - SumcheckStepType::Phase1Step1 => verifier_state - .verify_and_update_state_phase1_step1(circuit, step_proof, transcript)?, - SumcheckStepType::Phase2Step1 => verifier_state - .verify_and_update_state_phase2_step1(circuit, step_proof, transcript)?, - SumcheckStepType::Phase2Step2 => verifier_state - .verify_and_update_state_phase2_step2( - circuit, step_proof, transcript, false, - )?, - SumcheckStepType::Phase2Step2NoStep3 => verifier_state - .verify_and_update_state_phase2_step2( - circuit, step_proof, transcript, true, - )?, - SumcheckStepType::Phase2Step3 => verifier_state - .verify_and_update_state_phase2_step3(circuit, step_proof, transcript)?, - SumcheckStepType::LinearPhase2Step1 => verifier_state - .verify_and_update_state_linear_phase2_step1( - circuit, step_proof, transcript, - )?, - SumcheckStepType::InputPhase2Step1 => verifier_state - .verify_and_update_state_input_phase2_step1( - circuit, step_proof, transcript, - )?, - _ => unreachable!(), - } - } - end_timer!(timer); - } - - end_timer!(timer); - - Ok(GKRInputClaims { - point_and_evals: verifier_state.to_next_phase_point_and_evals, - }) - } - - /// Initialize verifying state for data parallel circuits. - fn verifier_init_parallel( - n_layers: usize, - challenges: HashMap>, - output_evals: Vec>, - wires_out_evals: Vec>, - instance_num_vars: usize, - transcript: &mut Transcript, - output_wit_num_vars: usize, - ) -> Self { - let mut subset_point_and_evals = vec![vec![]; n_layers]; - let to_next_step_point_and_eval = if !output_evals.is_empty() { - output_evals.last().unwrap().clone() - } else { - wires_out_evals.last().unwrap().clone() - }; - let assert_point = (0..output_wit_num_vars) - .map(|_| { - transcript - .get_and_append_challenge(b"assert_point") - .elements - }) - .collect_vec(); - let to_next_phase_point_and_evals = output_evals; - subset_point_and_evals[0] = wires_out_evals.into_iter().map(|p| (0, p)).collect(); - Self { - to_next_phase_point_and_evals, - subset_point_and_evals, - to_next_step_point_and_eval, - - challenges, - instance_num_vars, - - assert_point, - // Default - layer_id: 0, - out_point: vec![], - eq_y_ry: vec![], - eq_x1_rx1: vec![], - eq_x2_rx2: vec![], - } - } -} diff --git a/gkr/src/verifier_v2/phase1.rs b/gkr/src/verifier_v2/phase1.rs deleted file mode 100644 index 98208aa73..000000000 --- a/gkr/src/verifier_v2/phase1.rs +++ /dev/null @@ -1,127 +0,0 @@ -use ark_std::{end_timer, start_timer}; -use ff_ext::ExtensionField; -use itertools::{chain, izip, Itertools}; -use multilinear_extensions::virtual_poly::{build_eq_x_r_vec, eq_eval, VPAuxInfo}; -use std::marker::PhantomData; -use transcript::Transcript; - -use crate::{ - error::GKRError, - structs::{Circuit, IOPProverStepMessage, IOPVerifierStateV2, PointAndEval}, - utils::MatrixMLERowFirst, -}; - -use super::SumcheckState; - -impl IOPVerifierStateV2 { - pub(super) fn verify_and_update_state_phase1_step1( - &mut self, - circuit: &Circuit, - step_msg: IOPProverStepMessage, - transcript: &mut Transcript, - ) -> Result<(), GKRError> { - let timer = start_timer!(|| "Verifier sumcheck phase 1 step 1"); - let alpha = transcript - .get_and_append_challenge(b"combine subset evals") - .elements; - let total_length = self.to_next_phase_point_and_evals.len() - + self.subset_point_and_evals[self.layer_id as usize].len() - + 1; - let alpha_pows = { - let mut alpha_pows = vec![E::ONE; total_length]; - for i in 0..total_length.saturating_sub(1) { - alpha_pows[i + 1] = alpha_pows[i] * alpha; - } - alpha_pows - }; - - let lo_num_vars = circuit.layers[self.layer_id as usize].num_vars; - let hi_num_vars = self.instance_num_vars; - - // sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) - let mut sigma_1 = izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()) - .fold(E::ZERO, |acc, (point_and_eval, alpha_pow)| { - acc + point_and_eval.eval * alpha_pow - }); - sigma_1 += izip!( - self.subset_point_and_evals[self.layer_id as usize].iter(), - alpha_pows - .iter() - .skip(self.to_next_phase_point_and_evals.len()) - ) - .fold(E::ZERO, |acc, ((_, point_and_eval), alpha_pow)| { - acc + point_and_eval.eval * alpha_pow - }); - - // Sumcheck: sigma = \sum_{t || y}( \sum_j f1^{(j)}( t || y) * g1^{(j)}(t || y) ) - // f1^{(j)}(y) = layers[i](t || y) - // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) - // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) - let claim_1 = SumcheckState::verify( - sigma_1, - &step_msg.sumcheck_proof, - &VPAuxInfo { - max_degree: 2, - num_variables: lo_num_vars + hi_num_vars, - phantom: PhantomData, - }, - transcript, - ); - - let claim1_point = claim_1.point.iter().map(|x| x.elements).collect_vec(); - let claim1_point_lo_num_vars = claim1_point.len() - hi_num_vars; - let eq_y_ry = build_eq_x_r_vec(&claim1_point[..claim1_point_lo_num_vars]); - - assert_eq!(step_msg.sumcheck_eval_values.len(), 1); - let f_value = step_msg.sumcheck_eval_values[0]; - - let g_value: E = chain![ - izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()).map( - |(point_and_eval, alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let eq_t = eq_eval( - &point_and_eval.point[point_lo_num_vars..], - &claim1_point[(claim1_point.len() - hi_num_vars)..], - ); - let eq_y = eq_eval( - &point_and_eval.point[..point_lo_num_vars], - &claim1_point[..point_lo_num_vars], - ); - eq_t * eq_y * alpha_pow - } - ), - izip!( - self.subset_point_and_evals[self.layer_id as usize].iter(), - alpha_pows - .iter() - .skip(self.to_next_phase_point_and_evals.len()) - ) - .map(|((new_layer_id, point_and_eval), alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let eq_t = eq_eval( - &point_and_eval.point[point_lo_num_vars..], - &claim1_point[(claim1_point.len() - hi_num_vars)..], - ); - let eq_yj_ryj = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - eq_t * circuit.layers[self.layer_id as usize].copy_to[new_layer_id] - .as_slice() - .eval_row_first(&eq_yj_ryj, &eq_y_ry) - * alpha_pow - }), - ] - .sum(); - - let got_value = f_value * g_value; - - end_timer!(timer); - if claim_1.expected_evaluation != got_value { - return Err(GKRError::VerifyError("phase1 step1 failed")); - } - self.to_next_step_point_and_eval = PointAndEval::new_from_ref(&claim1_point, &f_value); - self.to_next_phase_point_and_evals = vec![self.to_next_step_point_and_eval.clone()]; - - self.subset_point_and_evals[self.layer_id as usize].clear(); - - Ok(()) - } -} diff --git a/gkr/src/verifier_v2/phase1_output.rs b/gkr/src/verifier_v2/phase1_output.rs deleted file mode 100644 index 816e667f7..000000000 --- a/gkr/src/verifier_v2/phase1_output.rs +++ /dev/null @@ -1,142 +0,0 @@ -use ark_std::{end_timer, start_timer}; -use ff_ext::ExtensionField; -use itertools::{chain, izip, Itertools}; -use multilinear_extensions::virtual_poly::{build_eq_x_r_vec, eq_eval, VPAuxInfo}; -use std::{iter, marker::PhantomData}; -use transcript::Transcript; - -use crate::{ - circuit::EvaluateGateCIn, - error::GKRError, - structs::{Circuit, IOPProverStepMessage, IOPVerifierStateV2, PointAndEval}, - utils::MatrixMLERowFirst, -}; - -use super::SumcheckState; - -impl IOPVerifierStateV2 { - pub(super) fn verify_and_update_state_output_phase1_step1( - &mut self, - circuit: &Circuit, - step_msg: IOPProverStepMessage, - transcript: &mut Transcript, - ) -> Result<(), GKRError> { - let timer = start_timer!(|| "Verifier sumcheck phase 1 step 1"); - let alpha = transcript - .get_and_append_challenge(b"combine subset evals") - .elements; - let total_length = self.to_next_phase_point_and_evals.len() - + self.subset_point_and_evals[self.layer_id as usize].len() - + 1; - let alpha_pows = { - let mut alpha_pows = vec![E::ONE; total_length]; - for i in 0..total_length.saturating_sub(1) { - alpha_pows[i + 1] = alpha_pows[i] * alpha; - } - alpha_pows - }; - - let lo_num_vars = circuit.layers[self.layer_id as usize].num_vars; - let hi_num_vars = self.instance_num_vars; - - // sigma = \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) - let mut sigma_1 = izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()) - .fold(E::ZERO, |acc, (point_and_eval, alpha_pow)| { - acc + point_and_eval.eval * alpha_pow - }); - sigma_1 += izip!( - self.subset_point_and_evals[self.layer_id as usize].iter(), - alpha_pows - .iter() - .skip(self.to_next_phase_point_and_evals.len()) - ) - .fold(E::ZERO, |acc, ((_, point_and_eval), alpha_pow)| { - acc + point_and_eval.eval * alpha_pow - }); - - let assert_eq_yj_ryj = build_eq_x_r_vec(&self.assert_point[..lo_num_vars]); - sigma_1 += circuit - .assert_consts - .as_slice() - .eval(&assert_eq_yj_ryj, &self.challenges) - * alpha_pows.last().unwrap(); - - // Sumcheck: sigma = \sum_{t || y}( \sum_j f1^{(j)}( t || y) * g1^{(j)}(t || y) ) - // f1^{(j)}(y) = layers[i](t || y) - // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * eq(ry_j, y) - // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * copy_to[j](ry_j, y) - // g1^{(j)}(t || y) = \alpha^j * eq(rt_j, t) * assert_subset_eq(ry, y) - let claim_1 = SumcheckState::verify( - sigma_1, - &step_msg.sumcheck_proof, - &VPAuxInfo { - max_degree: 2, - num_variables: lo_num_vars + hi_num_vars, - phantom: PhantomData, - }, - transcript, - ); - - let claim1_point = claim_1.point.iter().map(|x| x.elements).collect_vec(); - let claim1_point_lo_num_vars = claim1_point.len() - hi_num_vars; - let eq_y_ry = build_eq_x_r_vec(&claim1_point[..claim1_point_lo_num_vars]); - - assert_eq!(step_msg.sumcheck_eval_values.len(), 1); - let f_value = step_msg.sumcheck_eval_values[0]; - - let g_value: E = chain![ - izip!(self.to_next_phase_point_and_evals.iter(), alpha_pows.iter()).map( - |(point_and_eval, alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let eq_t = eq_eval( - &point_and_eval.point[point_lo_num_vars..], - &claim1_point[(claim1_point.len() - hi_num_vars)..], - ); - let eq_y = eq_eval( - &point_and_eval.point[..point_lo_num_vars], - &claim1_point[..point_lo_num_vars], - ); - eq_t * eq_y * alpha_pow - } - ), - izip!( - circuit.copy_to_wits_out.iter(), - self.subset_point_and_evals[self.layer_id as usize].iter(), - alpha_pows - .iter() - .skip(self.to_next_phase_point_and_evals.len()) - ) - .map(|(copy_to, (_, point_and_eval), alpha_pow)| { - let point_lo_num_vars = point_and_eval.point.len() - hi_num_vars; - let eq_t = eq_eval( - &point_and_eval.point[point_lo_num_vars..], - &claim1_point[(claim1_point.len() - hi_num_vars)..], - ); - let eq_yj_ryj = build_eq_x_r_vec(&point_and_eval.point[..point_lo_num_vars]); - eq_t * copy_to.as_slice().eval_row_first(&eq_yj_ryj, &eq_y_ry) - * alpha_pow - * alpha_pow - }), - iter::once( - circuit - .assert_consts - .as_slice() - .eval_subset_eq(&assert_eq_yj_ryj, &eq_y_ry) - * alpha_pows.last().unwrap() - ) - ] - .sum(); - - let got_value = f_value * g_value; - - end_timer!(timer); - if claim_1.expected_evaluation != got_value { - return Err(GKRError::VerifyError("phase1 output step1 failed")); - } - self.to_next_step_point_and_eval = PointAndEval::new_from_ref(&claim1_point, &f_value); - self.to_next_phase_point_and_evals = vec![self.to_next_step_point_and_eval.clone()]; - - self.subset_point_and_evals[self.layer_id as usize].clear(); - Ok(()) - } -} diff --git a/gkr/src/verifier_v2/phase2.rs b/gkr/src/verifier_v2/phase2.rs deleted file mode 100644 index bc44b3817..000000000 --- a/gkr/src/verifier_v2/phase2.rs +++ /dev/null @@ -1,240 +0,0 @@ -use ark_std::{end_timer, start_timer}; -use ff_ext::ExtensionField; -use itertools::{chain, izip, Itertools}; -use multilinear_extensions::virtual_poly::{build_eq_x_r_vec, eq_eval, VPAuxInfo}; -use std::mem; -use transcript::Transcript; - -use crate::{ - circuit::{EvaluateGate1In, EvaluateGate2In, EvaluateGate3In, EvaluateGateCIn}, - error::GKRError, - structs::{Circuit, IOPProverStepMessage, IOPVerifierStateV2, PointAndEval}, - utils::{eq3_eval, eq4_eval, MatrixMLEColumnFirst}, -}; - -use super::SumcheckState; - -impl IOPVerifierStateV2 { - pub(super) fn verify_and_update_state_phase2_step1( - &mut self, - circuit: &Circuit, - step_msg: IOPProverStepMessage, - transcript: &mut Transcript, - ) -> Result<(), GKRError> { - let timer = start_timer!(|| "Verifier sumcheck phase 2 step 1"); - let layer = &circuit.layers[self.layer_id as usize]; - let lo_out_num_vars = layer.num_vars; - let lo_in_num_vars = layer.max_previous_num_vars; - let hi_num_vars = self.instance_num_vars; - let in_num_vars = lo_in_num_vars + hi_num_vars; - - self.out_point = mem::take(&mut self.to_next_step_point_and_eval.point); - let lo_point = &self.out_point[..lo_out_num_vars]; - let hi_point = &self.out_point[lo_out_num_vars..]; - - self.eq_y_ry = build_eq_x_r_vec(lo_point); - - // sigma = layers[i](rt || ry) - add_const(ry), - let sumcheck_sigma = self.to_next_step_point_and_eval.eval - - layer - .add_consts - .as_slice() - .eval(&self.eq_y_ry, &self.challenges); - - // Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) - // * g1'_j(s1 || x1) f1(s1 || x1) = layers[i + 1](s1 || x1) - // g1(s1 || x1) = \sum_{s2}( \sum_{s3}( \sum_{x2}( \sum_{x3}( - // eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + - // 1](s3 || x3) ) ) ) ) + \sum_{s2}( \sum_{x2}( - // eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s2 || x2) - // ) ) + eq(rt, s1) * add(ry, x1) - // f1'^{(j)}(s1 || x1) = subset[j][i](s1 || x1) - // g1'^{(j)}(s1 || x1) = eq(rt, s1) paste_from[j](ry, x1) - let claim_1 = SumcheckState::verify( - sumcheck_sigma, - &step_msg.sumcheck_proof, - &VPAuxInfo { - max_degree: 2, - num_variables: in_num_vars, - phantom: std::marker::PhantomData, - }, - transcript, - ); - let claim1_point = claim_1.point.iter().map(|x| x.elements).collect_vec(); - let hi_point_sc1 = &claim1_point[lo_in_num_vars..]; - - let (f1_values, received_g1_values) = step_msg - .sumcheck_eval_values - .split_at(step_msg.sumcheck_eval_values.len() - 1); - - let hi_eq_eval = eq_eval(&hi_point, hi_point_sc1); - self.eq_x1_rx1 = build_eq_x_r_vec(&claim1_point[..lo_in_num_vars]); - let g1_values_iter = chain![ - received_g1_values.iter().cloned(), - layer.paste_from.iter().map(|(_, paste_from)| { - hi_eq_eval - * paste_from - .as_slice() - .eval_col_first(&self.eq_y_ry, &self.eq_x1_rx1) - }) - ]; - let got_value_1 = - izip!(f1_values.iter(), g1_values_iter).fold(E::ZERO, |acc, (&f1, g1)| acc + f1 * g1); - - end_timer!(timer); - if claim_1.expected_evaluation != got_value_1 { - return Err(GKRError::VerifyError("phase2 step1 failed")); - } - - self.to_next_phase_point_and_evals = - vec![PointAndEval::new_from_ref(&claim1_point, &f1_values[0])]; - izip!(layer.paste_from.iter(), f1_values.iter().skip(1)).for_each( - |((&old_layer_id, _), &subset_value)| { - self.subset_point_and_evals[old_layer_id as usize].push(( - self.layer_id, - PointAndEval::new_from_ref(&claim1_point, &subset_value), - )); - }, - ); - self.to_next_step_point_and_eval = PointAndEval::new(claim1_point, received_g1_values[0]); - - Ok(()) - } - - pub(super) fn verify_and_update_state_phase2_step2( - &mut self, - circuit: &Circuit, - step_msg: IOPProverStepMessage, - transcript: &mut Transcript, - no_step3: bool, - ) -> Result<(), GKRError> { - let timer = start_timer!(|| "Verifier sumcheck phase 2 step 2"); - let layer = &circuit.layers[self.layer_id as usize]; - let lo_out_num_vars = layer.num_vars; - let lo_in_num_vars = layer.max_previous_num_vars; - let hi_num_vars = self.instance_num_vars; - let in_num_vars = lo_in_num_vars + hi_num_vars; - - // sigma = g1(rs1 || rx1) - eq(rt, rs1) * add(ry, rx1) - let sumcheck_sigma = self.to_next_step_point_and_eval.eval - - eq_eval( - &self.out_point[lo_out_num_vars..], - &self.to_next_phase_point_and_evals[0].point[lo_in_num_vars..], - ) * layer - .adds - .as_slice() - .eval(&self.eq_y_ry, &self.eq_x1_rx1, &self.challenges); - - // Sumcheck 2 sigma = \sum_{s2 || x2} f2(s2 || x2) * g2(s2 || x2) - // f2(s2 || x2) = layers[i + 1](s2 || x2) - // g2(s2 || x2) = \sum_{s3}( \sum_{x3}( - // eq(rt, rs1, s2, s3) * mul3(ry, rx1, x2, x3) * layers[i + 1](s3 || x3) - // ) ) + eq(rt, rs1, s2) * mul2(ry, rx1, x2)} - let claim_2 = SumcheckState::verify( - sumcheck_sigma, - &step_msg.sumcheck_proof, - &VPAuxInfo { - max_degree: 2, - num_variables: in_num_vars, - phantom: std::marker::PhantomData, - }, - transcript, - ); - let claim2_point = claim_2.point.iter().map(|x| x.elements).collect_vec(); - let f2_value = step_msg.sumcheck_eval_values[0]; - - self.eq_x2_rx2 = build_eq_x_r_vec(&claim2_point[..lo_in_num_vars]); - let g2_value = if no_step3 { - eq3_eval( - &self.out_point[lo_out_num_vars..], - &self.to_next_phase_point_and_evals[0].point[lo_in_num_vars..], - &claim2_point[lo_in_num_vars..], - ) * layer.mul2s.as_slice().eval( - &self.eq_y_ry, - &self.eq_x1_rx1, - &self.eq_x2_rx2, - &self.challenges, - ) - } else { - step_msg.sumcheck_eval_values[1] - }; - let got_value_2 = f2_value * g2_value; - - end_timer!(timer); - if claim_2.expected_evaluation != got_value_2 { - return Err(GKRError::VerifyError("phase2 step2 failed")); - } - - self.to_next_phase_point_and_evals - .push(PointAndEval::new_from_ref(&claim2_point, &f2_value)); - self.to_next_step_point_and_eval = PointAndEval::new(claim2_point, g2_value); - Ok(()) - } - - pub(super) fn verify_and_update_state_phase2_step3( - &mut self, - circuit: &Circuit, - step_msg: IOPProverStepMessage, - transcript: &mut Transcript, - ) -> Result<(), GKRError> { - let timer = start_timer!(|| "Verifier sumcheck phase 2 step 3"); - let layer = &circuit.layers[self.layer_id as usize]; - let lo_out_num_vars = layer.num_vars; - let lo_in_num_vars = layer.max_previous_num_vars; - let hi_num_vars = self.instance_num_vars; - let in_num_vars = lo_in_num_vars + hi_num_vars; - - // sigma = g2(rs2 || rx2) - eq(rt, rs1, rs2) * mul2(ry, rx1, rx2) - let sumcheck_sigma = self.to_next_step_point_and_eval.eval - - eq3_eval( - &self.out_point[lo_out_num_vars..], - &self.to_next_phase_point_and_evals[0].point[lo_in_num_vars..], - &self.to_next_phase_point_and_evals[1].point[lo_in_num_vars..], - ) * layer.mul2s.as_slice().eval( - &self.eq_y_ry, - &self.eq_x1_rx1, - &self.eq_x2_rx2, - &self.challenges, - ); - - // Sumcheck 3 sigma = \sum_{s3 || x3} f3(s3 || x3) * g3(s3 || x3) - // f3(s3 || x3) = layers[i + 1](s3 || x3) - // g3(s3 || x3) = eq(rt, rs1, rs2, s3) * mul3(ry, rx1, rx2, x3) - let claim_3 = SumcheckState::verify( - sumcheck_sigma, - &step_msg.sumcheck_proof, - &VPAuxInfo { - max_degree: 2, - num_variables: in_num_vars, - phantom: std::marker::PhantomData, - }, - transcript, - ); - let claim3_point = claim_3.point.iter().map(|x| x.elements).collect_vec(); - let eq_x3_rx3 = build_eq_x_r_vec(&claim3_point[..lo_in_num_vars]); - let f3_value = step_msg.sumcheck_eval_values[0]; - let g3_value = eq4_eval( - &&self.out_point[lo_out_num_vars..], - &self.to_next_phase_point_and_evals[0].point[lo_in_num_vars..], - &self.to_next_phase_point_and_evals[1].point[lo_in_num_vars..], - &claim3_point[lo_in_num_vars..], - ) * layer.mul3s.as_slice().eval( - &self.eq_y_ry, - &self.eq_x1_rx1, - &self.eq_x2_rx2, - &eq_x3_rx3, - &self.challenges, - ); - - let got_value_3 = f3_value * g3_value; - end_timer!(timer); - if claim_3.expected_evaluation != got_value_3 { - return Err(GKRError::VerifyError("phase2 step3 failed")); - } - - self.to_next_phase_point_and_evals - .push(PointAndEval::new_from_ref(&claim3_point, &f3_value)); - self.to_next_step_point_and_eval = PointAndEval::new(claim3_point, E::ZERO); - Ok(()) - } -} diff --git a/gkr/src/verifier_v2/phase2_input.rs b/gkr/src/verifier_v2/phase2_input.rs deleted file mode 100644 index 560066e30..000000000 --- a/gkr/src/verifier_v2/phase2_input.rs +++ /dev/null @@ -1,146 +0,0 @@ -use ark_std::{end_timer, start_timer}; -use ff_ext::ExtensionField; -use itertools::{chain, izip, Itertools}; -use multilinear_extensions::virtual_poly::{build_eq_x_r_vec, VPAuxInfo}; -use std::mem; -use sumcheck::util::ceil_log2; -use transcript::Transcript; - -use crate::{ - circuit::EvaluateGateCIn, - error::GKRError, - structs::{Circuit, IOPProverStepMessage, IOPVerifierStateV2, PointAndEval}, - utils::{counter_eval, i64_to_field, segment_eval_greater_than}, -}; - -use super::SumcheckState; - -impl IOPVerifierStateV2 { - pub(super) fn verify_and_update_state_input_phase2_step1( - &mut self, - circuit: &Circuit, - step_msg: IOPProverStepMessage, - transcript: &mut Transcript, - ) -> Result<(), GKRError> { - let timer = start_timer!(|| "Verifier sumcheck phase 2 step 1"); - let layer = &circuit.layers[self.layer_id as usize]; - let lo_out_num_vars = layer.num_vars; - let lo_in_num_vars = circuit.max_wit_in_num_vars; - let hi_num_vars = self.instance_num_vars; - - self.out_point = mem::take(&mut self.to_next_step_point_and_eval.point); - let lo_point = &self.out_point[..lo_out_num_vars]; - let hi_point = &self.out_point[lo_out_num_vars..]; - - self.eq_y_ry = build_eq_x_r_vec(lo_point); - - let g_value_const = circuit - .paste_from_consts_in - .iter() - .map(|(c, (l, r))| { - let c = i64_to_field::(*c); - let segment_greater_than_l_1 = if *l == 0 { - E::ONE - } else { - segment_eval_greater_than(l - 1, lo_point) - }; - let segment_greater_than_r_1 = segment_eval_greater_than(r - 1, lo_point); - (segment_greater_than_l_1 - segment_greater_than_r_1) * &c - }) - .sum::(); - - let mut sumcheck_sigma = self.to_next_step_point_and_eval.eval - g_value_const; - if !layer.add_consts.is_empty() { - sumcheck_sigma -= layer - .add_consts - .as_slice() - .eval(&self.eq_y_ry, &self.challenges); - } - - if lo_in_num_vars.is_none() { - if sumcheck_sigma != E::ZERO { - return Err(GKRError::VerifyError("input phase2 step1 failed")); - } - return Ok(()); - } - - let lo_in_num_vars = lo_in_num_vars.unwrap(); - - let claim = SumcheckState::verify( - sumcheck_sigma, - &step_msg.sumcheck_proof, - &VPAuxInfo { - max_degree: 2, - num_variables: lo_in_num_vars, - phantom: std::marker::PhantomData, - }, - transcript, - ); - - let claim_point = claim.point.iter().map(|x| x.elements).collect_vec(); - - self.eq_x1_rx1 = build_eq_x_r_vec(&claim_point); - let g_values_iter = chain![ - circuit.paste_from_wits_in.iter().cloned(), - circuit - .paste_from_counter_in - .iter() - .map(|(_, (l, r))| (*l, *r)) - ] - .map(|(l, r)| { - (l..r) - .map(|i| self.eq_y_ry[i] * self.eq_x1_rx1[i - l]) - .sum::() - }); - - // TODO: Double check here. - let f_counter_values = circuit - .paste_from_counter_in - .iter() - .map(|(num_vars, _)| { - let point = [&claim_point[..*num_vars], hi_point].concat(); - counter_eval(num_vars + hi_num_vars, &point) - * claim_point[*num_vars..] - .iter() - .map(|x| E::ONE - *x) - .product::() - }) - .collect_vec(); - let got_value = izip!( - chain![ - step_msg.sumcheck_eval_values.iter(), - f_counter_values.iter() - ], - g_values_iter - ) - .map(|(f, g)| *f * g) - .sum::(); - - self.to_next_phase_point_and_evals = izip!( - circuit.paste_from_wits_in.iter(), - step_msg.sumcheck_eval_values.into_iter() - ) - .map(|((l, r), eval)| { - let num_vars = ceil_log2(*r - *l); - let point = [&claim_point[..num_vars], hi_point].concat(); - let wit_in_eval = eval - * claim_point[num_vars..] - .iter() - .map(|x| E::ONE - *x) - .product::() - .invert() - .unwrap(); - PointAndEval::new_from_ref(&point, &wit_in_eval) - }) - .collect_vec(); - self.to_next_step_point_and_eval = - PointAndEval::new([&claim_point, hi_point].concat(), E::ZERO); - - end_timer!(timer); - if claim.expected_evaluation != got_value { - return Err(GKRError::VerifyError("input phase2 step1 failed")); - } - - Ok(()) - } -} diff --git a/gkr/src/verifier_v2/phase2_linear.rs b/gkr/src/verifier_v2/phase2_linear.rs deleted file mode 100644 index 9586cbaa1..000000000 --- a/gkr/src/verifier_v2/phase2_linear.rs +++ /dev/null @@ -1,97 +0,0 @@ -use ark_std::{end_timer, start_timer}; -use ff_ext::ExtensionField; -use itertools::{chain, izip, Itertools}; -use multilinear_extensions::virtual_poly::{build_eq_x_r_vec, VPAuxInfo}; -use std::{iter, mem}; -use transcript::Transcript; - -use crate::{ - circuit::{EvaluateGate1In, EvaluateGateCIn}, - error::GKRError, - structs::{Circuit, IOPProverStepMessage, IOPVerifierStateV2, PointAndEval}, - utils::MatrixMLEColumnFirst, -}; - -use super::SumcheckState; - -impl IOPVerifierStateV2 { - pub(super) fn verify_and_update_state_linear_phase2_step1( - &mut self, - circuit: &Circuit, - step_msg: IOPProverStepMessage, - transcript: &mut Transcript, - ) -> Result<(), GKRError> { - let timer = start_timer!(|| "Verifier sumcheck phase 2 step 1"); - let layer = &circuit.layers[self.layer_id as usize]; - let lo_out_num_vars = layer.num_vars; - let lo_in_num_vars = layer.max_previous_num_vars; - - self.out_point = mem::take(&mut self.to_next_step_point_and_eval.point); - let lo_point = &self.out_point[..lo_out_num_vars]; - - self.eq_y_ry = build_eq_x_r_vec(lo_point); - - // sigma = layers[i](rt || ry) - add_const(ry), - let sumcheck_sigma = self.to_next_step_point_and_eval.eval - - layer - .add_consts - .as_slice() - .eval(&self.eq_y_ry, &self.challenges); - - // Sumcheck 1: sigma = \sum_{x1} f1(x1) * g1(x1) + \sum_j f1'_j(x1) * g1'_j(x1) - // sigma = layers[i](rt || ry) - add_const(ry), - // f1(x1) = layers[i + 1](rt || x1) - // g1(x1) = add(ry, x1) - // f1'^{(j)}(x1) = subset[j][i](rt || x1) - // g1'^{(j)}(x1) = paste_from[j](ry, x1) - let claim_1 = SumcheckState::verify( - sumcheck_sigma, - &step_msg.sumcheck_proof, - &VPAuxInfo { - max_degree: 2, - num_variables: lo_in_num_vars, - phantom: std::marker::PhantomData, - }, - transcript, - ); - let claim1_point = claim_1.point.iter().map(|x| x.elements).collect_vec(); - - self.eq_x1_rx1 = build_eq_x_r_vec(&claim1_point[..lo_in_num_vars]); - let g1_values_iter = chain![ - iter::once(layer.adds.as_slice().eval( - &self.eq_y_ry, - &self.eq_x1_rx1, - &self.challenges - )), - layer.paste_from.iter().map(|(_, paste_from)| { - paste_from - .as_slice() - .eval_col_first(&self.eq_y_ry, &self.eq_x1_rx1) - }) - ]; - - let f1_values = &step_msg.sumcheck_eval_values; - let got_value_1 = - izip!(f1_values.iter(), g1_values_iter).fold(E::ZERO, |acc, (&f1, g1)| acc + f1 * g1); - - end_timer!(timer); - if claim_1.expected_evaluation != got_value_1 { - return Err(GKRError::VerifyError("phase2 step1 failed")); - } - - let new_point = [&claim1_point, &self.out_point[lo_out_num_vars..]].concat(); - self.to_next_phase_point_and_evals = - vec![PointAndEval::new_from_ref(&new_point, &f1_values[0])]; - izip!(layer.paste_from.iter(), f1_values.iter().skip(1)).for_each( - |((&old_layer_id, _), &subset_value)| { - self.subset_point_and_evals[old_layer_id as usize].push(( - self.layer_id, - PointAndEval::new_from_ref(&new_point, &subset_value), - )); - }, - ); - self.to_next_step_point_and_eval = self.to_next_phase_point_and_evals[0].clone(); - - Ok(()) - } -} diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 66505390e..032d6f892 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -75,6 +75,20 @@ impl Into> for Vec> +pub trait IntoMLE: Sized { + /// Converts this type into the (usually inferred) input type. + fn into_mle(self) -> T; +} + +impl IntoMLE> for Vec { + fn into_mle(mut self) -> DenseMultilinearExtension { + let next_pow2 = ceil_log2(self.len()); + self.resize(next_pow2, E::BaseField::ZERO); + DenseMultilinearExtension::from_evaluations_vec(next_pow2, self) + } +} + #[derive(Clone, PartialEq, Eq, Hash, Default, Debug, Serialize, Deserialize)] #[serde(untagged)] /// Differentiate inner vector on base/extension field.