From 8de5550c31cf6013ca5ac83affcf4002cb9763bd Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Wed, 15 May 2024 22:15:31 -0400 Subject: [PATCH] Implement GkrOps for SIMD backend --- .../src/core/backend/cpu/lookups/gkr.rs | 4 +- .../src/core/backend/cpu/lookups/mod.rs | 2 +- crates/prover/src/core/backend/cpu/mod.rs | 2 +- .../src/core/backend/simd/lookups/gkr.rs | 90 +++++++++++++++++++ .../src/core/backend/simd/lookups/mod.rs | 1 + crates/prover/src/lib.rs | 1 + 6 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 crates/prover/src/core/backend/simd/lookups/gkr.rs diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index bfac7e140..2f735ea1e 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -239,7 +239,7 @@ fn eval_logup_singles_sum( /// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`. /// /// Evaluations are returned in bit-reversed order. -fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec { +pub fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec { let mut evals = Vec::with_capacity(1 << y.len()); evals.push(v); @@ -333,7 +333,7 @@ mod tests { let eq_evals = CpuBackend::gen_eq_evals(&y, two); assert_eq!( - **eq_evals, + *eq_evals, [ eq(&[zero, zero], &y) * two, eq(&[zero, one], &y) * two, diff --git a/crates/prover/src/core/backend/cpu/lookups/mod.rs b/crates/prover/src/core/backend/cpu/lookups/mod.rs index 34395e985..cd8dedf46 100644 --- a/crates/prover/src/core/backend/cpu/lookups/mod.rs +++ b/crates/prover/src/core/backend/cpu/lookups/mod.rs @@ -1,2 +1,2 @@ -mod gkr; +pub mod gkr; mod mle; diff --git a/crates/prover/src/core/backend/cpu/mod.rs b/crates/prover/src/core/backend/cpu/mod.rs index c033e4cc2..c054d5933 100644 --- a/crates/prover/src/core/backend/cpu/mod.rs +++ b/crates/prover/src/core/backend/cpu/mod.rs @@ -2,7 +2,7 @@ mod accumulation; mod blake2s; mod circle; mod fri; -mod lookups; +pub mod lookups; pub mod quotients; use std::fmt::Debug; diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs new file mode 100644 index 000000000..e7e5fcdbf --- /dev/null +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -0,0 +1,90 @@ +use std::iter::zip; + +use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals; +use crate::core::backend::simd::column::SecureFieldVec; +use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer}; +use crate::core::lookups::mle::Mle; +use crate::core::lookups::utils::UnivariatePoly; + +impl GkrOps for SimdBackend { + #[allow(clippy::uninit_vec)] + fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { + if y.len() < LOG_N_LANES as usize { + return Mle::new(cpu_gen_eq_evals(y, v).into_iter().collect()); + } + + // Start DP with CPU backend to prevent dealing with instances smaller than a SIMD vector. + let (y_last_chunk, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); + let initial = SecureFieldVec::from_iter(cpu_gen_eq_evals(y_last_chunk, v)); + assert_eq!(initial.len(), N_LANES); + + let packed_len = 1 << y_rem.len(); + let mut data = initial.data; + + data.reserve(packed_len - data.len()); + unsafe { data.set_len(packed_len) }; + + for (i, &y_j) in y_rem.iter().rev().enumerate() { + let packed_y_j = PackedSecureField::broadcast(y_j); + + let (lhs_evals, rhs_evals) = data.split_at_mut(1 << i); + + for (lhs, rhs) in zip(lhs_evals, rhs_evals) { + // Equivalent to: + // `rhs = eq(1, y_j) * lhs`, + // `lhs = eq(0, y_j) * lhs` + *rhs = *lhs * packed_y_j; + *lhs -= *rhs; + } + } + + let length = packed_len * N_LANES; + Mle::new(SecureFieldVec { data, length }) + } + + fn next_layer(_layer: &Layer) -> Layer { + todo!() + } + + fn sum_as_poly_in_first_variable( + _h: &GkrMultivariatePolyOracle<'_, Self>, + _claim: SecureField, + ) -> UnivariatePoly { + todo!() + } +} + +#[cfg(test)] +mod tests { + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::fields::m31::BaseField; + use crate::core::lookups::gkr_prover::GkrOps; + + #[test] + fn gen_eq_evals_matches_cpu() { + let two = BaseField::from(2).into(); + let y = [7, 3, 5, 6, 1, 1, 9].map(|v| BaseField::from(v).into()); + let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two); + + let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two); + + assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu); + } + + #[test] + fn gen_eq_evals_with_small_assignment_matches_cpu() { + let two = BaseField::from(2).into(); + let y = [7, 3, 5].map(|v| BaseField::from(v).into()); + let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two); + + let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two); + + assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu); + } +} diff --git a/crates/prover/src/core/backend/simd/lookups/mod.rs b/crates/prover/src/core/backend/simd/lookups/mod.rs index 1c1a481ae..34395e985 100644 --- a/crates/prover/src/core/backend/simd/lookups/mod.rs +++ b/crates/prover/src/core/backend/simd/lookups/mod.rs @@ -1 +1,2 @@ +mod gkr; mod mle; diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 79887aae3..b85ac28dd 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -8,6 +8,7 @@ stdsimd, get_many_mut, int_roundings, + slice_first_last_chunk, slice_flatten, assert_matches, portable_simd