From 43510762e1616219e114493668fa37afdaf8fea2 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Sun, 28 Jul 2024 16:07:25 +0300 Subject: [PATCH] Optimized lookup combine --- .../prover/src/constraint_framework/logup.rs | 45 +++++++++++-------- .../src/examples/blake/xor_table/mod.rs | 2 +- crates/prover/src/examples/poseidon/mod.rs | 34 +++++++------- 3 files changed, 43 insertions(+), 38 deletions(-) diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 929069771..1d9c4e180 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -1,4 +1,4 @@ -use std::ops::{Add, Mul, Sub}; +use std::ops::{Mul, Sub}; use itertools::Itertools; use num_traits::{One, Zero}; @@ -17,7 +17,6 @@ use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::fields::FieldExpOps; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::shifted_secure_combination; use crate::core::ColumnVec; /// Evaluates constraints for batched logups. @@ -57,11 +56,7 @@ impl LogupAtRow { values: &[E::F], lookup_elements: &LookupElements, ) { - let shifted_value = shifted_secure_combination( - values, - E::EF::from(lookup_elements.alpha), - E::EF::from(lookup_elements.z), - ); + let shifted_value = lookup_elements.combine(values); self.push_frac(eval, numerator, shifted_value); } @@ -115,32 +110,46 @@ impl LogupAtRow { } /// Interaction elements for the logup protocol. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct LookupElements { pub z: SecureField, pub alpha: SecureField, + alpha_powers: Vec, } impl LookupElements { - pub fn draw(channel: &mut Blake2sChannel) -> Self { + pub fn draw(channel: &mut Blake2sChannel, n_powers: usize) -> Self { let [z, alpha] = channel.draw_felts(2).try_into().unwrap(); - Self { z, alpha } + Self { + z, + alpha, + alpha_powers: (0..n_powers) + .scan(SecureField::one(), |acc, _| { + let res = *acc; + *acc *= alpha; + Some(res) + }) + .collect(), + } } pub fn combine(&self, values: &[F]) -> EF where - EF: Copy - + Zero - + Mul - + Add - + Sub - + From, + EF: Copy + Zero + From + From + Mul + Sub, { - shifted_secure_combination(values, EF::from(self.alpha), EF::from(self.z)) + EF::from(values[0]) + + values[1..] + .iter() + .zip(self.alpha_powers.iter()) + .fold(EF::zero(), |acc, (&value, &power)| { + acc + EF::from(power) * value + }) + - EF::from(self.z) } #[cfg(test)] - pub fn dummy() -> Self { + pub fn dummy(n_powers: usize) -> Self { Self { z: SecureField::one(), alpha: SecureField::one(), + alpha_powers: vec![SecureField::one(); n_powers], } } } diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index 5aa801829..1a55a4611 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -116,7 +116,7 @@ mod tests { xor_accum.add_input(u32x16::splat(1), u32x16::splat(2)); let (trace, lookup_data) = generate_trace(xor_accum); - let lookup_elements = LookupElements::dummy(); + let lookup_elements = LookupElements::dummy(3); let (interaction_trace, claimed_sum) = generate_interaction_trace(lookup_data, &lookup_elements); let constant_trace = generate_constant_trace::(); diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 4cf66dc46..7435b2c88 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -71,7 +71,7 @@ impl FrameworkComponent for PoseidonComponent { let poseidon_eval = PoseidonEval { eval, logup: LogupAtRow::new(1, self.claimed_sum, is_first), - lookup_elements: self.lookup_elements, + lookup_elements: &self.lookup_elements, }; poseidon_eval.eval() } @@ -146,20 +146,20 @@ fn pow5(x: F) -> F { x4 * x } -struct PoseidonEval { +struct PoseidonEval<'a, E: EvalAtRow> { eval: E, logup: LogupAtRow<2, E>, - lookup_elements: LookupElements, + lookup_elements: &'a LookupElements, } -impl PoseidonEval { +impl<'a, E: EvalAtRow> PoseidonEval<'a, E> { fn eval(mut self) -> E { for _ in 0..N_INSTANCES_PER_ROW { let mut state: [_; N_STATE] = std::array::from_fn(|_| self.eval.next_trace_mask()); // Require state lookup. self.logup - .push_lookup(&mut self.eval, E::EF::one(), &state, &self.lookup_elements); + .push_lookup(&mut self.eval, E::EF::one(), &state, self.lookup_elements); // 4 full rounds. (0..N_HALF_FULL_ROUNDS).for_each(|round| { @@ -201,7 +201,7 @@ impl PoseidonEval { // Provide state lookup. self.logup - .push_lookup(&mut self.eval, -E::EF::one(), &state, &self.lookup_elements); + .push_lookup(&mut self.eval, -E::EF::one(), &state, self.lookup_elements); } self.logup.finalize(&mut self.eval); @@ -311,7 +311,7 @@ pub fn gen_trace( pub fn gen_interaction_trace( log_size: u32, lookup_data: LookupData, - lookup_elements: LookupElements, + lookup_elements: &LookupElements, ) -> ( ColumnVec>, SecureField, @@ -368,12 +368,12 @@ pub fn prove_poseidon(log_n_instances: u32) -> (PoseidonAir, StarkProof