Skip to content

Commit

Permalink
Implement GrandProductOps for SIMD backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Jun 12, 2024
1 parent 7878fb0 commit 7ae75ab
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 15 deletions.
5 changes: 2 additions & 3 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,14 @@ impl GkrOps for CpuBackend {
) -> UnivariatePoly<SecureField> {
let k = h.n_variables();
let n_terms = 1 << (k - 1);
let eq_evals = h.eq_evals;
let eq_evals = h.eq_evals.as_ref();
let y = eq_evals.y();
let lambda = h.lambda;
let input_layer = &h.input_layer;

let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();

match input_layer {
match &h.input_layer {
Layer::GrandProduct(col) => {
process_grand_product_sum(&mut eval_at_0, &mut eval_at_2, eq_evals, col, n_terms)
}
Expand Down
184 changes: 176 additions & 8 deletions crates/prover/src/core/backend/simd/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use std::iter::zip;

use num_traits::Zero;

use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals;
use crate::core::backend::simd::column::SecureFieldVec;
use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer};
use crate::core::lookups::gkr_prover::{
correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer,
};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::UnivariatePoly;

impl GkrOps for SimdBackend {
Expand Down Expand Up @@ -47,24 +52,164 @@ impl GkrOps for SimdBackend {
Mle::new(SecureFieldVec { data, length })
}

fn next_layer(_layer: &Layer<Self>) -> Layer<Self> {
todo!()
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
// Offload to CPU backend to prevent dealing with instances smaller than a SIMD vector.
if layer.n_variables() as u32 <= LOG_N_LANES {
return into_simd_layer(layer.to_cpu().next_layer().unwrap());
}

match layer {
Layer::GrandProduct(col) => next_grand_product_layer(col),
Layer::LogUpGeneric {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpMultiplicities {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpSingles { denominators: _ } => todo!(),
}
}

fn sum_as_poly_in_first_variable(
_h: &GkrMultivariatePolyOracle<'_, Self>,
_claim: SecureField,
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
todo!()
let k = h.n_variables();
let n_terms = 1 << (k - 1);
let eq_evals = h.eq_evals.as_ref();
let y = eq_evals.y();

// Offload to CPU backend to prevent dealing with instances smaller than a SIMD vector.
if n_terms <= N_LANES {
return h.to_cpu().sum_as_poly_in_first_variable(claim);
}

let mut packed_eval_at_0 = PackedSecureField::zero();
let mut packed_eval_at_2 = PackedSecureField::zero();

match &h.input_layer {
Layer::GrandProduct(col) => process_grand_product_sum(
&mut packed_eval_at_0,
&mut packed_eval_at_2,
eq_evals,
col,
n_terms,
),
Layer::LogUpGeneric {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpMultiplicities {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpSingles { denominators: _ } => todo!(),
}

let eval_at_0 = packed_eval_at_0.pointwise_sum() * h.eq_fixed_var_correction;
let eval_at_2 = packed_eval_at_2.pointwise_sum() * h.eq_fixed_var_correction;

correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, k)
}
}

// Can assume `len(layer) > N_LANES * 2`
fn next_grand_product_layer(layer: &Mle<SimdBackend, SecureField>) -> Layer<SimdBackend> {
assert!(layer.len() > N_LANES);
let next_layer_len = layer.len() / 2;

let data = layer
.data
.array_chunks()
.map(|&[a, b]| {
let (evens, odds) = a.deinterleave(b);
evens * odds
})
.collect();

Layer::GrandProduct(Mle::new(SecureFieldVec {
data,
length: next_layer_len,
}))
}

// Can assume `n_terms > N_LANES`
fn process_grand_product_sum(
packed_eval_at_0: &mut PackedSecureField,
packed_eval_at_2: &mut PackedSecureField,
eq_evals: &EqEvals<SimdBackend>,
col: &Mle<SimdBackend, SecureField>,
n_terms: usize,
) {
assert!(n_terms > N_LANES);

#[allow(clippy::needless_range_loop)]
for i in 0..n_terms {
// Let `p` be the multilinear polynomial representing `col`.
let (p0x0 /* = p(0, x, 0) */, p0x1 /* = p(0, x, 1) */) =
col.data[i * 2].deinterleave(col.data[i * 2 + 1]);

// We obtain `p(2, x)` for some `x` in the boolean
// hypercube using `p(0, x)` and `p(1, x)`:
//
// ```text
// p(t, x) = eq(t, 0) * p(0, x) + eq(t, 1) * p(1, x)
// = (1 - t) * p(0, x) + t * p(1, x)
//
// p(2, x) = 2 * p(1, x) - p(0, x)
// ```
let (p1x0 /* = p(1, x, 0) */, p1x1 /* = p(1, x, 1) */) =
col.data[(n_terms + i) * 2].deinterleave(col.data[(n_terms + i) * 2 + 1]);
let p2x0 /* = p(2, x, 0) */ = p1x0.double() - p0x0;
let p2x1 /* = p(2, x, 1) */ = p1x1.double() - p0x1;

let product2 = p2x0 * p2x1;
let product0 = p0x0 * p0x1;

let eq_eval = eq_evals.data[i];
*packed_eval_at_0 += eq_eval * product0;
*packed_eval_at_2 += eq_eval * product2;
}
}

fn into_simd_layer(cpu_layer: Layer<CpuBackend>) -> Layer<SimdBackend> {
match cpu_layer {
Layer::GrandProduct(mle) => {
Layer::GrandProduct(Mle::new(mle.into_evals().into_iter().collect()))
}
Layer::LogUpGeneric {
numerators,
denominators,
} => Layer::LogUpGeneric {
numerators: Mle::new(numerators.into_evals().into_iter().collect()),
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
Layer::LogUpMultiplicities {
numerators,
denominators,
} => Layer::LogUpMultiplicities {
numerators: Mle::new(numerators.into_evals().into_iter().collect()),
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
Layer::LogUpSingles { denominators } => Layer::LogUpSingles {
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
}
}

#[cfg(test)]
mod tests {
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::lookups::mle::Mle;
use crate::core::test_utils::test_channel;

#[test]
fn gen_eq_evals_matches_cpu() {
Expand All @@ -76,4 +221,27 @@ mod tests {

assert_eq!(*cpu_eq_evals, simd_eq_evals.to_cpu());
}

#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 6;
let values = test_channel().draw_felts(N);
let product = values.iter().product();
let col = Mle::<SimdBackend, SecureField>::new(values.into_iter().collect());
let input_layer = Layer::GrandProduct(col.clone());
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);

let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

assert_eq!(proof.output_claims_by_instance, [vec![product]]);
assert_eq!(
claims_to_verify_by_instance,
[vec![col.eval_at_point(&ood_point)]]
);
Ok(())
}
}
1 change: 1 addition & 0 deletions crates/prover/src/core/backend/simd/lookups/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mod gkr;
// mod grandproduct;
mod mle;
61 changes: 57 additions & 4 deletions crates/prover/src/core/lookups/gkr_prover.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! Batch GKR protocol implementation designed to prove lookup arguments.
use std::borrow::Cow;
use std::iter::{successors, zip};
use std::ops::Deref;

use educe::Educe;
use itertools::Itertools;
use num_traits::{One, Zero};
use thiserror::Error;
Expand All @@ -10,7 +12,7 @@ use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask};
use super::mle::{Mle, MleOps};
use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::backend::{Col, Column, ColumnOps, CpuBackend};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -44,6 +46,8 @@ pub trait GkrOps: MleOps<BaseField> + MleOps<SecureField> {
/// `evals[1] = eq((0, ..., 0, 1), y)`, etc.
///
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
#[derive(Educe)]
#[educe(Debug, Clone)]
pub struct EqEvals<B: ColumnOps<SecureField>> {
y: Vec<SecureField>,
evals: Mle<B, SecureField>,
Expand Down Expand Up @@ -101,7 +105,7 @@ pub enum Layer<B: GkrOps> {

impl<B: GkrOps> Layer<B> {
/// Returns the number of variables used to interpolate the layer's gate values.
fn n_variables(&self) -> usize {
pub fn n_variables(&self) -> usize {
match self {
Self::GrandProduct(mle)
| Self::LogUpSingles { denominators: mle }
Expand Down Expand Up @@ -190,7 +194,7 @@ impl<B: GkrOps> Layer<B> {
eq_evals: &EqEvals<B>,
) -> GkrMultivariatePolyOracle<'_, B> {
GkrMultivariatePolyOracle {
eq_evals,
eq_evals: Cow::Borrowed(eq_evals),
input_layer: self,
eq_fixed_var_correction: SecureField::one(),
lambda,
Expand Down Expand Up @@ -230,6 +234,34 @@ impl<B: GkrOps> Layer<B> {
}
})
}

/// Returns a copy of this layer with the [`CpuBackend`].
///
/// This operation is expensive but can be useful for small traces that are difficult to handle
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
/// trace length becomes smaller than the SIMD lane count.
pub fn to_cpu(&self) -> Layer<CpuBackend> {
match self {
Layer::GrandProduct(mle) => Layer::GrandProduct(Mle::new(mle.to_cpu())),
Layer::LogUpGeneric {
numerators,
denominators,
} => Layer::LogUpGeneric {
numerators: Mle::new(numerators.to_cpu()),
denominators: Mle::new(denominators.to_cpu()),
},
Layer::LogUpMultiplicities {
numerators,
denominators,
} => Layer::LogUpMultiplicities {
numerators: Mle::new(numerators.to_cpu()),
denominators: Mle::new(denominators.to_cpu()),
},
Layer::LogUpSingles { denominators } => Layer::LogUpSingles {
denominators: Mle::new(denominators.to_cpu()),
},
}
}
}

#[derive(Debug)]
Expand All @@ -238,7 +270,7 @@ struct NotOutputLayerError;
/// A multivariate polynomial that expresses the relation between two consecutive GKR layers.
pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {
/// `eq_evals` passed by [`Layer::into_multivariate_poly()`].
pub eq_evals: &'a EqEvals<B>,
pub eq_evals: Cow<'a, EqEvals<B>>,
pub input_layer: Layer<B>,
pub eq_fixed_var_correction: SecureField,
/// Used by LogUp to perform a random linear combination of the numerators and denominators.
Expand Down Expand Up @@ -312,6 +344,27 @@ impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> {

Ok(GkrMask::new(columns))
}

/// Returns a copy of this oracle with the [`CpuBackend`].
///
/// This operation is expensive but can be useful for small oracles that are difficult to handle
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
/// trace length becomes smaller than the SIMD lane count.
pub fn to_cpu(&self) -> GkrMultivariatePolyOracle<'a, CpuBackend> {
// TODO(andrew): This block is not ideal.
let n_eq_evals = 1 << (self.n_variables() - 1);
let eq_evals = Cow::Owned(EqEvals {
evals: Mle::new((0..n_eq_evals).map(|i| self.eq_evals.at(i)).collect()),
y: self.eq_evals.y.to_vec(),
});

GkrMultivariatePolyOracle {
eq_evals,
eq_fixed_var_correction: self.eq_fixed_var_correction,
input_layer: self.input_layer.to_cpu(),
lambda: self.lambda,
}
}
}

/// Error returned when a polynomial is expected to be constant but it is not.
Expand Down

0 comments on commit 7ae75ab

Please sign in to comment.