From 7ae75ab6925b008574a071959c22073399639f6b Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Fri, 24 May 2024 21:23:49 -0400 Subject: [PATCH] Implement GrandProductOps for SIMD backend --- .../src/core/backend/cpu/lookups/gkr.rs | 5 +- .../src/core/backend/simd/lookups/gkr.rs | 184 +++++++++++++++++- .../src/core/backend/simd/lookups/mod.rs | 1 + crates/prover/src/core/lookups/gkr_prover.rs | 61 +++++- 4 files changed, 236 insertions(+), 15 deletions(-) diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index d38928706..855a23e0a 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -41,15 +41,14 @@ impl GkrOps for CpuBackend { ) -> UnivariatePoly { let k = h.n_variables(); let n_terms = 1 << (k - 1); - let eq_evals = h.eq_evals; + let eq_evals = h.eq_evals.as_ref(); let y = eq_evals.y(); let lambda = h.lambda; - let input_layer = &h.input_layer; let mut eval_at_0 = SecureField::zero(); let mut eval_at_2 = SecureField::zero(); - match input_layer { + match &h.input_layer { Layer::GrandProduct(col) => { process_grand_product_sum(&mut eval_at_0, &mut eval_at_2, eq_evals, col, n_terms) } diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index bd9d635b5..e69daa2f7 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -1,14 +1,19 @@ use std::iter::zip; +use num_traits::Zero; + use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals; use crate::core::backend::simd::column::SecureFieldVec; use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; -use crate::core::backend::Column; +use crate::core::backend::{Column, CpuBackend}; use crate::core::fields::qm31::SecureField; -use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer}; +use crate::core::lookups::gkr_prover::{ + correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer, +}; use crate::core::lookups::mle::Mle; +use crate::core::lookups::sumcheck::MultivariatePolyOracle; use crate::core::lookups::utils::UnivariatePoly; impl GkrOps for SimdBackend { @@ -47,15 +52,150 @@ impl GkrOps for SimdBackend { Mle::new(SecureFieldVec { data, length }) } - fn next_layer(_layer: &Layer) -> Layer { - todo!() + fn next_layer(layer: &Layer) -> Layer { + // Offload to CPU backend to prevent dealing with instances smaller than a SIMD vector. + if layer.n_variables() as u32 <= LOG_N_LANES { + return into_simd_layer(layer.to_cpu().next_layer().unwrap()); + } + + match layer { + Layer::GrandProduct(col) => next_grand_product_layer(col), + Layer::LogUpGeneric { + numerators: _, + denominators: _, + } => todo!(), + Layer::LogUpMultiplicities { + numerators: _, + denominators: _, + } => todo!(), + Layer::LogUpSingles { denominators: _ } => todo!(), + } } fn sum_as_poly_in_first_variable( - _h: &GkrMultivariatePolyOracle<'_, Self>, - _claim: SecureField, + h: &GkrMultivariatePolyOracle<'_, Self>, + claim: SecureField, ) -> UnivariatePoly { - todo!() + let k = h.n_variables(); + let n_terms = 1 << (k - 1); + let eq_evals = h.eq_evals.as_ref(); + let y = eq_evals.y(); + + // Offload to CPU backend to prevent dealing with instances smaller than a SIMD vector. + if n_terms <= N_LANES { + return h.to_cpu().sum_as_poly_in_first_variable(claim); + } + + let mut packed_eval_at_0 = PackedSecureField::zero(); + let mut packed_eval_at_2 = PackedSecureField::zero(); + + match &h.input_layer { + Layer::GrandProduct(col) => process_grand_product_sum( + &mut packed_eval_at_0, + &mut packed_eval_at_2, + eq_evals, + col, + n_terms, + ), + Layer::LogUpGeneric { + numerators: _, + denominators: _, + } => todo!(), + Layer::LogUpMultiplicities { + numerators: _, + denominators: _, + } => todo!(), + Layer::LogUpSingles { denominators: _ } => todo!(), + } + + let eval_at_0 = packed_eval_at_0.pointwise_sum() * h.eq_fixed_var_correction; + let eval_at_2 = packed_eval_at_2.pointwise_sum() * h.eq_fixed_var_correction; + + correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, k) + } +} + +// Can assume `len(layer) > N_LANES * 2` +fn next_grand_product_layer(layer: &Mle) -> Layer { + assert!(layer.len() > N_LANES); + let next_layer_len = layer.len() / 2; + + let data = layer + .data + .array_chunks() + .map(|&[a, b]| { + let (evens, odds) = a.deinterleave(b); + evens * odds + }) + .collect(); + + Layer::GrandProduct(Mle::new(SecureFieldVec { + data, + length: next_layer_len, + })) +} + +// Can assume `n_terms > N_LANES` +fn process_grand_product_sum( + packed_eval_at_0: &mut PackedSecureField, + packed_eval_at_2: &mut PackedSecureField, + eq_evals: &EqEvals, + col: &Mle, + n_terms: usize, +) { + assert!(n_terms > N_LANES); + + #[allow(clippy::needless_range_loop)] + for i in 0..n_terms { + // Let `p` be the multilinear polynomial representing `col`. + let (p0x0 /* = p(0, x, 0) */, p0x1 /* = p(0, x, 1) */) = + col.data[i * 2].deinterleave(col.data[i * 2 + 1]); + + // We obtain `p(2, x)` for some `x` in the boolean + // hypercube using `p(0, x)` and `p(1, x)`: + // + // ```text + // p(t, x) = eq(t, 0) * p(0, x) + eq(t, 1) * p(1, x) + // = (1 - t) * p(0, x) + t * p(1, x) + // + // p(2, x) = 2 * p(1, x) - p(0, x) + // ``` + let (p1x0 /* = p(1, x, 0) */, p1x1 /* = p(1, x, 1) */) = + col.data[(n_terms + i) * 2].deinterleave(col.data[(n_terms + i) * 2 + 1]); + let p2x0 /* = p(2, x, 0) */ = p1x0.double() - p0x0; + let p2x1 /* = p(2, x, 1) */ = p1x1.double() - p0x1; + + let product2 = p2x0 * p2x1; + let product0 = p0x0 * p0x1; + + let eq_eval = eq_evals.data[i]; + *packed_eval_at_0 += eq_eval * product0; + *packed_eval_at_2 += eq_eval * product2; + } +} + +fn into_simd_layer(cpu_layer: Layer) -> Layer { + match cpu_layer { + Layer::GrandProduct(mle) => { + Layer::GrandProduct(Mle::new(mle.into_evals().into_iter().collect())) + } + Layer::LogUpGeneric { + numerators, + denominators, + } => Layer::LogUpGeneric { + numerators: Mle::new(numerators.into_evals().into_iter().collect()), + denominators: Mle::new(denominators.into_evals().into_iter().collect()), + }, + Layer::LogUpMultiplicities { + numerators, + denominators, + } => Layer::LogUpMultiplicities { + numerators: Mle::new(numerators.into_evals().into_iter().collect()), + denominators: Mle::new(denominators.into_evals().into_iter().collect()), + }, + Layer::LogUpSingles { denominators } => Layer::LogUpSingles { + denominators: Mle::new(denominators.into_evals().into_iter().collect()), + }, } } @@ -63,8 +203,13 @@ impl GkrOps for SimdBackend { mod tests { use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Column, CpuBackend}; + use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; - use crate::core::lookups::gkr_prover::GkrOps; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; + use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use crate::core::lookups::mle::Mle; + use crate::core::test_utils::test_channel; #[test] fn gen_eq_evals_matches_cpu() { @@ -76,4 +221,27 @@ mod tests { assert_eq!(*cpu_eq_evals, simd_eq_evals.to_cpu()); } + + #[test] + fn grand_product_works() -> Result<(), GkrError> { + const N: usize = 1 << 6; + let values = test_channel().draw_felts(N); + let product = values.iter().product(); + let col = Mle::::new(values.into_iter().collect()); + let input_layer = Layer::GrandProduct(col.clone()); + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; + + assert_eq!(proof.output_claims_by_instance, [vec![product]]); + assert_eq!( + claims_to_verify_by_instance, + [vec![col.eval_at_point(&ood_point)]] + ); + Ok(()) + } } diff --git a/crates/prover/src/core/backend/simd/lookups/mod.rs b/crates/prover/src/core/backend/simd/lookups/mod.rs index 34395e985..e2ee801c4 100644 --- a/crates/prover/src/core/backend/simd/lookups/mod.rs +++ b/crates/prover/src/core/backend/simd/lookups/mod.rs @@ -1,2 +1,3 @@ mod gkr; +// mod grandproduct; mod mle; diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index 82be44f93..5a92d6f00 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -1,7 +1,9 @@ //! Batch GKR protocol implementation designed to prove lookup arguments. +use std::borrow::Cow; use std::iter::{successors, zip}; use std::ops::Deref; +use educe::Educe; use itertools::Itertools; use num_traits::{One, Zero}; use thiserror::Error; @@ -10,7 +12,7 @@ use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask}; use super::mle::{Mle, MleOps}; use super::sumcheck::MultivariatePolyOracle; use super::utils::{eq, random_linear_combination, UnivariatePoly}; -use crate::core::backend::{Col, Column, ColumnOps}; +use crate::core::backend::{Col, Column, ColumnOps, CpuBackend}; use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; @@ -44,6 +46,8 @@ pub trait GkrOps: MleOps + MleOps { /// `evals[1] = eq((0, ..., 0, 1), y)`, etc. /// /// [`eq(x, y)`]: crate::core::lookups::utils::eq +#[derive(Educe)] +#[educe(Debug, Clone)] pub struct EqEvals> { y: Vec, evals: Mle, @@ -101,7 +105,7 @@ pub enum Layer { impl Layer { /// Returns the number of variables used to interpolate the layer's gate values. - fn n_variables(&self) -> usize { + pub fn n_variables(&self) -> usize { match self { Self::GrandProduct(mle) | Self::LogUpSingles { denominators: mle } @@ -190,7 +194,7 @@ impl Layer { eq_evals: &EqEvals, ) -> GkrMultivariatePolyOracle<'_, B> { GkrMultivariatePolyOracle { - eq_evals, + eq_evals: Cow::Borrowed(eq_evals), input_layer: self, eq_fixed_var_correction: SecureField::one(), lambda, @@ -230,6 +234,34 @@ impl Layer { } }) } + + /// Returns a copy of this layer with the [`CpuBackend`]. + /// + /// This operation is expensive but can be useful for small traces that are difficult to handle + /// depending on the backend. For example, the SIMD backend offloads to the CPU backend when + /// trace length becomes smaller than the SIMD lane count. + pub fn to_cpu(&self) -> Layer { + match self { + Layer::GrandProduct(mle) => Layer::GrandProduct(Mle::new(mle.to_cpu())), + Layer::LogUpGeneric { + numerators, + denominators, + } => Layer::LogUpGeneric { + numerators: Mle::new(numerators.to_cpu()), + denominators: Mle::new(denominators.to_cpu()), + }, + Layer::LogUpMultiplicities { + numerators, + denominators, + } => Layer::LogUpMultiplicities { + numerators: Mle::new(numerators.to_cpu()), + denominators: Mle::new(denominators.to_cpu()), + }, + Layer::LogUpSingles { denominators } => Layer::LogUpSingles { + denominators: Mle::new(denominators.to_cpu()), + }, + } + } } #[derive(Debug)] @@ -238,7 +270,7 @@ struct NotOutputLayerError; /// A multivariate polynomial that expresses the relation between two consecutive GKR layers. pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { /// `eq_evals` passed by [`Layer::into_multivariate_poly()`]. - pub eq_evals: &'a EqEvals, + pub eq_evals: Cow<'a, EqEvals>, pub input_layer: Layer, pub eq_fixed_var_correction: SecureField, /// Used by LogUp to perform a random linear combination of the numerators and denominators. @@ -312,6 +344,27 @@ impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> { Ok(GkrMask::new(columns)) } + + /// Returns a copy of this oracle with the [`CpuBackend`]. + /// + /// This operation is expensive but can be useful for small oracles that are difficult to handle + /// depending on the backend. For example, the SIMD backend offloads to the CPU backend when + /// trace length becomes smaller than the SIMD lane count. + pub fn to_cpu(&self) -> GkrMultivariatePolyOracle<'a, CpuBackend> { + // TODO(andrew): This block is not ideal. + let n_eq_evals = 1 << (self.n_variables() - 1); + let eq_evals = Cow::Owned(EqEvals { + evals: Mle::new((0..n_eq_evals).map(|i| self.eq_evals.at(i)).collect()), + y: self.eq_evals.y.to_vec(), + }); + + GkrMultivariatePolyOracle { + eq_evals, + eq_fixed_var_correction: self.eq_fixed_var_correction, + input_layer: self.input_layer.to_cpu(), + lambda: self.lambda, + } + } } /// Error returned when a polynomial is expected to be constant but it is not.