Skip to content

Commit

Permalink
Poseidon with logup (#712)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/712)
<!-- Reviewable:end -->
  • Loading branch information
spapinistarkware authored Jul 30, 2024
1 parent 68ec132 commit b06f87e
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 82 deletions.
99 changes: 97 additions & 2 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::ops::{Add, Mul, Sub};

use itertools::Itertools;
use num_traits::Zero;
use num_traits::{One, Zero};

use super::EvalAtRow;
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
Expand All @@ -19,8 +20,102 @@ use crate::core::poly::BitReversedOrder;
use crate::core::utils::shifted_secure_combination;
use crate::core::ColumnVec;

/// Evaluates constraints for batched logups.
/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum.
/// BATCH_SIZE is the number of fractions to batch together. The degree of the resulting constraints
/// will be BATCH_SIZE + 1.
pub struct LogupAtRow<const BATCH_SIZE: usize, E: EvalAtRow> {
/// The index of the interaction used for the cumulative sum columns.
pub interaction: usize,
/// Queue of fractions waiting to be batched together.
pub queue: [(E::EF, E::EF); BATCH_SIZE],
/// Number of fractions in the queue.
pub queue_size: usize,
/// The claimed sum of all the fractions.
pub claimed_sum: SecureField,
/// The evaluation of the last cumulative sum column.
pub prev_col_cumsum: E::EF,
/// The value of the `is_first` constant column at current row.
/// See [`super::constant_columns::gen_is_first()`].
pub is_first: E::F,
}
impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
pub fn new(interaction: usize, claimed_sum: SecureField, is_first: E::F) -> Self {
Self {
interaction,
queue: [(E::EF::zero(), E::EF::zero()); BATCH_SIZE],
queue_size: 0,
claimed_sum,
prev_col_cumsum: E::EF::zero(),
is_first,
}
}
pub fn push_lookup(
&mut self,
eval: &mut E,
numerator: E::EF,
values: &[E::F],
lookup_elements: LookupElements,
) {
let shifted_value = shifted_secure_combination(
values,
E::EF::from(lookup_elements.alpha),
E::EF::from(lookup_elements.z),
);
self.push_frac(eval, numerator, shifted_value);
}

pub fn push_frac(&mut self, eval: &mut E, numerator: E::EF, denominator: E::EF) {
if self.queue_size < BATCH_SIZE {
self.queue[self.queue_size] = (numerator, denominator);
self.queue_size += 1;
return;
}

// Compute sum_i pi/qi over batch, as a fraction, num/denom.
let (num, denom) = self
.queue
.iter()
.copied()
.fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| {
(p0 * qi + pi * q0, qi * q0)
});

self.queue[0] = (numerator, denominator);
self.queue_size = 1;

// Add a constraint that num / denom = diff.
let cur_cumsum = E::combine_ef(std::array::from_fn(|_| {
eval.next_interaction_mask(self.interaction, [0])[0]
}));
let diff = cur_cumsum - self.prev_col_cumsum;
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * denom - num);
}

pub fn finalize(self, eval: &mut E) {
let (num, denom) = self.queue[0..self.queue_size]
.iter()
.copied()
.fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| {
(p0 * qi + pi * q0, qi * q0)
});

let cumsum_mask =
std::array::from_fn(|_| eval.next_interaction_mask(self.interaction, [0, -1]));
let cur_cumsum = E::combine_ef(cumsum_mask.map(|[cur_row, _prev_row]| cur_row));
let prev_row_cumsum = E::combine_ef(cumsum_mask.map(|[_cur_row, prev_row]| prev_row));

// Fix `prev_row_cumsum` by subtracting `claimed_sum` if this is the first row.
let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first * self.claimed_sum;
let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum;

eval.add_constraint(diff * denom - num);
}
}

/// Interaction elements for the logup protocol.
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct LookupElements {
pub z: SecureField,
pub alpha: SecureField,
Expand Down
6 changes: 4 additions & 2 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod simd_domain;

use std::array;
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Sub};
use std::ops::{Add, AddAssign, Mul, Neg, Sub};

pub use assert::{assert_constraints, AssertEvaluator};
pub use info::InfoEvaluator;
Expand Down Expand Up @@ -45,13 +45,15 @@ pub trait EvalAtRow {
+ Copy
+ Debug
+ Zero
+ Neg<Output = Self::EF>
+ Add<SecureField, Output = Self::EF>
+ Sub<SecureField, Output = Self::EF>
+ Mul<SecureField, Output = Self::EF>
+ Add<Self::F, Output = Self::EF>
+ Mul<Self::F, Output = Self::EF>
+ Sub<Self::EF, Output = Self::EF>
+ Mul<Self::EF, Output = Self::EF>;
+ Mul<Self::EF, Output = Self::EF>
+ From<SecureField>;

/// Returns the next mask value for the first interaction at offset 0.
fn next_trace_mask(&mut self) -> Self::F {
Expand Down
Loading

0 comments on commit b06f87e

Please sign in to comment.