From 1b2c994bc35e65e629397cf57ad5b261284649c0 Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Mon, 6 Nov 2023 11:25:03 -0800 Subject: [PATCH] optimize ppsnark (#252) * optimize; check claims about Az,Bz,Cz * add rayon --- Cargo.toml | 2 +- src/spartan/ppsnark.rs | 273 ++++++++++++++++++++++++----------------- 2 files changed, 162 insertions(+), 113 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3c9e6ce23..38cf438ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nova-snark" -version = "0.29.0" +version = "0.30.0" authors = ["Srinath Setty "] edition = "2021" description = "Recursive zkSNARKs without trusted setup" diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index 6fa084dfe..b0a34ae5d 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -266,7 +266,7 @@ pub trait SumcheckEngine: Send + Sync { fn final_claims(&self) -> Vec>; } -struct LookupSumcheckInstance { +struct MemorySumcheckInstance { // row w_plus_r_row: MultilinearPolynomial, t_plus_r_row: MultilinearPolynomial, @@ -283,9 +283,12 @@ struct LookupSumcheckInstance { // eq poly_eq: MultilinearPolynomial, + + // zero polynomial + poly_zero: MultilinearPolynomial, } -impl LookupSumcheckInstance { +impl MemorySumcheckInstance { pub fn new( ck: &CommitmentKey, r: &G::Scalar, @@ -426,6 +429,8 @@ impl LookupSumcheckInstance { w_plus_r_inv_col.clone(), ]; + let zero = vec![G::Scalar::ZERO; t_plus_r_inv_row.len()]; + Ok(( Self { w_plus_r_row: MultilinearPolynomial::new(w_plus_r_row?), @@ -439,6 +444,7 @@ impl LookupSumcheckInstance { w_plus_r_inv_col: MultilinearPolynomial::new(w_plus_r_inv_col), ts_col: MultilinearPolynomial::new(ts_col), poly_eq, + poly_zero: MultilinearPolynomial::new(zero), }, comm_vec, poly_vec, @@ -446,7 +452,7 @@ impl LookupSumcheckInstance { } } -impl SumcheckEngine for LookupSumcheckInstance { +impl SumcheckEngine for MemorySumcheckInstance { fn initial_claims(&self) -> Vec { vec![G::Scalar::ZERO; 6] } @@ -467,8 +473,6 @@ impl SumcheckEngine for LookupSumcheckInstance { } fn evaluation_points(&self) -> Vec> { - let poly_zero = MultilinearPolynomial::new(vec![G::Scalar::ZERO; self.w_plus_r_row.len()]); - let comb_func = |poly_A_comp: &G::Scalar, poly_B_comp: &G::Scalar, _poly_C_comp: &G::Scalar| @@ -493,7 +497,7 @@ impl SumcheckEngine for LookupSumcheckInstance { SumcheckProof::::compute_eval_points_cubic( &self.t_plus_r_inv_row, &self.w_plus_r_inv_row, - &poly_zero, + &self.poly_zero, &comb_func, ); @@ -501,7 +505,7 @@ impl SumcheckEngine for LookupSumcheckInstance { SumcheckProof::::compute_eval_points_cubic( &self.t_plus_r_inv_col, &self.w_plus_r_inv_col, - &poly_zero, + &self.poly_zero, &comb_func, ); @@ -519,7 +523,7 @@ impl SumcheckEngine for LookupSumcheckInstance { &self.poly_eq, &self.w_plus_r_inv_row, &self.w_plus_r_row, - &poly_zero, + &self.poly_zero, &comb_func2, ); @@ -537,7 +541,7 @@ impl SumcheckEngine for LookupSumcheckInstance { &self.poly_eq, &self.w_plus_r_inv_col, &self.w_plus_r_col, - &poly_zero, + &self.poly_zero, &comb_func2, ); @@ -591,11 +595,38 @@ struct OuterSumcheckInstance { poly_Az: MultilinearPolynomial, poly_Bz: MultilinearPolynomial, poly_uCz_E: MultilinearPolynomial, + + poly_Mz: MultilinearPolynomial, + eval_Mz_at_tau: G::Scalar, + + poly_zero: MultilinearPolynomial, +} + +impl OuterSumcheckInstance { + pub fn new( + tau: Vec, + Az: Vec, + Bz: Vec, + uCz_E: Vec, + Mz: Vec, + eval_Mz_at_tau: &G::Scalar, + ) -> Self { + let zero = vec![G::Scalar::ZERO; tau.len()]; + Self { + poly_tau: MultilinearPolynomial::new(tau), + poly_Az: MultilinearPolynomial::new(Az), + poly_Bz: MultilinearPolynomial::new(Bz), + poly_uCz_E: MultilinearPolynomial::new(uCz_E), + poly_Mz: MultilinearPolynomial::new(Mz), + eval_Mz_at_tau: *eval_Mz_at_tau, + poly_zero: MultilinearPolynomial::new(zero), + } + } } impl SumcheckEngine for OuterSumcheckInstance { fn initial_claims(&self) -> Vec { - vec![G::Scalar::ZERO] + vec![G::Scalar::ZERO, self.eval_Mz_at_tau] } fn degree(&self) -> usize { @@ -606,16 +637,11 @@ impl SumcheckEngine for OuterSumcheckInstance { assert_eq!(self.poly_tau.len(), self.poly_Az.len()); assert_eq!(self.poly_tau.len(), self.poly_Bz.len()); assert_eq!(self.poly_tau.len(), self.poly_uCz_E.len()); + assert_eq!(self.poly_tau.len(), self.poly_Mz.len()); self.poly_tau.len() } fn evaluation_points(&self) -> Vec> { - let (poly_A, poly_B, poly_C, poly_D) = ( - &self.poly_tau, - &self.poly_Az, - &self.poly_Bz, - &self.poly_uCz_E, - ); let comb_func = |poly_A_comp: &G::Scalar, poly_B_comp: &G::Scalar, @@ -623,12 +649,32 @@ impl SumcheckEngine for OuterSumcheckInstance { poly_D_comp: &G::Scalar| -> G::Scalar { *poly_A_comp * (*poly_B_comp * *poly_C_comp - *poly_D_comp) }; - let (eval_point_0, eval_point_2, eval_point_3) = + let (eval_point_h_0, eval_point_h_2, eval_point_h_3) = SumcheckProof::::compute_eval_points_cubic_with_additive_term( - poly_A, poly_B, poly_C, poly_D, &comb_func, + &self.poly_tau, + &self.poly_Az, + &self.poly_Bz, + &self.poly_uCz_E, + &comb_func, ); - vec![vec![eval_point_0, eval_point_2, eval_point_3]] + let comb_func2 = |poly_A_comp: &G::Scalar, + poly_B_comp: &G::Scalar, + _poly_C_comp: &G::Scalar| + -> G::Scalar { *poly_A_comp * *poly_B_comp }; + + let (eval_point_e_0, eval_point_e_2, eval_point_e_3) = + SumcheckProof::::compute_eval_points_cubic( + &self.poly_tau, + &self.poly_Mz, + &self.poly_zero, + &comb_func2, + ); + + vec![ + vec![eval_point_h_0, eval_point_h_2, eval_point_h_3], + vec![eval_point_e_0, eval_point_e_2, eval_point_e_3], + ] } fn bound(&mut self, r: &G::Scalar) { @@ -637,6 +683,7 @@ impl SumcheckEngine for OuterSumcheckInstance { &mut self.poly_Az, &mut self.poly_Bz, &mut self.poly_uCz_E, + &mut self.poly_Mz, ] .par_iter_mut() .for_each(|poly| poly.bound_poly_var_top(r)); @@ -969,9 +1016,6 @@ where let W = W.pad(&S); // pad the witness let mut transcript = G::TE::new(b"RelaxedR1CSSNARK"); - // a list of polynomial evaluation claims that will be batched - let mut w_u_vec = Vec::new(); - // append the verifier key (which includes commitment to R1CS matrices) and the RelaxedR1CSInstance to the transcript transcript.absorb(b"vk", &pk.vk_digest); transcript.absorb(b"U", U); @@ -1028,7 +1072,6 @@ where // absorb commitments to L_row and L_col in the transcript transcript.absorb(b"e", &vec![comm_L_row, comm_L_col].as_slice()); - // add claims about Az, Bz, and Cz to be checked later // since all the three polynomials are opened at tau, // we can combine them into a single polynomial opened at tau let eval_vec = vec![eval_Az_at_tau, eval_Bz_at_tau, eval_Cz_at_tau]; @@ -1038,84 +1081,91 @@ where let c = transcript.squeeze(b"c")?; let w: PolyEvalWitness = PolyEvalWitness::batch(&poly_vec, &c); let u: PolyEvalInstance = PolyEvalInstance::batch(&comm_vec, &tau_coords, &eval_vec, &c); - w_u_vec.push((w, u)); // we now need to prove three claims - // (1) 0 = \sum_x poly_tau(x) * (poly_Az(x) * poly_Bz(x) - poly_uCz_E(x)) - // (2) eval_Az_at_tau + r * eval_Bz_at_tau + r^2 * eval_Cz_at_tau = \sum_y L_row(y) * (val_A(y) + r * val_B(y) + r^2 * val_C(y)) * L_col(y) + // (1) 0 = \sum_x poly_tau(x) * (poly_Az(x) * poly_Bz(x) - poly_uCz_E(x)), and eval_Az_at_tau + r * eval_Az_at_tau + r^2 * eval_Cz_at_tau = (Az+r*Bz+r^2*Cz)(tau) + // (2) eval_Az_at_tau + c * eval_Bz_at_tau + c^2 * eval_Cz_at_tau = \sum_y L_row(y) * (val_A(y) + c * val_B(y) + c^2 * val_C(y)) * L_col(y) // (3) L_row(i) = eq(tau, row(i)) and L_col(i) = z(col(i)) + let gamma = transcript.squeeze(b"g")?; + let r = transcript.squeeze(b"r")?; - // a sum-check instance to prove the first claim - let mut outer_sc_inst = OuterSumcheckInstance { - poly_tau: MultilinearPolynomial::new(PowPolynomial::new(&tau, num_rounds_sc).evals()), - poly_Az: MultilinearPolynomial::new(Az.clone()), - poly_Bz: MultilinearPolynomial::new(Bz.clone()), - poly_uCz_E: { - let uCz_E = (0..Cz.len()) - .map(|i| U.u * Cz[i] + E[i]) + let ((mut outer_sc_inst, mut inner_sc_inst), mem_res) = rayon::join( + || { + // a sum-check instance to prove the first claim + let outer_sc_inst = OuterSumcheckInstance::new( + PowPolynomial::new(&tau, num_rounds_sc).evals(), + Az.clone(), + Bz.clone(), + (0..Cz.len()) + .map(|i| U.u * Cz[i] + E[i]) + .collect::>(), + w.p.clone(), // Mz = Az + r * Bz + r^2 * Cz + &u.e, // eval_Az_at_tau + r * eval_Az_at_tau + r^2 * eval_Cz_at_tau + ); + + // a sum-check instance to prove the second claim + let val = pk + .S_repr + .val_A + .par_iter() + .zip(pk.S_repr.val_B.par_iter()) + .zip(pk.S_repr.val_C.par_iter()) + .map(|((v_a, v_b), v_c)| *v_a + c * *v_b + c * c * *v_c) .collect::>(); - MultilinearPolynomial::new(uCz_E) - }, - }; - - // a sum-check instance to prove the second claim - let val = pk - .S_repr - .val_A - .iter() - .zip(pk.S_repr.val_B.iter()) - .zip(pk.S_repr.val_C.iter()) - .map(|((v_a, v_b), v_c)| *v_a + c * *v_b + c * c * *v_c) - .collect::>(); - let mut inner_sc_inst = InnerSumcheckInstance { - claim: eval_Az_at_tau + c * eval_Bz_at_tau + c * c * eval_Cz_at_tau, - poly_L_row: MultilinearPolynomial::new(L_row.clone()), - poly_L_col: MultilinearPolynomial::new(L_col.clone()), - poly_val: MultilinearPolynomial::new(val), - }; - - // a third sum-check instance to prove the read-only memory claim - // we now need to prove that L_row and L_col are well-formed - let gamma = transcript.squeeze(b"g")?; + let inner_sc_inst = InnerSumcheckInstance { + claim: eval_Az_at_tau + c * eval_Bz_at_tau + c * c * eval_Cz_at_tau, + poly_L_row: MultilinearPolynomial::new(L_row.clone()), + poly_L_col: MultilinearPolynomial::new(L_col.clone()), + poly_val: MultilinearPolynomial::new(val), + }; - // hash the tuples of (addr,val) memory contents and read responses into a single field element using `hash_func` - let hash_func_vec = |mem: &[G::Scalar], - addr: &[G::Scalar], - lookups: &[G::Scalar]| - -> (Vec, Vec) { - let hash_func = |addr: &G::Scalar, val: &G::Scalar| -> G::Scalar { *val * gamma + *addr }; - assert_eq!(addr.len(), lookups.len()); - rayon::join( - || { - (0..mem.len()) - .map(|i| hash_func(&G::Scalar::from(i as u64), &mem[i])) - .collect::>() - }, - || { - (0..addr.len()) - .map(|i| hash_func(&addr[i], &lookups[i])) - .collect::>() - }, - ) - }; + (outer_sc_inst, inner_sc_inst) + }, + || { + // a third sum-check instance to prove the read-only memory claim + // we now need to prove that L_row and L_col are well-formed + + // hash the tuples of (addr,val) memory contents and read responses into a single field element using `hash_func` + let hash_func_vec = |mem: &[G::Scalar], + addr: &[G::Scalar], + lookups: &[G::Scalar]| + -> (Vec, Vec) { + let hash_func = |addr: &G::Scalar, val: &G::Scalar| -> G::Scalar { *val * gamma + *addr }; + assert_eq!(addr.len(), lookups.len()); + rayon::join( + || { + (0..mem.len()) + .map(|i| hash_func(&G::Scalar::from(i as u64), &mem[i])) + .collect::>() + }, + || { + (0..addr.len()) + .map(|i| hash_func(&addr[i], &lookups[i])) + .collect::>() + }, + ) + }; - let ((T_row, W_row), (T_col, W_col)) = rayon::join( - || hash_func_vec(&mem_row, &pk.S_repr.row, &L_row), - || hash_func_vec(&mem_col, &pk.S_repr.col, &L_col), + let ((T_row, W_row), (T_col, W_col)) = rayon::join( + || hash_func_vec(&mem_row, &pk.S_repr.row, &L_row), + || hash_func_vec(&mem_col, &pk.S_repr.col, &L_col), + ); + + MemorySumcheckInstance::new( + ck, + &r, + T_row, + W_row, + pk.S_repr.ts_row.clone(), + T_col, + W_col, + pk.S_repr.ts_col.clone(), + &mut transcript, + ) + }, ); - let r = transcript.squeeze(b"r")?; - let (mut mem_sc_inst, comm_lookup, polys_lookup) = LookupSumcheckInstance::new( - ck, - &r, - T_row, - W_row, - pk.S_repr.ts_row.clone(), - T_col, - W_col, - pk.S_repr.ts_col.clone(), - &mut transcript, - )?; + let (mut mem_sc_inst, comm_mem_oracles, mem_oracles) = mem_res?; let (sc, rand_sc, claims_mem, claims_outer, claims_inner) = Self::prove_helper( &mut mem_sc_inst, @@ -1192,13 +1242,13 @@ where pk.S_comm.comm_val_A, pk.S_comm.comm_val_B, pk.S_comm.comm_val_C, - comm_lookup[0], + comm_mem_oracles[0], pk.S_comm.comm_row, - comm_lookup[1], + comm_mem_oracles[1], pk.S_comm.comm_ts_row, - comm_lookup[2], + comm_mem_oracles[2], pk.S_comm.comm_col, - comm_lookup[3], + comm_mem_oracles[3], pk.S_comm.comm_ts_col, ]; let poly_vec = [ @@ -1212,13 +1262,13 @@ where &pk.S_repr.val_A, &pk.S_repr.val_B, &pk.S_repr.val_C, - polys_lookup[0].as_ref(), + mem_oracles[0].as_ref(), &pk.S_repr.row, - polys_lookup[1].as_ref(), + mem_oracles[1].as_ref(), &pk.S_repr.ts_row, - polys_lookup[2].as_ref(), + mem_oracles[2].as_ref(), &pk.S_repr.col, - polys_lookup[3].as_ref(), + mem_oracles[3].as_ref(), &pk.S_repr.ts_col, ]; transcript.absorb(b"e", &eval_vec.as_slice()); // comm_vec is already in the transcript @@ -1235,10 +1285,10 @@ where comm_L_row: comm_L_row.compress(), comm_L_col: comm_L_col.compress(), - comm_t_plus_r_inv_row: comm_lookup[0].compress(), - comm_w_plus_r_inv_row: comm_lookup[1].compress(), - comm_t_plus_r_inv_col: comm_lookup[2].compress(), - comm_w_plus_r_inv_col: comm_lookup[3].compress(), + comm_t_plus_r_inv_row: comm_mem_oracles[0].compress(), + comm_w_plus_r_inv_row: comm_mem_oracles[1].compress(), + comm_t_plus_r_inv_col: comm_mem_oracles[2].compress(), + comm_w_plus_r_inv_col: comm_mem_oracles[3].compress(), eval_Az_at_tau, eval_Bz_at_tau, @@ -1275,7 +1325,6 @@ where /// verifies a proof of satisfiability of a `RelaxedR1CS` instance fn verify(&self, vk: &Self::VerifierKey, U: &RelaxedR1CSInstance) -> Result<(), NovaError> { let mut transcript = G::TE::new(b"RelaxedR1CSSNARK"); - let mut u_vec: Vec> = Vec::new(); // append the verifier key (including commitment to R1CS matrices) and the RelaxedR1CSInstance to the transcript transcript.absorb(b"vk", &vk.digest()); @@ -1320,9 +1369,8 @@ where let comm_vec = vec![comm_Az, comm_Bz, comm_Cz]; transcript.absorb(b"e", &eval_vec.as_slice()); // c_vec is already in the transcript let c = transcript.squeeze(b"c")?; - let u = PolyEvalInstance::batch(&comm_vec, &tau_coords, &eval_vec, &c); - let claim_inner = u.e; - u_vec.push(u); + let u: PolyEvalInstance = PolyEvalInstance::batch(&comm_vec, &tau_coords, &eval_vec, &c); + let claim = u.e; let gamma = transcript.squeeze(b"g")?; @@ -1341,10 +1389,10 @@ where let rho = transcript.squeeze(b"r")?; - let num_claims = 8; + let num_claims = 9; let s = transcript.squeeze(b"r")?; let coeffs = powers::(&s, num_claims); - let claim = coeffs[7] * claim_inner; // rest are zeros + let claim = (coeffs[7] + coeffs[8]) * claim; // rest are zeros // verify sc let (claim_sc_final, rand_sc) = self.sc.verify(claim, num_rounds_sc, 3, &mut transcript)?; @@ -1436,8 +1484,9 @@ where let claim_outer_final_expected = coeffs[6] * taus_bound_rand_sc - * (self.eval_Az * self.eval_Bz - U.u * self.eval_Cz - self.eval_E); - let claim_inner_final_expected = coeffs[7] + * (self.eval_Az * self.eval_Bz - U.u * self.eval_Cz - self.eval_E) + + coeffs[7] * taus_bound_rand_sc * (self.eval_Az + c * self.eval_Bz + c * c * self.eval_Cz); + let claim_inner_final_expected = coeffs[8] * self.eval_L_row * self.eval_L_col * (self.eval_val_A + c * self.eval_val_B + c * c * self.eval_val_C);