Skip to content

Commit

Permalink
Implement GrandProductOps for SIMD backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Jul 18, 2024
1 parent be26562 commit a454e56
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 15 deletions.
8 changes: 5 additions & 3 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,12 @@ impl GkrOps for CpuBackend {
let n_variables = h.n_variables();
assert!(!n_variables.is_zero());
let n_terms = 1 << (n_variables - 1);
let eq_evals = h.eq_evals;
let eq_evals = h.eq_evals.as_ref();
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
let y = eq_evals.y();
let lambda = h.lambda;
let input_layer = &h.input_layer;

let (mut eval_at_0, mut eval_at_2) = match input_layer {
let (mut eval_at_0, mut eval_at_2) = match &h.input_layer {
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms),
Layer::LogUpGeneric {
numerators,
Expand Down Expand Up @@ -88,6 +87,9 @@ fn eval_grand_product_sum(
let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1];
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
// TODO(andrew): Consider evaluation at `1/2` to save an addition operation since
// `inp(r, 1/2, x) = 1/2 * (inp(r, 1, x) + inp(r, 0, x))`. `1/2 * ...` can be factored
// outside the loop.
let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0;
let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1;

Expand Down
178 changes: 170 additions & 8 deletions crates/prover/src/core/backend/simd/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use std::iter::zip;

use num_traits::Zero;

use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals;
use crate::core::backend::simd::column::SecureFieldVec;
use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer};
use crate::core::lookups::gkr_prover::{
correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer,
};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::UnivariatePoly;

impl GkrOps for SimdBackend {
Expand Down Expand Up @@ -47,24 +52,158 @@ impl GkrOps for SimdBackend {
Mle::new(SecureFieldVec { data, length })
}

fn next_layer(_layer: &Layer<Self>) -> Layer<Self> {
todo!()
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
// Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector.
if layer.n_variables() as u32 <= LOG_N_LANES {
return into_simd_layer(layer.to_cpu().next_layer().unwrap());
}

match layer {
Layer::GrandProduct(col) => next_grand_product_layer(col),
Layer::LogUpGeneric {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpMultiplicities {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpSingles { denominators: _ } => todo!(),
}
}

fn sum_as_poly_in_first_variable(
_h: &GkrMultivariatePolyOracle<'_, Self>,
_claim: SecureField,
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
todo!()
let n_variables = h.n_variables();
let n_terms = 1 << n_variables.saturating_sub(1);
let eq_evals = h.eq_evals.as_ref();
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
let y = eq_evals.y();

// Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector.
if n_terms < N_LANES {
return h.to_cpu().sum_as_poly_in_first_variable(claim);
}

let n_packed_terms = n_terms / N_LANES;

let (mut eval_at_0, mut eval_at_2) = match &h.input_layer {
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_packed_terms),
Layer::LogUpGeneric {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpMultiplicities {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpSingles { denominators: _ } => todo!(),
};

eval_at_0 *= h.eq_fixed_var_correction;
eval_at_2 *= h.eq_fixed_var_correction;
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
}
}

/// Generates the next GKR layer for Grand Product.
///
/// Assumption `len(layer) > N_LANES`.
fn next_grand_product_layer(layer: &Mle<SimdBackend, SecureField>) -> Layer<SimdBackend> {
assert!(layer.len() > N_LANES);
let next_layer_len = layer.len() / 2;

let data = layer
.data
.array_chunks()
.map(|&[a, b]| {
let (evens, odds) = a.deinterleave(b);
evens * odds
})
.collect();

Layer::GrandProduct(Mle::new(SecureFieldVec {
data,
length: next_layer_len,
}))
}

/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`.
///
/// Output of the form: `(eval_at_0, eval_at_2)`.
fn eval_grand_product_sum(
eq_evals: &EqEvals<SimdBackend>,
col: &Mle<SimdBackend, SecureField>,
n_packed_terms: usize,
) -> (SecureField, SecureField) {
let mut packed_eval_at_0 = PackedSecureField::zero();
let mut packed_eval_at_2 = PackedSecureField::zero();

for i in 0..n_packed_terms {
// Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})`
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let (inp_at_r0iv0, inp_at_r0iv1) = col.data[i * 2].deinterleave(col.data[i * 2 + 1]);
let (inp_at_r1iv0, inp_at_r1iv1) =
col.data[(n_packed_terms + i) * 2].deinterleave(col.data[(n_packed_terms + i) * 2 + 1]);
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
let inp_at_r2iv0 = inp_at_r1iv0.double() - inp_at_r0iv0;
let inp_at_r2iv1 = inp_at_r1iv1.double() - inp_at_r0iv1;

// Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i), v)`.
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let prod_at_r2iv = inp_at_r2iv0 * inp_at_r2iv1;
let prod_at_r0iv = inp_at_r0iv0 * inp_at_r0iv1;

let eq_eval_at_0iv = eq_evals.data[i];
packed_eval_at_0 += eq_eval_at_0iv * prod_at_r0iv;
packed_eval_at_2 += eq_eval_at_0iv * prod_at_r2iv;
}

(
packed_eval_at_0.pointwise_sum(),
packed_eval_at_2.pointwise_sum(),
)
}

fn into_simd_layer(cpu_layer: Layer<CpuBackend>) -> Layer<SimdBackend> {
match cpu_layer {
Layer::GrandProduct(mle) => {
Layer::GrandProduct(Mle::new(mle.into_evals().into_iter().collect()))
}
Layer::LogUpGeneric {
numerators,
denominators,
} => Layer::LogUpGeneric {
numerators: Mle::new(numerators.into_evals().into_iter().collect()),
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
Layer::LogUpMultiplicities {
numerators,
denominators,
} => Layer::LogUpMultiplicities {
numerators: Mle::new(numerators.into_evals().into_iter().collect()),
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
Layer::LogUpSingles { denominators } => Layer::LogUpSingles {
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
}
}

#[cfg(test)]
mod tests {
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::lookups::mle::Mle;
use crate::core::test_utils::test_channel;

#[test]
fn gen_eq_evals_matches_cpu() {
Expand All @@ -87,4 +226,27 @@ mod tests {

assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu);
}

#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 8;
let values = test_channel().draw_felts(N);
let product = values.iter().product();
let col = Mle::<SimdBackend, SecureField>::new(values.into_iter().collect());
let input_layer = Layer::GrandProduct(col.clone());
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);

let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

assert_eq!(proof.output_claims_by_instance, [vec![product]]);
assert_eq!(
claims_to_verify_by_instance,
[vec![col.eval_at_point(&ood_point)]]
);
Ok(())
}
}
61 changes: 57 additions & 4 deletions crates/prover/src/core/lookups/gkr_prover.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! GKR batch prover for Grand Product and LogUp lookup arguments.
use std::borrow::Cow;
use std::iter::{successors, zip};
use std::ops::Deref;

use educe::Educe;
use itertools::Itertools;
use num_traits::{One, Zero};
use thiserror::Error;
Expand All @@ -10,7 +12,7 @@ use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask};
use super::mle::{Mle, MleOps};
use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::backend::{Col, Column, ColumnOps, CpuBackend};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -46,6 +48,8 @@ pub trait GkrOps: MleOps<BaseField> + MleOps<SecureField> {
/// `evals[1] = eq((0, ..., 0, 1), y)`, etc.
///
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
#[derive(Educe)]
#[educe(Debug, Clone)]
pub struct EqEvals<B: ColumnOps<SecureField>> {
y: Vec<SecureField>,
evals: Mle<B, SecureField>,
Expand Down Expand Up @@ -103,7 +107,7 @@ pub enum Layer<B: GkrOps> {

impl<B: GkrOps> Layer<B> {
/// Returns the number of variables used to interpolate the layer's gate values.
fn n_variables(&self) -> usize {
pub fn n_variables(&self) -> usize {
match self {
Self::GrandProduct(mle)
| Self::LogUpSingles { denominators: mle }
Expand Down Expand Up @@ -227,12 +231,40 @@ impl<B: GkrOps> Layer<B> {
eq_evals: &EqEvals<B>,
) -> GkrMultivariatePolyOracle<'_, B> {
GkrMultivariatePolyOracle {
eq_evals,
eq_evals: Cow::Borrowed(eq_evals),
input_layer: self,
eq_fixed_var_correction: SecureField::one(),
lambda,
}
}

/// Returns a copy of this layer with the [`CpuBackend`].
///
/// This operation is expensive but can be useful for small traces that are difficult to handle
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
/// trace length becomes smaller than the SIMD lane count.
pub fn to_cpu(&self) -> Layer<CpuBackend> {
match self {
Layer::GrandProduct(mle) => Layer::GrandProduct(Mle::new(mle.to_cpu())),
Layer::LogUpGeneric {
numerators,
denominators,
} => Layer::LogUpGeneric {
numerators: Mle::new(numerators.to_cpu()),
denominators: Mle::new(denominators.to_cpu()),
},
Layer::LogUpMultiplicities {
numerators,
denominators,
} => Layer::LogUpMultiplicities {
numerators: Mle::new(numerators.to_cpu()),
denominators: Mle::new(denominators.to_cpu()),
},
Layer::LogUpSingles { denominators } => Layer::LogUpSingles {
denominators: Mle::new(denominators.to_cpu()),
},
}
}
}

#[derive(Debug)]
Expand All @@ -258,7 +290,7 @@ struct NotOutputLayerError;
/// ```
pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {
/// `eq_evals` passed by `Layer::into_multivariate_poly()`.
pub eq_evals: &'a EqEvals<B>,
pub eq_evals: Cow<'a, EqEvals<B>>,
pub input_layer: Layer<B>,
pub eq_fixed_var_correction: SecureField,
/// Used by LogUp to perform a random linear combination of the numerators and denominators.
Expand Down Expand Up @@ -332,6 +364,27 @@ impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> {

Ok(GkrMask::new(columns))
}

/// Returns a copy of this oracle with the [`CpuBackend`].
///
/// This operation is expensive but can be useful for small oracles that are difficult to handle
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
/// trace length becomes smaller than the SIMD lane count.
pub fn to_cpu(&self) -> GkrMultivariatePolyOracle<'a, CpuBackend> {
// TODO(andrew): This block is not ideal.
let n_eq_evals = 1 << (self.n_variables() - 1);
let eq_evals = Cow::Owned(EqEvals {
evals: Mle::new((0..n_eq_evals).map(|i| self.eq_evals.at(i)).collect()),
y: self.eq_evals.y.to_vec(),
});

GkrMultivariatePolyOracle {
eq_evals,
eq_fixed_var_correction: self.eq_fixed_var_correction,
input_layer: self.input_layer.to_cpu(),
lambda: self.lambda,
}
}
}

/// Error returned when a polynomial is expected to be constant but it is not.
Expand Down

0 comments on commit a454e56

Please sign in to comment.