Skip to content

Commit

Permalink
Implement LogupOps for SIMD backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Jul 11, 2024
1 parent 410fcd3 commit 8e801e7
Show file tree
Hide file tree
Showing 4 changed files with 476 additions and 64 deletions.
39 changes: 7 additions & 32 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, Index};
use std::ops::Index;

use num_traits::{One, Zero};

Expand All @@ -11,7 +11,7 @@ use crate::core::lookups::gkr_prover::{
};
use crate::core::lookups::mle::{Mle, MleOps};
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::{Fraction, UnivariatePoly};
use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly};

impl GkrOps for CpuBackend {
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
Expand Down Expand Up @@ -177,23 +177,6 @@ fn eval_logup_singles_sum(
n_terms: usize,
lambda: SecureField,
) -> (SecureField, SecureField) {
/// Represents the fraction `1 / x`
struct Reciprocal {
x: SecureField,
}

impl Add for Reciprocal {
type Output = Fraction<SecureField>;

fn add(self, rhs: Self) -> Fraction<SecureField> {
// `1/a + 1/b = (a + b)/(a * b)`
Fraction {
numerator: self.x + rhs.x,
denominator: self.x * rhs.x,
}
}
}

let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();

Expand All @@ -215,19 +198,11 @@ fn eval_logup_singles_sum(
let Fraction {
numerator: numer_at_r0i,
denominator: denom_at_r0i,
} = Reciprocal {
x: inp_denom_at_r0i0,
} + Reciprocal {
x: inp_denom_at_r0i1,
};
} = Reciprocal::new(inp_denom_at_r0i0) + Reciprocal::new(inp_denom_at_r0i1);
let Fraction {
numerator: numer_at_r2i,
denominator: denom_at_r2i,
} = Reciprocal {
x: inp_denom_at_r2i0,
} + Reciprocal {
x: inp_denom_at_r2i1,
};
} = Reciprocal::new(inp_denom_at_r2i0) + Reciprocal::new(inp_denom_at_r2i1);

let eq_eval_at_0i = eq_evals[i];
eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i);
Expand Down Expand Up @@ -372,7 +347,7 @@ mod tests {
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerator_values, &denominator_values)
.map(|(&n, &d)| Fraction::new(n, d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<CpuBackend, SecureField>::new(numerator_values);
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpGeneric {
Expand Down Expand Up @@ -414,7 +389,7 @@ mod tests {
let sum = denominator_values
.iter()
.map(|&d| Fraction::new(SecureField::one(), d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpSingles {
denominators: denominators.clone(),
Expand Down Expand Up @@ -448,7 +423,7 @@ mod tests {
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerator_values, &denominator_values)
.map(|(&n, &d)| Fraction::new(n.into(), d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<CpuBackend, BaseField>::new(numerator_values);
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpMultiplicities {
Expand Down
Loading

0 comments on commit 8e801e7

Please sign in to comment.