Skip to content

Commit

Permalink
Implement GrandProductOps for SIMD backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 25, 2024
1 parent f7837b1 commit e17f7b8
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 15 deletions.
114 changes: 114 additions & 0 deletions crates/prover/src/core/backend/simd/lookups/grandproduct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use num_traits::Zero;

use crate::core::backend::simd::column::SecureFieldVec;
use crate::core::backend::simd::m31::N_LANES;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr::correct_sum_as_poly_in_first_variable;
use crate::core::lookups::grandproduct::{GrandProductOps, GrandProductOracle, GrandProductTrace};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::UnivariatePoly;

impl GrandProductOps for SimdBackend {
fn next_layer(layer: &GrandProductTrace<Self>) -> GrandProductTrace<Self> {
let next_layer_len = layer.len() / 2;

// Offload to CPU backend to prevent dealing with instances smaller than a SIMD vector.
if next_layer_len < N_LANES {
return to_simd_trace(&CpuBackend::next_layer(&layer.to_cpu()));
}

let data = layer
.data
.array_chunks()
.map(|&[a, b]| {
let (evens, odds) = a.deinterleave(b);
evens * odds
})
.collect();

GrandProductTrace::new(Mle::new(SecureFieldVec {
data,
length: next_layer_len,
}))
}

fn sum_as_poly_in_first_variable(
h: &GrandProductOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
let k = h.n_variables();
let n_terms = 1 << (k - 1);
let eq_evals = h.eq_evals();
let y = eq_evals.y();
let trace = h.trace();

// Offload to CPU backend to prevent dealing with instances smaller than a SIMD vector.
if n_terms < 2 * N_LANES {
return h.to_cpu().sum_as_poly_in_first_variable(claim);
}

let n_packed_terms = n_terms / N_LANES;
let (lhs_data, rhs_data) = trace.data.split_at(trace.data.len() / 2);

let mut packed_eval_at_0 = PackedSecureField::zero();
let mut packed_eval_at_2 = PackedSecureField::zero();

for i in 0..n_packed_terms {
let (lhs0, lhs1) = lhs_data[i * 2].deinterleave(lhs_data[i * 2 + 1]);
let (rhs0, rhs1) = rhs_data[i * 2].deinterleave(rhs_data[i * 2 + 1]);

let product2 = (rhs0.double() - lhs0) * (rhs1.double() - lhs1);
let product0 = lhs0 * lhs1;

let eq_eval = eq_evals.data[i];
packed_eval_at_0 += eq_eval * product0;
packed_eval_at_2 += eq_eval * product2;
}

let eval_at_0 = packed_eval_at_0.pointwise_sum() * h.eq_fixed_var_correction();
let eval_at_2 = packed_eval_at_2.pointwise_sum() * h.eq_fixed_var_correction();

correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, k)
}
}

fn to_simd_trace(cpu_trace: &GrandProductTrace<CpuBackend>) -> GrandProductTrace<SimdBackend> {
GrandProductTrace::new(Mle::new((**cpu_trace).to_cpu().into_iter().collect()))
}

#[cfg(test)]
mod tests {
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Channel;
use crate::core::lookups::gkr::{partially_verify_batch, prove_batch, GkrArtifact, GkrError};
use crate::core::lookups::grandproduct::{GrandProductGate, GrandProductTrace};
use crate::core::lookups::mle::Mle;
use crate::core::test_utils::test_channel;

#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 6;
let values = test_channel().draw_felts(N);
let product = values.iter().product();
let mle = Mle::new(values.into_iter().collect());
let top_layer = GrandProductTrace::<SimdBackend>::new(mle);
let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer.clone()]);

let GkrArtifact {
ood_point,
claims_to_verify_by_component,
n_variables_by_component: _,
} = partially_verify_batch(vec![&GrandProductGate], &proof, &mut test_channel())?;

assert_eq!(proof.output_claims_by_component, [vec![product]]);
assert_eq!(
claims_to_verify_by_component,
[vec![top_layer.eval_at_point(&ood_point)]]
);
Ok(())
}
}
1 change: 1 addition & 0 deletions crates/prover/src/core/backend/simd/lookups/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mod gkr;
mod grandproduct;
mod mle;
16 changes: 9 additions & 7 deletions crates/prover/src/core/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use std::iter::{successors, zip};
use std::ops::Deref;

use derivative::Derivative;
use itertools::Itertools;
use num_traits::{One, Zero};
use thiserror::Error;
Expand All @@ -28,9 +29,11 @@ pub trait GkrOps: MleOps<SecureField> {
///
/// Evaluations are stored in bit-reversed order i.e. `evals[0] = eq((0, ..., 0, 0), y)`,
/// `evals[1] = eq((0, ..., 0, 1), y)`, etc.
#[derive(Derivative)]
#[derivative(Debug(bound = ""), Clone(bound = ""))]
pub struct EqEvals<B: ColumnOps<SecureField>> {
y: Vec<SecureField>,
evals: Mle<B, SecureField>,
pub y: Vec<SecureField>,
pub evals: Mle<B, SecureField>,
}

impl<B: GkrOps> EqEvals<B> {
Expand Down Expand Up @@ -570,7 +573,6 @@ mod tests {
use super::GkrError;
use crate::core::backend::CpuBackend;
use crate::core::channel::Channel;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr::{partially_verify_batch, prove_batch, GkrArtifact};
use crate::core::lookups::grandproduct::{GrandProductGate, GrandProductTrace};
use crate::core::lookups::mle::Mle;
Expand All @@ -582,8 +584,8 @@ mod tests {
let mut channel = test_channel();
let col0 = GrandProductTrace::<CpuBackend>::new(Mle::new(channel.draw_felts(1 << LOG_N)));
let col1 = GrandProductTrace::<CpuBackend>::new(Mle::new(channel.draw_felts(1 << LOG_N)));
let product0 = col0.iter().product::<SecureField>();
let product1 = col1.iter().product::<SecureField>();
let product0 = col0.iter().product();
let product1 = col1.iter().product();
let top_layers = vec![col0.clone(), col1.clone()];
let (proof, _) = prove_batch(&mut test_channel(), top_layers);

Expand Down Expand Up @@ -612,8 +614,8 @@ mod tests {
let mut channel = test_channel();
let col0 = GrandProductTrace::<CpuBackend>::new(Mle::new(channel.draw_felts(1 << LOG_N0)));
let col1 = GrandProductTrace::<CpuBackend>::new(Mle::new(channel.draw_felts(1 << LOG_N1)));
let product0 = col0.iter().product::<SecureField>();
let product1 = col1.iter().product::<SecureField>();
let product0 = col0.iter().product();
let product1 = col1.iter().product();
let top_layers = vec![col0.clone(), col1.clone()];
let (proof, _) = prove_batch(&mut test_channel(), top_layers);

Expand Down
48 changes: 40 additions & 8 deletions crates/prover/src/core/lookups/grandproduct.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::{Borrow, Cow};
use std::ops::Deref;

use derivative::Derivative;
Expand All @@ -10,7 +11,7 @@ use super::gkr::{
use super::mle::{Mle, MleOps};
use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, UnivariatePoly};
use crate::core::backend::{Column, ColumnOps};
use crate::core::backend::{Column, ColumnOps, CpuBackend};
use crate::core::fields::qm31::SecureField;

pub trait GrandProductOps: MleOps<SecureField> + GkrOps + Sized + 'static {
Expand All @@ -30,12 +31,21 @@ pub trait GrandProductOps: MleOps<SecureField> + GkrOps + Sized + 'static {

#[derive(Derivative)]
#[derivative(Debug(bound = ""), Clone(bound = ""))]
pub struct GrandProductTrace<B: ColumnOps<SecureField>>(pub Mle<B, SecureField>);
pub struct GrandProductTrace<B: ColumnOps<SecureField>>(Mle<B, SecureField>);

impl<B: ColumnOps<SecureField>> GrandProductTrace<B> {
pub fn new(column: Mle<B, SecureField>) -> Self {
Self(column)
}

/// Returns a copy of this trace with the [`CpuBackend`].
///
/// This operation is expensive but can be useful for small traces that are difficult to handle
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
/// trace length becomes smaller than the SIMD lane count.
pub fn to_cpu(&self) -> GrandProductTrace<CpuBackend> {
GrandProductTrace::new(Mle::new(self.0.to_cpu()))
}
}

impl<B: ColumnOps<SecureField>> Deref for GrandProductTrace<B> {
Expand All @@ -60,7 +70,7 @@ impl<B: GrandProductOps> GkrBinaryLayer for GrandProductTrace<B> {
let next_layer = B::next_layer(self);

if next_layer.n_variables() == 0 {
Layer::Output(next_layer.to_cpu())
Layer::Output(next_layer.0.to_cpu())
} else {
Layer::Internal(next_layer)
}
Expand All @@ -71,20 +81,20 @@ impl<B: GrandProductOps> GkrBinaryLayer for GrandProductTrace<B> {
_: SecureField,
eq_evals: &EqEvals<B>,
) -> GrandProductOracle<'_, B> {
GrandProductOracle::new(eq_evals, self)
GrandProductOracle::new(Cow::Borrowed(eq_evals), self)
}
}

/// Multivariate polynomial oracle.
pub struct GrandProductOracle<'a, B: GrandProductOps> {
/// `eq_evals` passed by [`GkrBinaryLayer::into_multivariate_poly()`].
eq_evals: &'a EqEvals<B>,
eq_evals: Cow<'a, EqEvals<B>>,
eq_fixed_var_correction: SecureField,
trace: GrandProductTrace<B>,
}

impl<'a, B: GrandProductOps> GrandProductOracle<'a, B> {
pub fn new(eq_evals: &'a EqEvals<B>, trace: GrandProductTrace<B>) -> Self {
pub fn new(eq_evals: Cow<'a, EqEvals<B>>, trace: GrandProductTrace<B>) -> Self {
Self {
eq_evals,
eq_fixed_var_correction: SecureField::one(),
Expand All @@ -93,7 +103,7 @@ impl<'a, B: GrandProductOps> GrandProductOracle<'a, B> {
}

pub fn eq_evals(&self) -> &EqEvals<B> {
self.eq_evals
self.eq_evals.borrow()
}

pub fn eq_fixed_var_correction(&self) -> SecureField {
Expand All @@ -103,6 +113,26 @@ impl<'a, B: GrandProductOps> GrandProductOracle<'a, B> {
pub fn trace(&self) -> &GrandProductTrace<B> {
&self.trace
}

/// Returns a copy of this oracle with the [`CpuBackend`].
///
/// This operation is expensive but can be useful for small oracles that are difficult to handle
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
/// trace length becomes smaller than the SIMD lane count.
pub fn to_cpu(&self) -> GrandProductOracle<'a, CpuBackend> {
// TODO(andrew): This block is not ideal.
let n_eq_evals = 1 << (self.n_variables() - 1);
let eq_evals = Cow::Owned(EqEvals {
evals: Mle::new((0..n_eq_evals).map(|i| self.eq_evals.at(i)).collect()),
y: self.eq_evals.y.to_vec(),
});

GrandProductOracle {
eq_evals,
eq_fixed_var_correction: self.eq_fixed_var_correction,
trace: self.trace.to_cpu(),
}
}
}

impl<'a, B: GrandProductOps> MultivariatePolyOracle for GrandProductOracle<'a, B> {
Expand Down Expand Up @@ -136,7 +166,9 @@ impl<'a, B: GrandProductOps> GkrMultivariatePolyOracle for GrandProductOracle<'a
return Err(NotConstantPolyError);
}

Ok(GkrMask::new(vec![self.trace.to_cpu().try_into().unwrap()]))
let evals = self.trace.0.to_cpu().try_into().unwrap();

Ok(GkrMask::new(vec![evals]))
}
}

Expand Down

0 comments on commit e17f7b8

Please sign in to comment.