From ed308eb5e4107989f41e486b7178bae8a2581577 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Wed, 15 May 2024 22:15:31 -0400 Subject: [PATCH] Implement GkrOps for SIMD backend --- .../src/core/backend/cpu/lookups/gkr.rs | 15 ++-- .../src/core/backend/cpu/lookups/mod.rs | 2 +- crates/prover/src/core/backend/cpu/mod.rs | 2 +- .../src/core/backend/simd/lookups/gkr.rs | 79 +++++++++++++++++++ .../src/core/backend/simd/lookups/mod.rs | 1 + crates/prover/src/lib.rs | 1 + 6 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 crates/prover/src/core/backend/simd/lookups/gkr.rs diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index 8e4e16ded..2933c4dc1 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -234,7 +234,7 @@ fn process_logup_singles_sum( /// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`. /// /// Evaluations are returned in bit-reversed order. -fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec { +pub fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec { let mut evals = Vec::with_capacity(1 << y.len()); evals.push(v); @@ -321,17 +321,18 @@ mod tests { fn gen_eq_evals() { let zero = SecureField::zero(); let one = SecureField::one(); + let two = BaseField::from(2).into(); let y = [7, 3].map(|v| BaseField::from(v).into()); - let eq_evals = CpuBackend::gen_eq_evals(&y, one); + let eq_evals = CpuBackend::gen_eq_evals(&y, two); assert_eq!( - **eq_evals, + *eq_evals, [ - eq(&[zero, zero], &y), - eq(&[zero, one], &y), - eq(&[one, zero], &y), - eq(&[one, one], &y), + eq(&[zero, zero], &y) * two, + eq(&[zero, one], &y) * two, + eq(&[one, zero], &y) * two, + eq(&[one, one], &y) * two, ] ); } diff --git a/crates/prover/src/core/backend/cpu/lookups/mod.rs b/crates/prover/src/core/backend/cpu/lookups/mod.rs index 34395e985..cd8dedf46 100644 --- a/crates/prover/src/core/backend/cpu/lookups/mod.rs +++ b/crates/prover/src/core/backend/cpu/lookups/mod.rs @@ -1,2 +1,2 @@ -mod gkr; +pub mod gkr; mod mle; diff --git a/crates/prover/src/core/backend/cpu/mod.rs b/crates/prover/src/core/backend/cpu/mod.rs index c033e4cc2..c054d5933 100644 --- a/crates/prover/src/core/backend/cpu/mod.rs +++ b/crates/prover/src/core/backend/cpu/mod.rs @@ -2,7 +2,7 @@ mod accumulation; mod blake2s; mod circle; mod fri; -mod lookups; +pub mod lookups; pub mod quotients; use std::fmt::Debug; diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs new file mode 100644 index 000000000..bd9d635b5 --- /dev/null +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -0,0 +1,79 @@ +use std::iter::zip; + +use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals; +use crate::core::backend::simd::column::SecureFieldVec; +use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer}; +use crate::core::lookups::mle::Mle; +use crate::core::lookups::utils::UnivariatePoly; + +impl GkrOps for SimdBackend { + #[allow(clippy::uninit_vec)] + fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { + if y.len() < LOG_N_LANES as usize { + return Mle::new(cpu_gen_eq_evals(y, v).into_iter().collect()); + } + + // Start DP with CPU backend to prevent dealing with instances smaller than a SIMD vector. + let (y_initial, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); + let initial = SecureFieldVec::from_iter(cpu_gen_eq_evals(y_initial, v)); + assert_eq!(initial.len(), N_LANES); + + let packed_len = 1 << y_rem.len(); + let mut data = initial.data; + + data.reserve(packed_len - data.len()); + unsafe { data.set_len(packed_len) }; + + for (i, &y_j) in y_rem.iter().rev().enumerate() { + let packed_y_j = PackedSecureField::broadcast(y_j); + + let (lhs_evals, rhs_evals) = data.split_at_mut(1 << i); + + for (lhs, rhs) in zip(lhs_evals, rhs_evals) { + // Equivalent to: + // `rhs = eq(1, y_j) * lhs`, + // `lhs = eq(0, y_j) * lhs` + *rhs = *lhs * packed_y_j; + *lhs -= *rhs; + } + } + + let length = packed_len * N_LANES; + Mle::new(SecureFieldVec { data, length }) + } + + fn next_layer(_layer: &Layer) -> Layer { + todo!() + } + + fn sum_as_poly_in_first_variable( + _h: &GkrMultivariatePolyOracle<'_, Self>, + _claim: SecureField, + ) -> UnivariatePoly { + todo!() + } +} + +#[cfg(test)] +mod tests { + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::fields::m31::BaseField; + use crate::core::lookups::gkr_prover::GkrOps; + + #[test] + fn gen_eq_evals_matches_cpu() { + let two = BaseField::from(2).into(); + let y = [7, 3, 5, 6, 1, 1, 9].map(|v| BaseField::from(v).into()); + let cpu_eq_evals = CpuBackend::gen_eq_evals(&y, two); + + let simd_eq_evals = SimdBackend::gen_eq_evals(&y, two); + + assert_eq!(*cpu_eq_evals, simd_eq_evals.to_cpu()); + } +} diff --git a/crates/prover/src/core/backend/simd/lookups/mod.rs b/crates/prover/src/core/backend/simd/lookups/mod.rs index 1c1a481ae..34395e985 100644 --- a/crates/prover/src/core/backend/simd/lookups/mod.rs +++ b/crates/prover/src/core/backend/simd/lookups/mod.rs @@ -1 +1,2 @@ +mod gkr; mod mle; diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 6d70f257e..10e2670b6 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -8,6 +8,7 @@ stdsimd, get_many_mut, int_roundings, + slice_first_last_chunk, slice_flatten, assert_matches, portable_simd