Skip to content

Commit

Permalink
Optimized lookup combine
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Aug 5, 2024
1 parent 17591ea commit 4351076
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 38 deletions.
45 changes: 27 additions & 18 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, Mul, Sub};
use std::ops::{Mul, Sub};

use itertools::Itertools;
use num_traits::{One, Zero};
Expand All @@ -17,7 +17,6 @@ use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::shifted_secure_combination;
use crate::core::ColumnVec;

/// Evaluates constraints for batched logups.
Expand Down Expand Up @@ -57,11 +56,7 @@ impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
values: &[E::F],
lookup_elements: &LookupElements,
) {
let shifted_value = shifted_secure_combination(
values,
E::EF::from(lookup_elements.alpha),
E::EF::from(lookup_elements.z),
);
let shifted_value = lookup_elements.combine(values);
self.push_frac(eval, numerator, shifted_value);
}

Expand Down Expand Up @@ -115,32 +110,46 @@ impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
}

/// Interaction elements for the logup protocol.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LookupElements {
pub z: SecureField,
pub alpha: SecureField,
alpha_powers: Vec<SecureField>,
}
impl LookupElements {
pub fn draw(channel: &mut Blake2sChannel) -> Self {
pub fn draw(channel: &mut Blake2sChannel, n_powers: usize) -> Self {
let [z, alpha] = channel.draw_felts(2).try_into().unwrap();
Self { z, alpha }
Self {
z,
alpha,
alpha_powers: (0..n_powers)
.scan(SecureField::one(), |acc, _| {
let res = *acc;
*acc *= alpha;
Some(res)
})
.collect(),
}
}
pub fn combine<F: Copy, EF>(&self, values: &[F]) -> EF
where
EF: Copy
+ Zero
+ Mul<EF, Output = EF>
+ Add<F, Output = EF>
+ Sub<EF, Output = EF>
+ From<SecureField>,
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
{
shifted_secure_combination(values, EF::from(self.alpha), EF::from(self.z))
EF::from(values[0])
+ values[1..]
.iter()
.zip(self.alpha_powers.iter())
.fold(EF::zero(), |acc, (&value, &power)| {
acc + EF::from(power) * value
})
- EF::from(self.z)
}
#[cfg(test)]
pub fn dummy() -> Self {
pub fn dummy(n_powers: usize) -> Self {
Self {
z: SecureField::one(),
alpha: SecureField::one(),
alpha_powers: vec![SecureField::one(); n_powers],
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/xor_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ mod tests {
xor_accum.add_input(u32x16::splat(1), u32x16::splat(2));

let (trace, lookup_data) = generate_trace(xor_accum);
let lookup_elements = LookupElements::dummy();
let lookup_elements = LookupElements::dummy(3);
let (interaction_trace, claimed_sum) =
generate_interaction_trace(lookup_data, &lookup_elements);
let constant_trace = generate_constant_trace::<ELEM_BITS, EXPAND_BITS>();
Expand Down
34 changes: 15 additions & 19 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl FrameworkComponent for PoseidonComponent {
let poseidon_eval = PoseidonEval {
eval,
logup: LogupAtRow::new(1, self.claimed_sum, is_first),
lookup_elements: self.lookup_elements,
lookup_elements: &self.lookup_elements,
};
poseidon_eval.eval()
}
Expand Down Expand Up @@ -146,20 +146,20 @@ fn pow5<F: FieldExpOps>(x: F) -> F {
x4 * x
}

struct PoseidonEval<E: EvalAtRow> {
struct PoseidonEval<'a, E: EvalAtRow> {
eval: E,
logup: LogupAtRow<2, E>,
lookup_elements: LookupElements,
lookup_elements: &'a LookupElements,
}

impl<E: EvalAtRow> PoseidonEval<E> {
impl<'a, E: EvalAtRow> PoseidonEval<'a, E> {
fn eval(mut self) -> E {
for _ in 0..N_INSTANCES_PER_ROW {
let mut state: [_; N_STATE] = std::array::from_fn(|_| self.eval.next_trace_mask());

// Require state lookup.
self.logup
.push_lookup(&mut self.eval, E::EF::one(), &state, &self.lookup_elements);
.push_lookup(&mut self.eval, E::EF::one(), &state, self.lookup_elements);

// 4 full rounds.
(0..N_HALF_FULL_ROUNDS).for_each(|round| {
Expand Down Expand Up @@ -201,7 +201,7 @@ impl<E: EvalAtRow> PoseidonEval<E> {

// Provide state lookup.
self.logup
.push_lookup(&mut self.eval, -E::EF::one(), &state, &self.lookup_elements);
.push_lookup(&mut self.eval, -E::EF::one(), &state, self.lookup_elements);
}

self.logup.finalize(&mut self.eval);
Expand Down Expand Up @@ -311,7 +311,7 @@ pub fn gen_trace(
pub fn gen_interaction_trace(
log_size: u32,
lookup_data: LookupData,
lookup_elements: LookupElements,
lookup_elements: &LookupElements,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
SecureField,
Expand Down Expand Up @@ -368,12 +368,12 @@ pub fn prove_poseidon(log_n_instances: u32) -> (PoseidonAir, StarkProof<Blake2sM
tree_builder.commit(channel);
span.exit();

// Draw lookup element.
let lookup_elements = LookupElements::draw(channel);
// Draw lookup elements.
let lookup_elements = LookupElements::draw(channel, N_STATE * 2);

// Interaction trace.
let span = span!(Level::INFO, "Interaction").entered();
let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, lookup_data, lookup_elements);
let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, lookup_data, &lookup_elements);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(trace);
tree_builder.commit(channel);
Expand Down Expand Up @@ -425,10 +425,9 @@ mod tests {
use crate::core::InteractionElements;
use crate::examples::poseidon::{
apply_internal_round_matrix, apply_m4, gen_interaction_trace, gen_trace, prove_poseidon,
PoseidonEval,
PoseidonEval, N_STATE,
};
use crate::math::matrix::{RowMajorMatrix, SquareMatrix};
use crate::qm31;

#[test]
fn test_apply_m4() {
Expand Down Expand Up @@ -473,12 +472,9 @@ mod tests {

// Trace.
let (trace0, interaction_data) = gen_trace(LOG_N_ROWS);
let lookup_elements = LookupElements {
z: qm31!(1, 2, 3, 4),
alpha: qm31!(5, 6, 7, 8),
};
let lookup_elements = LookupElements::dummy(N_STATE * 2);
let (trace1, claimed_sum) =
gen_interaction_trace(LOG_N_ROWS, interaction_data, lookup_elements);
gen_interaction_trace(LOG_N_ROWS, interaction_data, &lookup_elements);
let trace2 = vec![gen_is_first(LOG_N_ROWS)];

let traces = TreeVec::new(vec![trace0, trace1, trace2]);
Expand All @@ -489,7 +485,7 @@ mod tests {
PoseidonEval {
eval,
logup: LogupAtRow::new(1, claimed_sum, is_first),
lookup_elements,
lookup_elements: &lookup_elements,
}
.eval();
});
Expand Down Expand Up @@ -522,7 +518,7 @@ mod tests {
// Trace columns.
commitment_scheme.commit(proof.commitments[0], &sizes[0], channel);
// Draw lookup element.
let lookup_elements = LookupElements::draw(channel);
let lookup_elements = LookupElements::draw(channel, N_STATE * 2);
assert_eq!(lookup_elements, air.component.lookup_elements);
// TODO(spapini): Check claimed sum against first and last instances.
// Interaction columns.
Expand Down

0 comments on commit 4351076

Please sign in to comment.