diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index deae2fd26..4dbbc2322 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -42,13 +42,12 @@ impl GkrOps for CpuBackend { let n_variables = h.n_variables(); assert!(!n_variables.is_zero()); let n_terms = 1 << (n_variables - 1); - let eq_evals = h.eq_evals; + let eq_evals = h.eq_evals.as_ref(); // Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube. let y = eq_evals.y(); let lambda = h.lambda; - let input_layer = &h.input_layer; - let (mut eval_at_0, mut eval_at_2) = match input_layer { + let (mut eval_at_0, mut eval_at_2) = match &h.input_layer { Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms), Layer::LogUpGeneric { numerators, @@ -88,6 +87,9 @@ fn eval_grand_product_sum( let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1]; // Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)` // => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)` + // TODO(andrew): Consider evaluation at `1/2` to save an addition operation since + // `inp(r, 1/2, x) = 1/2 * (inp(r, 1, x) + inp(r, 0, x))`. `1/2 * ...` can be factored + // outside the loop. let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0; let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1; diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index 94065f3a9..51163e35c 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -1,14 +1,19 @@ use std::iter::zip; +use num_traits::Zero; + use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals; use crate::core::backend::simd::column::SecureColumn; use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; -use crate::core::backend::Column; +use crate::core::backend::{Column, CpuBackend}; use crate::core::fields::qm31::SecureField; -use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer}; +use crate::core::lookups::gkr_prover::{ + correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer, +}; use crate::core::lookups::mle::Mle; +use crate::core::lookups::sumcheck::MultivariatePolyOracle; use crate::core::lookups::utils::UnivariatePoly; impl GkrOps for SimdBackend { @@ -47,15 +52,144 @@ impl GkrOps for SimdBackend { Mle::new(SecureColumn { data, length }) } - fn next_layer(_layer: &Layer) -> Layer { - todo!() + fn next_layer(layer: &Layer) -> Layer { + // Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector. + if layer.n_variables() as u32 <= LOG_N_LANES { + return into_simd_layer(layer.to_cpu().next_layer().unwrap()); + } + + match layer { + Layer::GrandProduct(col) => next_grand_product_layer(col), + Layer::LogUpGeneric { + numerators: _, + denominators: _, + } => todo!(), + Layer::LogUpMultiplicities { + numerators: _, + denominators: _, + } => todo!(), + Layer::LogUpSingles { denominators: _ } => todo!(), + } } fn sum_as_poly_in_first_variable( - _h: &GkrMultivariatePolyOracle<'_, Self>, - _claim: SecureField, + h: &GkrMultivariatePolyOracle<'_, Self>, + claim: SecureField, ) -> UnivariatePoly { - todo!() + let n_variables = h.n_variables(); + let n_terms = 1 << n_variables.saturating_sub(1); + let eq_evals = h.eq_evals.as_ref(); + // Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube. + let y = eq_evals.y(); + + // Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector. + if n_terms < N_LANES { + return h.to_cpu().sum_as_poly_in_first_variable(claim); + } + + let n_packed_terms = n_terms / N_LANES; + + let (mut eval_at_0, mut eval_at_2) = match &h.input_layer { + Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_packed_terms), + Layer::LogUpGeneric { + numerators: _, + denominators: _, + } => todo!(), + Layer::LogUpMultiplicities { + numerators: _, + denominators: _, + } => todo!(), + Layer::LogUpSingles { denominators: _ } => todo!(), + }; + + eval_at_0 *= h.eq_fixed_var_correction; + eval_at_2 *= h.eq_fixed_var_correction; + correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables) + } +} + +/// Generates the next GKR layer for Grand Product. +/// +/// Assumption: `len(layer) > N_LANES`. +fn next_grand_product_layer(layer: &Mle) -> Layer { + assert!(layer.len() > N_LANES); + let next_layer_len = layer.len() / 2; + + let data = layer + .data + .array_chunks() + .map(|&[a, b]| { + let (evens, odds) = a.deinterleave(b); + evens * odds + }) + .collect(); + + Layer::GrandProduct(Mle::new(SecureFieldVec { + data, + length: next_layer_len, + })) +} + +/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`. +/// +/// Output of the form: `(eval_at_0, eval_at_2)`. +fn eval_grand_product_sum( + eq_evals: &EqEvals, + col: &Mle, + n_packed_terms: usize, +) -> (SecureField, SecureField) { + let mut packed_eval_at_0 = PackedSecureField::zero(); + let mut packed_eval_at_2 = PackedSecureField::zero(); + + for i in 0..n_packed_terms { + // Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_at_r0iv0, inp_at_r0iv1) = col.data[i * 2].deinterleave(col.data[i * 2 + 1]); + let (inp_at_r1iv0, inp_at_r1iv1) = + col.data[(n_packed_terms + i) * 2].deinterleave(col.data[(n_packed_terms + i) * 2 + 1]); + // Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)` + // => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)` + let inp_at_r2iv0 = inp_at_r1iv0.double() - inp_at_r0iv0; + let inp_at_r2iv1 = inp_at_r1iv1.double() - inp_at_r0iv1; + + // Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i), v)`. + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let prod_at_r2iv = inp_at_r2iv0 * inp_at_r2iv1; + let prod_at_r0iv = inp_at_r0iv0 * inp_at_r0iv1; + + let eq_eval_at_0iv = eq_evals.data[i]; + packed_eval_at_0 += eq_eval_at_0iv * prod_at_r0iv; + packed_eval_at_2 += eq_eval_at_0iv * prod_at_r2iv; + } + + ( + packed_eval_at_0.pointwise_sum(), + packed_eval_at_2.pointwise_sum(), + ) +} + +fn into_simd_layer(cpu_layer: Layer) -> Layer { + match cpu_layer { + Layer::GrandProduct(mle) => { + Layer::GrandProduct(Mle::new(mle.into_evals().into_iter().collect())) + } + Layer::LogUpGeneric { + numerators, + denominators, + } => Layer::LogUpGeneric { + numerators: Mle::new(numerators.into_evals().into_iter().collect()), + denominators: Mle::new(denominators.into_evals().into_iter().collect()), + }, + Layer::LogUpMultiplicities { + numerators, + denominators, + } => Layer::LogUpMultiplicities { + numerators: Mle::new(numerators.into_evals().into_iter().collect()), + denominators: Mle::new(denominators.into_evals().into_iter().collect()), + }, + Layer::LogUpSingles { denominators } => Layer::LogUpSingles { + denominators: Mle::new(denominators.into_evals().into_iter().collect()), + }, } } @@ -63,8 +197,13 @@ impl GkrOps for SimdBackend { mod tests { use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Column, CpuBackend}; + use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; - use crate::core::lookups::gkr_prover::GkrOps; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; + use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use crate::core::lookups::mle::Mle; + use crate::core::test_utils::test_channel; #[test] fn gen_eq_evals_matches_cpu() { @@ -87,4 +226,27 @@ mod tests { assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu); } + + #[test] + fn grand_product_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let values = test_channel().draw_felts(N); + let product = values.iter().product(); + let col = Mle::::new(values.into_iter().collect()); + let input_layer = Layer::GrandProduct(col.clone()); + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; + + assert_eq!(proof.output_claims_by_instance, [vec![product]]); + assert_eq!( + claims_to_verify_by_instance, + [vec![col.eval_at_point(&ood_point)]] + ); + Ok(()) + } } diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index c0486f3b5..4413322ff 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -1,7 +1,9 @@ //! GKR batch prover for Grand Product and LogUp lookup arguments. +use std::borrow::Cow; use std::iter::{successors, zip}; use std::ops::Deref; +use educe::Educe; use itertools::Itertools; use num_traits::{One, Zero}; use thiserror::Error; @@ -10,7 +12,7 @@ use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask}; use super::mle::{Mle, MleOps}; use super::sumcheck::MultivariatePolyOracle; use super::utils::{eq, random_linear_combination, UnivariatePoly}; -use crate::core::backend::{Col, Column, ColumnOps}; +use crate::core::backend::{Col, Column, ColumnOps, CpuBackend}; use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; @@ -46,6 +48,8 @@ pub trait GkrOps: MleOps + MleOps { /// `evals[1] = eq((0, ..., 0, 1), y)`, etc. /// /// [`eq(x, y)`]: crate::core::lookups::utils::eq +#[derive(Educe)] +#[educe(Debug, Clone)] pub struct EqEvals> { y: Vec, evals: Mle, @@ -103,7 +107,7 @@ pub enum Layer { impl Layer { /// Returns the number of variables used to interpolate the layer's gate values. - fn n_variables(&self) -> usize { + pub fn n_variables(&self) -> usize { match self { Self::GrandProduct(mle) | Self::LogUpSingles { denominators: mle } @@ -227,12 +231,40 @@ impl Layer { eq_evals: &EqEvals, ) -> GkrMultivariatePolyOracle<'_, B> { GkrMultivariatePolyOracle { - eq_evals, + eq_evals: Cow::Borrowed(eq_evals), input_layer: self, eq_fixed_var_correction: SecureField::one(), lambda, } } + + /// Returns a copy of this layer with the [`CpuBackend`]. + /// + /// This operation is expensive but can be useful for small traces that are difficult to handle + /// depending on the backend. For example, the SIMD backend offloads to the CPU backend when + /// trace length becomes smaller than the SIMD lane count. + pub fn to_cpu(&self) -> Layer { + match self { + Layer::GrandProduct(mle) => Layer::GrandProduct(Mle::new(mle.to_cpu())), + Layer::LogUpGeneric { + numerators, + denominators, + } => Layer::LogUpGeneric { + numerators: Mle::new(numerators.to_cpu()), + denominators: Mle::new(denominators.to_cpu()), + }, + Layer::LogUpMultiplicities { + numerators, + denominators, + } => Layer::LogUpMultiplicities { + numerators: Mle::new(numerators.to_cpu()), + denominators: Mle::new(denominators.to_cpu()), + }, + Layer::LogUpSingles { denominators } => Layer::LogUpSingles { + denominators: Mle::new(denominators.to_cpu()), + }, + } + } } #[derive(Debug)] @@ -258,7 +290,7 @@ struct NotOutputLayerError; /// ``` pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { /// `eq_evals` passed by `Layer::into_multivariate_poly()`. - pub eq_evals: &'a EqEvals, + pub eq_evals: Cow<'a, EqEvals>, pub input_layer: Layer, pub eq_fixed_var_correction: SecureField, /// Used by LogUp to perform a random linear combination of the numerators and denominators. @@ -332,6 +364,27 @@ impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> { Ok(GkrMask::new(columns)) } + + /// Returns a copy of this oracle with the [`CpuBackend`]. + /// + /// This operation is expensive but can be useful for small oracles that are difficult to handle + /// depending on the backend. For example, the SIMD backend offloads to the CPU backend when + /// trace length becomes smaller than the SIMD lane count. + pub fn to_cpu(&self) -> GkrMultivariatePolyOracle<'a, CpuBackend> { + // TODO(andrew): This block is not ideal. + let n_eq_evals = 1 << (self.n_variables() - 1); + let eq_evals = Cow::Owned(EqEvals { + evals: Mle::new((0..n_eq_evals).map(|i| self.eq_evals.at(i)).collect()), + y: self.eq_evals.y.to_vec(), + }); + + GkrMultivariatePolyOracle { + eq_evals, + eq_fixed_var_correction: self.eq_fixed_var_correction, + input_layer: self.input_layer.to_cpu(), + lambda: self.lambda, + } + } } /// Error returned when a polynomial is expected to be constant but it is not.