Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GrandProductOps for SIMD backend #640

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,12 @@ impl GkrOps for CpuBackend {
let n_variables = h.n_variables();
assert!(!n_variables.is_zero());
let n_terms = 1 << (n_variables - 1);
let eq_evals = h.eq_evals;
let eq_evals = h.eq_evals.as_ref();
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
let y = eq_evals.y();
let lambda = h.lambda;
let input_layer = &h.input_layer;

let (mut eval_at_0, mut eval_at_2) = match input_layer {
let (mut eval_at_0, mut eval_at_2) = match &h.input_layer {
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms),
Layer::LogUpGeneric {
numerators,
Expand Down Expand Up @@ -88,6 +87,9 @@ fn eval_grand_product_sum(
let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1];
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
// TODO(andrew): Consider evaluation at `1/2` to save an addition operation since
// `inp(r, 1/2, x) = 1/2 * (inp(r, 1, x) + inp(r, 0, x))`. `1/2 * ...` can be factored
// outside the loop.
let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0;
let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1;

Expand Down
178 changes: 170 additions & 8 deletions crates/prover/src/core/backend/simd/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use std::iter::zip;

use num_traits::Zero;

use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals;
use crate::core::backend::simd::column::SecureFieldVec;
use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer};
use crate::core::lookups::gkr_prover::{
correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer,
};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::UnivariatePoly;

impl GkrOps for SimdBackend {
Expand Down Expand Up @@ -47,24 +52,158 @@ impl GkrOps for SimdBackend {
Mle::new(SecureFieldVec { data, length })
}

fn next_layer(_layer: &Layer<Self>) -> Layer<Self> {
todo!()
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
// Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector.
if layer.n_variables() as u32 <= LOG_N_LANES {
return into_simd_layer(layer.to_cpu().next_layer().unwrap());
}

match layer {
Layer::GrandProduct(col) => next_grand_product_layer(col),
Layer::LogUpGeneric {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpMultiplicities {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpSingles { denominators: _ } => todo!(),
}
}

fn sum_as_poly_in_first_variable(
_h: &GkrMultivariatePolyOracle<'_, Self>,
_claim: SecureField,
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
todo!()
let n_variables = h.n_variables();
let n_terms = 1 << n_variables.saturating_sub(1);
let eq_evals = h.eq_evals.as_ref();
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
let y = eq_evals.y();

// Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector.
if n_terms < N_LANES {
return h.to_cpu().sum_as_poly_in_first_variable(claim);
}

let n_packed_terms = n_terms / N_LANES;

let (mut eval_at_0, mut eval_at_2) = match &h.input_layer {
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_packed_terms),
Layer::LogUpGeneric {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpMultiplicities {
numerators: _,
denominators: _,
} => todo!(),
Layer::LogUpSingles { denominators: _ } => todo!(),
};

eval_at_0 *= h.eq_fixed_var_correction;
eval_at_2 *= h.eq_fixed_var_correction;
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
}
}

/// Generates the next GKR layer for Grand Product.
///
/// Assumption: `len(layer) > N_LANES`.
fn next_grand_product_layer(layer: &Mle<SimdBackend, SecureField>) -> Layer<SimdBackend> {
assert!(layer.len() > N_LANES);
let next_layer_len = layer.len() / 2;

let data = layer
.data
.array_chunks()
.map(|&[a, b]| {
let (evens, odds) = a.deinterleave(b);
evens * odds
})
.collect();

Layer::GrandProduct(Mle::new(SecureFieldVec {
data,
length: next_layer_len,
}))
}

/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`.
///
/// Output of the form: `(eval_at_0, eval_at_2)`.
fn eval_grand_product_sum(
eq_evals: &EqEvals<SimdBackend>,
col: &Mle<SimdBackend, SecureField>,
n_packed_terms: usize,
) -> (SecureField, SecureField) {
let mut packed_eval_at_0 = PackedSecureField::zero();
let mut packed_eval_at_2 = PackedSecureField::zero();

for i in 0..n_packed_terms {
// Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})`
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let (inp_at_r0iv0, inp_at_r0iv1) = col.data[i * 2].deinterleave(col.data[i * 2 + 1]);
let (inp_at_r1iv0, inp_at_r1iv1) =
col.data[(n_packed_terms + i) * 2].deinterleave(col.data[(n_packed_terms + i) * 2 + 1]);
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
let inp_at_r2iv0 = inp_at_r1iv0.double() - inp_at_r0iv0;
let inp_at_r2iv1 = inp_at_r1iv1.double() - inp_at_r0iv1;

// Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i), v)`.
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let prod_at_r2iv = inp_at_r2iv0 * inp_at_r2iv1;
let prod_at_r0iv = inp_at_r0iv0 * inp_at_r0iv1;

let eq_eval_at_0iv = eq_evals.data[i];
packed_eval_at_0 += eq_eval_at_0iv * prod_at_r0iv;
packed_eval_at_2 += eq_eval_at_0iv * prod_at_r2iv;
}

(
packed_eval_at_0.pointwise_sum(),
packed_eval_at_2.pointwise_sum(),
)
}

fn into_simd_layer(cpu_layer: Layer<CpuBackend>) -> Layer<SimdBackend> {
match cpu_layer {
Layer::GrandProduct(mle) => {
Layer::GrandProduct(Mle::new(mle.into_evals().into_iter().collect()))
}
Layer::LogUpGeneric {
numerators,
denominators,
} => Layer::LogUpGeneric {
numerators: Mle::new(numerators.into_evals().into_iter().collect()),
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
Layer::LogUpMultiplicities {
numerators,
denominators,
} => Layer::LogUpMultiplicities {
numerators: Mle::new(numerators.into_evals().into_iter().collect()),
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
Layer::LogUpSingles { denominators } => Layer::LogUpSingles {
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
}
}

#[cfg(test)]
mod tests {
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::lookups::mle::Mle;
use crate::core::test_utils::test_channel;

#[test]
fn gen_eq_evals_matches_cpu() {
Expand All @@ -87,4 +226,27 @@ mod tests {

assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu);
}

#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 8;
let values = test_channel().draw_felts(N);
let product = values.iter().product();
let col = Mle::<SimdBackend, SecureField>::new(values.into_iter().collect());
let input_layer = Layer::GrandProduct(col.clone());
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);

let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

assert_eq!(proof.output_claims_by_instance, [vec![product]]);
assert_eq!(
claims_to_verify_by_instance,
[vec![col.eval_at_point(&ood_point)]]
);
Ok(())
}
}
61 changes: 57 additions & 4 deletions crates/prover/src/core/lookups/gkr_prover.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! GKR batch prover for Grand Product and LogUp lookup arguments.
use std::borrow::Cow;
use std::iter::{successors, zip};
use std::ops::Deref;

use educe::Educe;
use itertools::Itertools;
use num_traits::{One, Zero};
use thiserror::Error;
Expand All @@ -10,7 +12,7 @@ use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask};
use super::mle::{Mle, MleOps};
use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::backend::{Col, Column, ColumnOps, CpuBackend};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -46,6 +48,8 @@ pub trait GkrOps: MleOps<BaseField> + MleOps<SecureField> {
/// `evals[1] = eq((0, ..., 0, 1), y)`, etc.
///
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
#[derive(Educe)]
#[educe(Debug, Clone)]
pub struct EqEvals<B: ColumnOps<SecureField>> {
y: Vec<SecureField>,
evals: Mle<B, SecureField>,
Expand Down Expand Up @@ -103,7 +107,7 @@ pub enum Layer<B: GkrOps> {

impl<B: GkrOps> Layer<B> {
/// Returns the number of variables used to interpolate the layer's gate values.
fn n_variables(&self) -> usize {
pub fn n_variables(&self) -> usize {
match self {
Self::GrandProduct(mle)
| Self::LogUpSingles { denominators: mle }
Expand Down Expand Up @@ -227,12 +231,40 @@ impl<B: GkrOps> Layer<B> {
eq_evals: &EqEvals<B>,
) -> GkrMultivariatePolyOracle<'_, B> {
GkrMultivariatePolyOracle {
eq_evals,
eq_evals: Cow::Borrowed(eq_evals),
input_layer: self,
eq_fixed_var_correction: SecureField::one(),
lambda,
}
}

/// Returns a copy of this layer with the [`CpuBackend`].
///
/// This operation is expensive but can be useful for small traces that are difficult to handle
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
/// trace length becomes smaller than the SIMD lane count.
pub fn to_cpu(&self) -> Layer<CpuBackend> {
match self {
Layer::GrandProduct(mle) => Layer::GrandProduct(Mle::new(mle.to_cpu())),
Layer::LogUpGeneric {
numerators,
denominators,
} => Layer::LogUpGeneric {
numerators: Mle::new(numerators.to_cpu()),
denominators: Mle::new(denominators.to_cpu()),
},
Layer::LogUpMultiplicities {
numerators,
denominators,
} => Layer::LogUpMultiplicities {
numerators: Mle::new(numerators.to_cpu()),
denominators: Mle::new(denominators.to_cpu()),
},
Layer::LogUpSingles { denominators } => Layer::LogUpSingles {
denominators: Mle::new(denominators.to_cpu()),
},
}
}
}

#[derive(Debug)]
Expand All @@ -258,7 +290,7 @@ struct NotOutputLayerError;
/// ```
pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {
/// `eq_evals` passed by `Layer::into_multivariate_poly()`.
pub eq_evals: &'a EqEvals<B>,
pub eq_evals: Cow<'a, EqEvals<B>>,
pub input_layer: Layer<B>,
pub eq_fixed_var_correction: SecureField,
/// Used by LogUp to perform a random linear combination of the numerators and denominators.
Expand Down Expand Up @@ -332,6 +364,27 @@ impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> {

Ok(GkrMask::new(columns))
}

/// Returns a copy of this oracle with the [`CpuBackend`].
///
/// This operation is expensive but can be useful for small oracles that are difficult to handle
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
/// trace length becomes smaller than the SIMD lane count.
pub fn to_cpu(&self) -> GkrMultivariatePolyOracle<'a, CpuBackend> {
// TODO(andrew): This block is not ideal.
let n_eq_evals = 1 << (self.n_variables() - 1);
let eq_evals = Cow::Owned(EqEvals {
evals: Mle::new((0..n_eq_evals).map(|i| self.eq_evals.at(i)).collect()),
y: self.eq_evals.y.to_vec(),
});

GkrMultivariatePolyOracle {
eq_evals,
eq_fixed_var_correction: self.eq_fixed_var_correction,
input_layer: self.input_layer.to_cpu(),
lambda: self.lambda,
}
}
}

/// Error returned when a polynomial is expected to be constant but it is not.
Expand Down
Loading