Skip to content

Commit

Permalink
Add GKR implementation of Grand Product lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Jul 8, 2024
1 parent ab2322c commit f8ff1d6
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 27 deletions.
109 changes: 101 additions & 8 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,87 @@
use num_traits::Zero;

use crate::core::backend::CpuBackend;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer};
use crate::core::fields::Field;
use crate::core::lookups::gkr_prover::{
correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer,
};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::UnivariatePoly;

impl GkrOps for CpuBackend {
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
Mle::new(gen_eq_evals(y, v))
}

fn next_layer(_layer: &Layer<Self>) -> Layer<Self> {
todo!()
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
match layer {
Layer::GrandProduct(layer) => next_grand_product_layer(layer),
Layer::_LogUp(_) => todo!(),
}
}

fn sum_as_poly_in_first_variable(
_h: &GkrMultivariatePolyOracle<'_, Self>,
_claim: SecureField,
) -> crate::core::lookups::utils::UnivariatePoly<SecureField> {
todo!()
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
let n_variables = h.n_variables();
assert!(!n_variables.is_zero());
let n_terms = 1 << (n_variables - 1);
let eq_evals = h.eq_evals;
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
let y = eq_evals.y();
let input_layer = &h.input_layer;

let (mut eval_at_0, mut eval_at_2) = match input_layer {
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms),
Layer::_LogUp(_) => todo!(),
};

// Corrects the difference between two univariate sums in `t`:
// 1. `sum_x eq(({0}^|r|, 0, x), y) * F(r, t, x)`
// 2. `sum_x eq((r, t, x), y) * F(r, t, x)`
{
eval_at_0 *= h.eq_fixed_var_correction;
eval_at_2 *= h.eq_fixed_var_correction;
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
}
}
}

/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`.
///
/// Output of the form: `(eval_at_0, eval_at_2)`.
fn eval_grand_product_sum(
eq_evals: &EqEvals<CpuBackend>,
input_layer: &Mle<CpuBackend, SecureField>,
n_terms: usize,
) -> (SecureField, SecureField) {
let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();

for i in 0..n_terms {
// Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`.
let inp_at_r0i0 = input_layer[i * 2];
let inp_at_r0i1 = input_layer[i * 2 + 1];
let inp_at_r1i0 = input_layer[(n_terms + i) * 2];
let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1];
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0;
let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1;

// Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i))`.
let prod_at_r2i = inp_at_r2i0 * inp_at_r2i1;
let prod_at_r0i = inp_at_r0i0 * inp_at_r0i1;

let eq_eval_at_0i = eq_evals[i];
eval_at_0 += eq_eval_at_0i * prod_at_r0i;
eval_at_2 += eq_eval_at_0i * prod_at_r2i;
}

(eval_at_0, eval_at_2)
}

/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
Expand All @@ -40,15 +104,24 @@ fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec<SecureField> {
evals
}

fn next_grand_product_layer(layer: &Mle<CpuBackend, SecureField>) -> Layer<CpuBackend> {
let res = layer.array_chunks().map(|&[a, b]| a * b).collect();
Layer::GrandProduct(Mle::new(res))
}

#[cfg(test)]
mod tests {
use num_traits::{One, Zero};

use crate::core::backend::CpuBackend;
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::eq;
use crate::core::test_utils::test_channel;

#[test]
fn gen_eq_evals() {
Expand All @@ -69,4 +142,24 @@ mod tests {
]
);
}

#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 5;
let values = test_channel().draw_felts(N);
let product = values.iter().product::<SecureField>();
let col = Mle::<CpuBackend, SecureField>::new(values);
let input_layer = Layer::GrandProduct(col.clone());
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);

let GkrArtifact {
ood_point: r,
claims_to_verify_by_instance,
..
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

assert_eq!(proof.output_claims_by_instance, [vec![product]]);
assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]);
Ok(())
}
}
160 changes: 147 additions & 13 deletions crates/prover/src/core/lookups/gkr_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::sumcheck;

pub trait GkrOps: MleOps<SecureField> {
pub trait GkrOps: MleOps<BaseField> + MleOps<SecureField> {
/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
///
/// Note [`Mle`] stores values in bit-reversed order.
Expand Down Expand Up @@ -85,13 +87,16 @@ impl<B: ColumnOps<SecureField>> Deref for EqEvals<B> {
/// [LogUp]: https://eprint.iacr.org/2023/1284.pdf
pub enum Layer<B: GkrOps> {
_LogUp(B),
_GrandProduct(B),
GrandProduct(Mle<B, SecureField>),
}

impl<B: GkrOps> Layer<B> {
/// Returns the number of variables used to interpolate the layer's gate values.
fn n_variables(&self) -> usize {
todo!()
match self {
Self::_LogUp(_) => todo!(),
Self::GrandProduct(mle) => mle.n_variables(),
}
}

/// Produces the next layer from the current layer.
Expand All @@ -112,7 +117,28 @@ impl<B: GkrOps> Layer<B> {

/// Returns each column output if the layer is an output layer, otherwise returns an `Err`.
fn try_into_output_layer_values(self) -> Result<Vec<SecureField>, NotOutputLayerError> {
todo!()
if !self.is_output_layer() {
return Err(NotOutputLayerError);
}

Ok(match self {
Self::GrandProduct(col) => {
vec![col.at(0)]
}
Self::_LogUp(_) => todo!(),
})
}

/// Returns a transformed layer with the first variable of each column fixed to `assignment`.
fn fix_first_variable(self, x0: SecureField) -> Self {
if self.n_variables() == 0 {
return self;
}

match self {
Self::_LogUp(_) => todo!(),
Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)),
}
}

/// Represents the next GKR layer evaluation as a multivariate polynomial which uses this GKR
Expand Down Expand Up @@ -145,16 +171,37 @@ impl<B: GkrOps> Layer<B> {
fn into_multivariate_poly(
self,
_lambda: SecureField,
_eq_evals: &EqEvals<B>,
eq_evals: &EqEvals<B>,
) -> GkrMultivariatePolyOracle<'_, B> {
todo!()
GkrMultivariatePolyOracle {
eq_evals,
input_layer: self,
eq_fixed_var_correction: SecureField::one(),
}
}
}

#[derive(Debug)]
struct NotOutputLayerError;

/// A multivariate polynomial that expresses the relation between two consecutive GKR layers.
/// Multivariate polynomial `P` that expresses the relation between two consecutive GKR layers.
///
/// When the input layer is [`Layer::GrandProduct`] (represented by multilinear column `inp`)
/// the polynomial represents:
///
/// ```text
/// P(x) = eq(x, y) * inp(x, 0) * inp(x, 1)
/// ```
///
/// When the input layer is LogUp (represented multilinear columns `inp_numer` and
/// `inp_denom`) the polynomial represents:
///
/// ```text
/// numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)
/// denom(x) = inp_denom(x, 0) * inp_denom(x, 1)
///
/// P(x) = eq(x, y) * (numer(x) + lambda * denom(x))
/// ```
pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {
/// `eq_evals` passed by `Layer::into_multivariate_poly()`.
pub eq_evals: &'a EqEvals<B>,
Expand All @@ -164,15 +211,26 @@ pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {

impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> {
fn n_variables(&self) -> usize {
todo!()
self.input_layer.n_variables() - 1
}

fn sum_as_poly_in_first_variable(&self, _claim: SecureField) -> UnivariatePoly<SecureField> {
todo!()
fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly<SecureField> {
B::sum_as_poly_in_first_variable(self, claim)
}

fn fix_first_variable(self, _challenge: SecureField) -> Self {
todo!()
fn fix_first_variable(self, challenge: SecureField) -> Self {
if self.n_variables() == 0 {
return self;
}

let z0 = self.eq_evals.y()[self.eq_evals.y().len() - self.n_variables()];
let eq_fixed_var_correction = self.eq_fixed_var_correction * eq(&[challenge], &[z0]);

Self {
eq_evals: self.eq_evals,
eq_fixed_var_correction,
input_layer: self.input_layer.fix_first_variable(challenge),
}
}
}

Expand All @@ -188,7 +246,14 @@ impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> {
///
/// For more context see <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> page 64.
fn try_into_mask(self) -> Result<GkrMask, NotConstantPolyError> {
todo!()
if self.n_variables() != 0 {
return Err(NotConstantPolyError);
}

match self.input_layer {
Layer::_LogUp(_) => todo!(),
Layer::GrandProduct(mle) => Ok(GkrMask::new(vec![mle.to_cpu().try_into().unwrap()])),
}
}
}

Expand Down Expand Up @@ -319,3 +384,72 @@ fn gen_layers<B: GkrOps>(input_layer: Layer<B>) -> Vec<Layer<B>> {
assert_eq!(layers.len(), n_variables + 1);
layers
}

/// Corrects and interpolates GKR instance sumcheck round polynomials that are generated with the
/// precomputed `eq(x, y)` evaluations provided by `Layer::into_multivariate_poly()`.
///
/// Let `y` be a fixed vector of length `n` and let `z` be a subvector comprising of the last `k`
/// elements of `y`. Returns the univariate polynomial `f(t) = sum_x eq((t, x), z) * p(t, x)` for
/// `x` in the boolean hypercube `{0, 1}^(k-1)` when provided with:
///
/// * `claim` equalling `f(0) + f(1)`.
/// * `eval_at_0/2` equalling `sum_x eq(({0}^(n-k+1), x), y) * p(t, x)` at `t=0,2` respectively.
///
/// Note that `f` must have degree <= 3.
///
/// For more context see `Layer::into_multivariate_poly()` docs.
/// See also <https://ia.cr/2024/108> (section 3.2).
///
/// # Panics
///
/// Panics if:
/// * `k` is zero or greater than the length of `y`.
/// * `z_0` is zero.
pub fn correct_sum_as_poly_in_first_variable(
eval_at_0: SecureField,
eval_at_2: SecureField,
claim: SecureField,
y: &[SecureField],
k: usize,
) -> UnivariatePoly<SecureField> {
assert_ne!(k, 0);
let n = y.len();
assert!(k <= n);

let z = &y[n - k..];

// Corrects the difference between two sums:
// 1. `sum_x eq(({0}^(n-k+1), x), y) * p(t, x)`
// 2. `sum_x eq((0, x), z) * p(t, x)`
let eq_y_to_z_correction_factor = eq(&vec![SecureField::zero(); n - k], &y[0..n - k]).inverse();

// Corrects the difference between two sums:
// 1. `sum_x eq((0, x), z) * p(t, x)`
// 2. `sum_x eq((t, x), z) * p(t, x)`
let eq_correction_factor_at = |t| eq(&[t], &[z[0]]) / eq(&[SecureField::zero()], &[z[0]]);

// Let `v(t) = sum_x eq((0, x), z) * p(t, x)`. Apply trick from
// <https://ia.cr/2024/108> (section 3.2) to obtain `f` from `v`.
let t0: SecureField = BaseField::zero().into();
let t1: SecureField = BaseField::one().into();
let t2: SecureField = BaseField::from(2).into();
let t3: SecureField = BaseField::from(3).into();

// Obtain evals `v(0)`, `v(1)`, `v(2)`.
let mut y0 = eq_y_to_z_correction_factor * eval_at_0;
let mut y1 = (claim - y0) / eq_correction_factor_at(t1);
let mut y2 = eq_y_to_z_correction_factor * eval_at_2;

// Interpolate `v` to find `v(3)`. Note `v` has degree <= 2.
let v = UnivariatePoly::interpolate_lagrange(&[t0, t1, t2], &[y0, y1, y2]);
let mut y3 = v.eval_at_point(t3);

// Obtain evals of `f(0)`, `f(1)`, `f(2)`, `f(3)`.
y0 *= eq_correction_factor_at(t0);
y1 *= eq_correction_factor_at(t1);
y2 *= eq_correction_factor_at(t2);
y3 *= eq_correction_factor_at(t3);

// Interpolate `f(t)`. Note `f(t)` has degree <= 3.
UnivariatePoly::interpolate_lagrange(&[t0, t1, t2, t3], &[y0, y1, y2, y3])
}
Loading

0 comments on commit f8ff1d6

Please sign in to comment.