diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 95f9f3342..bfa82363e 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -612,8 +612,7 @@ pub struct FormalLogupAtRow { pub interaction: usize, pub total_sum: ExtExpr, pub claimed_sum: Option<(ExtExpr, usize)>, - pub prev_col_cumsum: ExtExpr, - pub cur_frac: Option>, + pub fracs: Vec>, pub is_finalized: bool, pub is_first: BaseExpr, pub log_size: u32, @@ -634,8 +633,7 @@ impl FormalLogupAtRow { total_sum: ExtExpr::Param(total_sum_name), claimed_sum: has_partial_sum .then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), - prev_col_cumsum: ExtExpr::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: BaseExpr::zero(), log_size, @@ -724,22 +722,11 @@ impl EvalAtRow for ExprEvaluator { fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + entry: RelationEntry<'_, Self::F, Self::EF, R>, ) { - let fracs: Vec> = entries - .iter() - .map( - |RelationEntry { - relation, - multiplicity, - values, - }| { - let intermediate = self.add_intermediate(combine_formal(*relation, values)); - Fraction::new(multiplicity.clone(), intermediate) - }, - ) - .collect(); - self.write_logup_frac(fracs.into_iter().sum()); + let intermediate = self.add_intermediate(combine_formal(entry.relation, entry.values)); + let frac = Fraction::new(entry.multiplicity.clone(), intermediate); + self.write_logup_frac(frac); } super::logup_proxy!(); @@ -945,12 +932,12 @@ mod tests { eval.add_constraint( x0.clone() * x1.clone() * x2.clone() * (x0.clone() + x1.clone()).inverse(), ); - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( &TestRelation::dummy(), E::EF::one(), &[x0, x1, x2], - )]); - eval.finalize_logup(); + )); + eval.finalize_logup(&[1]); eval } } diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index d6af96e46..bb05c6b5c 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -49,8 +49,7 @@ pub struct LogupAtRow { /// None if the claimed_sum is the total_sum. pub claimed_sum: Option, /// The evaluation of the last cumulative sum column. - pub prev_col_cumsum: E::EF, - pub cur_frac: Option>, + pub fracs: Vec>, pub is_finalized: bool, /// The value of the `is_first` constant column at current row. /// See [`super::preprocessed_columns::gen_is_first()`]. @@ -74,8 +73,7 @@ impl LogupAtRow { interaction, total_sum, claimed_sum, - prev_col_cumsum: E::EF::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: E::F::zero(), log_size, @@ -88,8 +86,7 @@ impl LogupAtRow { interaction: 100, total_sum: SecureField::one(), claimed_sum: None, - prev_col_cumsum: E::EF::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: E::F::zero(), log_size: 10, diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 341e24511..21dfe12dd 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -120,23 +120,20 @@ pub trait EvalAtRow { /// multiplied. fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + entry: RelationEntry<'_, Self::F, Self::EF, R>, ) { - let fracs = entries.iter().map( - |RelationEntry { - relation, - multiplicity, - values, - }| { Fraction::new(multiplicity.clone(), relation.combine(values)) }, + let frac = Fraction::new( + entry.multiplicity.clone(), + entry.relation.combine(entry.values), ); - self.write_logup_frac(fracs.sum()); + self.write_logup_frac(frac); } // TODO(alont): Remove these once LogupAtRow is no longer used. fn write_logup_frac(&mut self, _fraction: Fraction) { unimplemented!() } - fn finalize_logup(&mut self) { + fn finalize_logup(&mut self, _batching: &[usize]) { unimplemented!() } } @@ -147,26 +144,56 @@ pub trait EvalAtRow { macro_rules! logup_proxy { () => { fn write_logup_frac(&mut self, fraction: Fraction) { - // Add a constraint that num / denom = diff. - if let Some(cur_frac) = self.logup.cur_frac.clone() { - let [cur_cumsum] = - self.next_extension_interaction_mask(self.logup.interaction, [0]); - let diff = cur_cumsum.clone() - self.logup.prev_col_cumsum.clone(); - self.logup.prev_col_cumsum = cur_cumsum; - self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); - } else { + if self.logup.fracs.is_empty() { self.logup.is_first = self.get_preprocessed_column( super::preprocessed_columns::PreprocessedColumn::IsFirst(self.logup.log_size), ); self.logup.is_finalized = false; } - self.logup.cur_frac = Some(fraction); + self.logup.fracs.push(fraction.clone()); + // Add a constraint that num / denom = diff. + // if let Some(cur_frac) = self.logup.cur_frac.clone() { + // let [cur_cumsum] = + // self.next_extension_interaction_mask(self.logup.interaction, [0]); + // let diff = cur_cumsum.clone() - self.logup.prev_col_cumsum.clone(); + // self.logup.prev_col_cumsum = cur_cumsum; + // self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); + // } else { + // self.logup.is_first = self.get_preprocessed_column( + // super::preprocessed_columns::PreprocessedColumn::IsFirst(self.logup.log_size), + // ); + // self.logup.is_finalized = false; + // } + // self.logup.cur_frac = Some(fraction); } - fn finalize_logup(&mut self) { + /// Finalize the logup by adding the constraints for the fractions, batched by + /// the given `batching`. + fn finalize_logup(&mut self, batching: &[usize]) { assert!(!self.logup.is_finalized, "LogupAtRow was already finalized"); + assert_eq!( + batching.into_iter().sum::(), + self.logup.fracs.len(), + "Batching must sum to the number of entries" + ); + + let logup_fracs = self.logup.fracs.clone(); + let mut fracs_iter = logup_fracs.into_iter(); + let [first_batches @ .., last_batch] = &batching[..] else { + panic!("Batching must be nonempty") + }; + let mut prev_col_cumsum = ::zero(); + + for batch_size in first_batches { + let cur_frac: Fraction = fracs_iter.by_ref().take(*batch_size).sum(); + let [cur_cumsum] = + self.next_extension_interaction_mask(self.logup.interaction, [0]); + let diff = cur_cumsum.clone() - prev_col_cumsum.clone(); + prev_col_cumsum = cur_cumsum; + self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); + } - let frac = self.logup.cur_frac.clone().unwrap(); + let frac: Fraction = fracs_iter.take(*last_batch).sum(); // TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted // offset from the is_first column when constant columns are supported. @@ -193,7 +220,7 @@ macro_rules! logup_proxy { // Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row. let fixed_prev_row_cumsum = prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone(); - let diff = cur_cumsum - fixed_prev_row_cumsum - self.logup.prev_col_cumsum.clone(); + let diff = cur_cumsum - fixed_prev_row_cumsum - prev_col_cumsum.clone(); self.add_constraint(diff * frac.denominator - frac.numerator); diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index 3866df39a..7dfc0de61 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -152,38 +152,36 @@ impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> { fn write_logup_frac(&mut self, _fraction: Fraction) {} - fn finalize_logup(&mut self) {} + fn finalize_logup(&mut self, _batching: &[usize]) {} fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + entry: RelationEntry<'_, Self::F, Self::EF, R>, ) { - for entry in entries { - let relation = entry.relation.get_name().to_owned(); - let values = entry.values.iter().map(|v| v.to_array()).collect_vec(); - let mult = entry.multiplicity.to_array(); - - // Unpack SIMD. - for j in 0..N_LANES { - // Skip padded values. - let cannonical_index = bit_reverse_index( - coset_index_to_circle_domain_index( - (self.vec_row << LOG_N_LANES) + j, - self.domain_log_size, - ), + let relation = entry.relation.get_name().to_owned(); + let values = entry.values.iter().map(|v| v.to_array()).collect_vec(); + let mult = entry.multiplicity.to_array(); + + // Unpack SIMD. + for j in 0..N_LANES { + // Skip padded values. + let cannonical_index = bit_reverse_index( + coset_index_to_circle_domain_index( + (self.vec_row << LOG_N_LANES) + j, self.domain_log_size, - ); - if cannonical_index >= self.n_rows { - continue; - } - let values = values.iter().map(|v| v[j]).collect_vec(); - let mult = mult[j].to_m31_array()[0]; - self.entries.push(RelationTrackerEntry { - relation: relation.clone(), - mult, - values, - }); + ), + self.domain_log_size, + ); + if cannonical_index >= self.n_rows { + continue; } + let values = values.iter().map(|v| v[j]).collect_vec(); + let mult = mult[j].to_m31_array()[0]; + self.entries.push(RelationTrackerEntry { + relation: relation.clone(), + mult, + values, + }); } } } diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index ff62f9f7d..76feb7f8b 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -88,26 +88,26 @@ impl BlakeXorElements { // TODO(alont): Generalize this to variable sizes batches if ever used. fn use_relation(&self, eval: &mut E, w: u32, values: [&[E::F]; 2]) { match w { - 12 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor12, E::EF::one(), values[0]), - RelationEntry::new(&self.xor12, E::EF::one(), values[1]), - ]), - 9 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor9, E::EF::one(), values[0]), - RelationEntry::new(&self.xor9, E::EF::one(), values[1]), - ]), - 8 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor8, E::EF::one(), values[0]), - RelationEntry::new(&self.xor8, E::EF::one(), values[1]), - ]), - 7 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor7, E::EF::one(), values[0]), - RelationEntry::new(&self.xor7, E::EF::one(), values[1]), - ]), - 4 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor4, E::EF::one(), values[0]), - RelationEntry::new(&self.xor4, E::EF::one(), values[1]), - ]), + 12 => { + eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[1])); + } + 9 => { + eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[1])); + } + 8 => { + eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[1])); + } + 7 => { + eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[1])); + } + 4 => { + eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[1])); + } _ => panic!("Invalid w"), }; } diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index ada5fb287..94144b9a7 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -65,7 +65,7 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { ); // Yield `Round(input_v, output_v, message)`. - self.eval.add_to_relation(&[RelationEntry::new( + self.eval.add_to_relation(RelationEntry::new( self.round_lookup_elements, -E::EF::one(), &chain![ @@ -74,9 +74,12 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { m.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), - )]); + )); - self.eval.finalize_logup(); + // TODO(alont) see if there's a better way to represent the batching here. + let mut batching = [2; 65]; + batching[64] = 1; + self.eval.finalize_logup(&batching); self.eval } fn next_u32(&mut self) -> Fu32 { diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index 1bf93d1aa..cf77ed035 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -30,17 +30,23 @@ pub fn eval_blake_scheduler_constraints( ] .collect_vec() }); - eval.add_to_relation(&[ - RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_i), - RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_j), - ]); + eval.add_to_relation(RelationEntry::new( + round_lookup_elements, + E::EF::one(), + &elems_i, + )); + eval.add_to_relation(RelationEntry::new( + round_lookup_elements, + E::EF::one(), + &elems_j, + )); } let input_state = &states[0]; let output_state = &states[N_ROUNDS]; // TODO(alont): Remove blake interaction. - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( blake_lookup_elements, E::EF::zero(), &chain![ @@ -49,9 +55,11 @@ pub fn eval_blake_scheduler_constraints( messages.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), - )]); + )); - eval.finalize_logup(); + let mut batching = [2; N_ROUNDS / 2 + 1]; + batching[batching.len() - 1] = 1; + eval.finalize_logup(&batching); } fn eval_next_u32(eval: &mut E) -> Fu32 { diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index 4df0a6c63..9cff84bd1 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -40,43 +40,32 @@ macro_rules! xor_table_eval { 2, )); - let entry_chunks = (0..(1 << (2 * EXPAND_BITS))) - .map(|i| { - let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); - let multiplicity = self.eval.next_trace_mask(); + for i in (0..(1 << (2 * EXPAND_BITS))) { + let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); + let multiplicity = self.eval.next_trace_mask(); - let a = al.clone() - + E::F::from(BaseField::from_u32_unchecked( - i << limb_bits::(), - )); - let b = bl.clone() - + E::F::from(BaseField::from_u32_unchecked( - j << limb_bits::(), - )); - let c = cl.clone() - + E::F::from(BaseField::from_u32_unchecked( - (i ^ j) << limb_bits::(), - )); + let a = al.clone() + + E::F::from(BaseField::from_u32_unchecked( + i << limb_bits::(), + )); + let b = bl.clone() + + E::F::from(BaseField::from_u32_unchecked( + j << limb_bits::(), + )); + let c = cl.clone() + + E::F::from(BaseField::from_u32_unchecked( + (i ^ j) << limb_bits::(), + )); - (self.lookup_elements, -multiplicity, [a, b, c]) - }) - .collect_vec(); - - for entry_chunk in entry_chunks.chunks(2) { - self.eval.add_to_relation( - &entry_chunk - .iter() - .map(|(lookup, multiplicity, values)| { - RelationEntry::new( - *lookup, - E::EF::from(multiplicity.clone()), - values, - ) - }) - .collect_vec(), - ); + self.eval.add_to_relation(RelationEntry::new( + self.lookup_elements, + -E::EF::from(multiplicity), + &[a, b, c], + )); } - self.eval.finalize_logup(); + + self.eval + .finalize_logup(&vec![2; (1 << (2 * EXPAND_BITS - 1))]); self.eval } } diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 49da86f8a..02c24a2f5 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -66,18 +66,24 @@ impl FrameworkEval for PlonkEval { + (E::F::one() - op) * a_val.clone() * b_val.clone(), ); - eval.add_to_relation(&[ - RelationEntry::new(&self.lookup_elements, E::EF::one(), &[a_wire, a_val]), - RelationEntry::new(&self.lookup_elements, E::EF::one(), &[b_wire, b_val]), - ]); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &[a_wire, a_val], + )); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &[b_wire, b_val], + )); - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( &self.lookup_elements, (-mult).into(), &[c_wire, c_val], - )]); + )); - eval.finalize_logup(); + eval.finalize_logup(&[2, 1]); eval } } diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 51b671580..f2a6be315 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -186,13 +186,15 @@ pub fn eval_poseidon_constraints(eval: &mut E, lookup_elements: &P }); // Provide state lookups. - eval.add_to_relation(&[ - RelationEntry::new(lookup_elements, E::EF::one(), &initial_state), - RelationEntry::new(lookup_elements, -E::EF::one(), &state), - ]) + eval.add_to_relation(RelationEntry::new( + lookup_elements, + E::EF::one(), + &initial_state, + )); + eval.add_to_relation(RelationEntry::new(lookup_elements, -E::EF::one(), &state)); } - eval.finalize_logup(); + eval.finalize_logup(&[2; N_INSTANCES_PER_ROW]); } pub struct LookupData { diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index 4600a3cf0..63d43c115 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -52,12 +52,18 @@ impl FrameworkEval for StateTransitionEval let mut output_state = input_state.clone(); output_state[COORDINATE] += E::F::one(); - eval.add_to_relation(&[ - RelationEntry::new(&self.lookup_elements, E::EF::one(), &input_state), - RelationEntry::new(&self.lookup_elements, -E::EF::one(), &output_state), - ]); - - eval.finalize_logup(); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &input_state, + )); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + -E::EF::one(), + &output_state, + )); + + eval.finalize_logup(&[2]); eval } }