diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index 53a14391f..11308c6d4 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -1,23 +1,87 @@ +use num_traits::Zero; + use crate::core::backend::CpuBackend; use crate::core::fields::qm31::SecureField; -use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer}; +use crate::core::fields::Field; +use crate::core::lookups::gkr_prover::{ + correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer, +}; use crate::core::lookups::mle::Mle; +use crate::core::lookups::sumcheck::MultivariatePolyOracle; +use crate::core::lookups::utils::UnivariatePoly; impl GkrOps for CpuBackend { fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { Mle::new(gen_eq_evals(y, v)) } - fn next_layer(_layer: &Layer) -> Layer { - todo!() + fn next_layer(layer: &Layer) -> Layer { + match layer { + Layer::GrandProduct(layer) => next_grand_product_layer(layer), + Layer::_LogUp(_) => todo!(), + } } fn sum_as_poly_in_first_variable( - _h: &GkrMultivariatePolyOracle<'_, Self>, - _claim: SecureField, - ) -> crate::core::lookups::utils::UnivariatePoly { - todo!() + h: &GkrMultivariatePolyOracle<'_, Self>, + claim: SecureField, + ) -> UnivariatePoly { + let n_variables = h.n_variables(); + assert!(!n_variables.is_zero()); + let n_terms = 1 << (n_variables - 1); + let eq_evals = h.eq_evals; + // Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube. + let y = eq_evals.y(); + let input_layer = &h.input_layer; + + let (mut eval_at_0, mut eval_at_2) = match input_layer { + Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms), + Layer::_LogUp(_) => todo!(), + }; + + // Corrects the difference between two univariate sums in `t`: + // 1. `sum_x eq(({0}^|r|, 0, x), y) * F(r, t, x)` + // 2. `sum_x eq((r, t, x), y) * F(r, t, x)` + { + eval_at_0 *= h.eq_fixed_var_correction; + eval_at_2 *= h.eq_fixed_var_correction; + correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables) + } + } +} + +/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`. +/// +/// Output of the form: `(eval_at_0, eval_at_2)`. +fn eval_grand_product_sum( + eq_evals: &EqEvals, + input_layer: &Mle, + n_terms: usize, +) -> (SecureField, SecureField) { + let mut eval_at_0 = SecureField::zero(); + let mut eval_at_2 = SecureField::zero(); + + for i in 0..n_terms { + // Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`. + let inp_at_r0i0 = input_layer[i * 2]; + let inp_at_r0i1 = input_layer[i * 2 + 1]; + let inp_at_r1i0 = input_layer[(n_terms + i) * 2]; + let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1]; + // Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)` + // => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)` + let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0; + let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1; + + // Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i))`. + let prod_at_r2i = inp_at_r2i0 * inp_at_r2i1; + let prod_at_r0i = inp_at_r0i0 * inp_at_r0i1; + + let eq_eval_at_0i = eq_evals[i]; + eval_at_0 += eq_eval_at_0i * prod_at_r0i; + eval_at_2 += eq_eval_at_0i * prod_at_r2i; } + + (eval_at_0, eval_at_2) } /// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`. @@ -40,15 +104,24 @@ fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec { evals } +fn next_grand_product_layer(layer: &Mle) -> Layer { + let res = layer.array_chunks().map(|&[a, b]| a * b).collect(); + Layer::GrandProduct(Mle::new(res)) +} + #[cfg(test)] mod tests { use num_traits::{One, Zero}; use crate::core::backend::CpuBackend; + use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; - use crate::core::lookups::gkr_prover::GkrOps; + use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; + use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use crate::core::lookups::mle::Mle; use crate::core::lookups::utils::eq; + use crate::core::test_utils::test_channel; #[test] fn gen_eq_evals() { @@ -69,4 +142,24 @@ mod tests { ] ); } + + #[test] + fn grand_product_works() -> Result<(), GkrError> { + const N: usize = 1 << 5; + let values = test_channel().draw_felts(N); + let product = values.iter().product::(); + let col = Mle::::new(values); + let input_layer = Layer::GrandProduct(col.clone()); + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point: r, + claims_to_verify_by_instance, + .. + } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; + + assert_eq!(proof.output_claims_by_instance, [vec![product]]); + assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]); + Ok(()) + } } diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index 4485aa36c..1abf70e79 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -12,10 +12,12 @@ use super::sumcheck::MultivariatePolyOracle; use super::utils::{eq, random_linear_combination, UnivariatePoly}; use crate::core::backend::{Col, Column, ColumnOps}; use crate::core::channel::Channel; +use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; use crate::core::lookups::sumcheck; -pub trait GkrOps: MleOps { +pub trait GkrOps: MleOps + MleOps { /// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`. /// /// Note [`Mle`] stores values in bit-reversed order. @@ -85,13 +87,16 @@ impl> Deref for EqEvals { /// [LogUp]: https://eprint.iacr.org/2023/1284.pdf pub enum Layer { _LogUp(B), - _GrandProduct(B), + GrandProduct(Mle), } impl Layer { /// Returns the number of variables used to interpolate the layer's gate values. fn n_variables(&self) -> usize { - todo!() + match self { + Self::_LogUp(_) => todo!(), + Self::GrandProduct(mle) => mle.n_variables(), + } } /// Produces the next layer from the current layer. @@ -112,7 +117,28 @@ impl Layer { /// Returns each column output if the layer is an output layer, otherwise returns an `Err`. fn try_into_output_layer_values(self) -> Result, NotOutputLayerError> { - todo!() + if !self.is_output_layer() { + return Err(NotOutputLayerError); + } + + Ok(match self { + Self::GrandProduct(col) => { + vec![col.at(0)] + } + Self::_LogUp(_) => todo!(), + }) + } + + /// Returns a transformed layer with the first variable of each column fixed to `assignment`. + fn fix_first_variable(self, x0: SecureField) -> Self { + if self.n_variables() == 0 { + return self; + } + + match self { + Self::_LogUp(_) => todo!(), + Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)), + } } /// Represents the next GKR layer evaluation as a multivariate polynomial which uses this GKR @@ -145,16 +171,37 @@ impl Layer { fn into_multivariate_poly( self, _lambda: SecureField, - _eq_evals: &EqEvals, + eq_evals: &EqEvals, ) -> GkrMultivariatePolyOracle<'_, B> { - todo!() + GkrMultivariatePolyOracle { + eq_evals, + input_layer: self, + eq_fixed_var_correction: SecureField::one(), + } } } #[derive(Debug)] struct NotOutputLayerError; -/// A multivariate polynomial that expresses the relation between two consecutive GKR layers. +/// Multivariate polynomial `P` that expresses the relation between two consecutive GKR layers. +/// +/// When the input layer is [`Layer::GrandProduct`] (represented by multilinear column `inp`) +/// the polynomial represents: +/// +/// ```text +/// P(x) = eq(x, y) * inp(x, 0) * inp(x, 1) +/// ``` +/// +/// When the input layer is LogUp (represented multilinear columns `inp_numer` and +/// `inp_denom`) the polynomial represents: +/// +/// ```text +/// numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0) +/// denom(x) = inp_denom(x, 0) * inp_denom(x, 1) +/// +/// P(x) = eq(x, y) * (numer(x) + lambda * denom(x)) +/// ``` pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { /// `eq_evals` passed by `Layer::into_multivariate_poly()`. pub eq_evals: &'a EqEvals, @@ -164,15 +211,26 @@ pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> { fn n_variables(&self) -> usize { - todo!() + self.input_layer.n_variables() - 1 } - fn sum_as_poly_in_first_variable(&self, _claim: SecureField) -> UnivariatePoly { - todo!() + fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly { + B::sum_as_poly_in_first_variable(self, claim) } - fn fix_first_variable(self, _challenge: SecureField) -> Self { - todo!() + fn fix_first_variable(self, challenge: SecureField) -> Self { + if self.n_variables() == 0 { + return self; + } + + let z0 = self.eq_evals.y()[self.eq_evals.y().len() - self.n_variables()]; + let eq_fixed_var_correction = self.eq_fixed_var_correction * eq(&[challenge], &[z0]); + + Self { + eq_evals: self.eq_evals, + eq_fixed_var_correction, + input_layer: self.input_layer.fix_first_variable(challenge), + } } } @@ -188,7 +246,14 @@ impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> { /// /// For more context see page 64. fn try_into_mask(self) -> Result { - todo!() + if self.n_variables() != 0 { + return Err(NotConstantPolyError); + } + + match self.input_layer { + Layer::_LogUp(_) => todo!(), + Layer::GrandProduct(mle) => Ok(GkrMask::new(vec![mle.to_cpu().try_into().unwrap()])), + } } } @@ -319,3 +384,72 @@ fn gen_layers(input_layer: Layer) -> Vec> { assert_eq!(layers.len(), n_variables + 1); layers } + +/// Corrects and interpolates GKR instance sumcheck round polynomials that are generated with the +/// precomputed `eq(x, y)` evaluations provided by `Layer::into_multivariate_poly()`. +/// +/// Let `y` be a fixed vector of length `n` and let `z` be a subvector comprising of the last `k` +/// elements of `y`. Returns the univariate polynomial `f(t) = sum_x eq((t, x), z) * p(t, x)` for +/// `x` in the boolean hypercube `{0, 1}^(k-1)` when provided with: +/// +/// * `claim` equalling `f(0) + f(1)`. +/// * `eval_at_0/2` equalling `sum_x eq(({0}^(n-k+1), x), y) * p(t, x)` at `t=0,2` respectively. +/// +/// Note that `f` must have degree <= 3. +/// +/// For more context see `Layer::into_multivariate_poly()` docs. +/// See also (section 3.2). +/// +/// # Panics +/// +/// Panics if: +/// * `k` is zero or greater than the length of `y`. +/// * `z_0` is zero. +pub fn correct_sum_as_poly_in_first_variable( + eval_at_0: SecureField, + eval_at_2: SecureField, + claim: SecureField, + y: &[SecureField], + k: usize, +) -> UnivariatePoly { + assert_ne!(k, 0); + let n = y.len(); + assert!(k <= n); + + let z = &y[n - k..]; + + // Corrects the difference between two sums: + // 1. `sum_x eq(({0}^(n-k+1), x), y) * p(t, x)` + // 2. `sum_x eq((0, x), z) * p(t, x)` + let eq_y_to_z_correction_factor = eq(&vec![SecureField::zero(); n - k], &y[0..n - k]).inverse(); + + // Corrects the difference between two sums: + // 1. `sum_x eq((0, x), z) * p(t, x)` + // 2. `sum_x eq((t, x), z) * p(t, x)` + let eq_correction_factor_at = |t| eq(&[t], &[z[0]]) / eq(&[SecureField::zero()], &[z[0]]); + + // Let `v(t) = sum_x eq((0, x), z) * p(t, x)`. Apply trick from + // (section 3.2) to obtain `f` from `v`. + let t0: SecureField = BaseField::zero().into(); + let t1: SecureField = BaseField::one().into(); + let t2: SecureField = BaseField::from(2).into(); + let t3: SecureField = BaseField::from(3).into(); + + // Obtain evals `v(0)`, `v(1)`, `v(2)`. + let mut y0 = eq_y_to_z_correction_factor * eval_at_0; + let mut y1 = (claim - y0) / eq_correction_factor_at(t1); + let mut y2 = eq_y_to_z_correction_factor * eval_at_2; + + // Interpolate `v` to find `v(3)`. Note `v` has degree <= 2. + let v = UnivariatePoly::interpolate_lagrange(&[t0, t1, t2], &[y0, y1, y2]); + let mut y3 = v.eval_at_point(t3); + + // Obtain evals of `f(0)`, `f(1)`, `f(2)`, `f(3)`. + y0 *= eq_correction_factor_at(t0); + y1 *= eq_correction_factor_at(t1); + y2 *= eq_correction_factor_at(t2); + y3 *= eq_correction_factor_at(t3); + + // Interpolate `f(t)`. Note `f(t)` has degree <= 3. + UnivariatePoly::interpolate_lagrange(&[t0, t1, t2, t3], &[y0, y1, y2, y3]) +} diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index 90a403cff..7c9598b79 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -175,15 +175,26 @@ pub struct GkrArtifact { /// circuit) GKR prover implementations. /// /// [Thaler13]: https://eprint.iacr.org/2013/351.pdf +#[derive(Debug, Clone, Copy)] pub enum Gate { _LogUp, - _GrandProduct, + GrandProduct, } impl Gate { /// Returns the output after applying the gate to the mask. - fn eval(&self, _mask: &GkrMask) -> Result, InvalidNumMaskColumnsError> { - todo!() + fn eval(&self, mask: &GkrMask) -> Result, InvalidNumMaskColumnsError> { + Ok(match self { + Self::_LogUp => todo!(), + Self::GrandProduct => { + if mask.columns().len() != 1 { + return Err(InvalidNumMaskColumnsError); + } + + let [a, b] = mask.columns()[0]; + vec![a * b] + } + }) } } @@ -253,3 +264,80 @@ pub enum GkrError { /// GKR layer index where 0 corresponds to the output layer. pub type LayerIndex = usize; + +#[cfg(test)] +mod tests { + use super::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use crate::core::backend::CpuBackend; + use crate::core::channel::Channel; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::gkr_prover::{prove_batch, Layer}; + use crate::core::lookups::mle::Mle; + use crate::core::test_utils::test_channel; + + #[test] + fn prove_batch_works() -> Result<(), GkrError> { + const LOG_N: usize = 5; + let mut channel = test_channel(); + let col0 = Mle::::new(channel.draw_felts(1 << LOG_N)); + let col1 = Mle::::new(channel.draw_felts(1 << LOG_N)); + let product0 = col0.iter().product::(); + let product1 = col1.iter().product::(); + let input_layers = vec![ + Layer::GrandProduct(col0.clone()), + Layer::GrandProduct(col1.clone()), + ]; + let (proof, _) = prove_batch(&mut test_channel(), input_layers); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance, + } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; + + assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]); + assert_eq!(proof.output_claims_by_instance.len(), 2); + assert_eq!(claims_to_verify_by_instance.len(), 2); + assert_eq!(proof.output_claims_by_instance[0], &[product0]); + assert_eq!(proof.output_claims_by_instance[1], &[product1]); + let claim0 = &claims_to_verify_by_instance[0]; + let claim1 = &claims_to_verify_by_instance[1]; + assert_eq!(claim0, &[col0.eval_at_point(&ood_point)]); + assert_eq!(claim1, &[col1.eval_at_point(&ood_point)]); + Ok(()) + } + + #[test] + fn prove_batch_with_different_sizes_works() -> Result<(), GkrError> { + const LOG_N0: usize = 5; + const LOG_N1: usize = 7; + let mut channel = test_channel(); + let col0 = Mle::::new(channel.draw_felts(1 << LOG_N0)); + let col1 = Mle::::new(channel.draw_felts(1 << LOG_N1)); + let product0 = col0.iter().product::(); + let product1 = col1.iter().product::(); + let input_layers = vec![ + Layer::GrandProduct(col0.clone()), + Layer::GrandProduct(col1.clone()), + ]; + let (proof, _) = prove_batch(&mut test_channel(), input_layers); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance, + } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; + + assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]); + assert_eq!(proof.output_claims_by_instance.len(), 2); + assert_eq!(claims_to_verify_by_instance.len(), 2); + assert_eq!(proof.output_claims_by_instance[0], &[product0]); + assert_eq!(proof.output_claims_by_instance[1], &[product1]); + let claim0 = &claims_to_verify_by_instance[0]; + let claim1 = &claims_to_verify_by_instance[1]; + let n_vars = ood_point.len(); + assert_eq!(claim0, &[col0.eval_at_point(&ood_point[n_vars - LOG_N0..])]); + assert_eq!(claim1, &[col1.eval_at_point(&ood_point[n_vars - LOG_N1..])]); + Ok(()) + } +} diff --git a/crates/prover/src/core/lookups/mle.rs b/crates/prover/src/core/lookups/mle.rs index cd6212686..7ac7f9eb3 100644 --- a/crates/prover/src/core/lookups/mle.rs +++ b/crates/prover/src/core/lookups/mle.rs @@ -72,7 +72,7 @@ mod test { B: MleOps, { /// Evaluates the multilinear polynomial at `point`. - pub(crate) fn eval_at_point(self, point: &[SecureField]) -> SecureField { + pub(crate) fn eval_at_point(&self, point: &[SecureField]) -> SecureField { pub fn eval(mle_evals: &[SecureField], p: &[SecureField]) -> SecureField { match p { [] => mle_evals[0], diff --git a/crates/prover/src/core/test_utils.rs b/crates/prover/src/core/test_utils.rs index 083aa88ed..431c25778 100644 --- a/crates/prover/src/core/test_utils.rs +++ b/crates/prover/src/core/test_utils.rs @@ -3,6 +3,7 @@ use super::channel::Blake2sChannel; use super::fields::m31::BaseField; use super::fields::qm31::SecureField; use crate::core::channel::Channel; +use crate::core::vcs::blake2_hash::Blake2sHash; pub fn secure_eval_to_base_eval( eval: &CpuCircleEvaluation, @@ -14,8 +15,6 @@ pub fn secure_eval_to_base_eval( } pub fn test_channel() -> Blake2sChannel { - use crate::core::vcs::blake2_hash::Blake2sHash; - let seed = Blake2sHash::from(vec![0; 32]); Blake2sChannel::new(seed) }