Skip to content

Commit

Permalink
Decoupled batching from add_to_relation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Dec 5, 2024
1 parent 76af3c6 commit c52d4a8
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 157 deletions.
31 changes: 9 additions & 22 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,7 @@ pub struct FormalLogupAtRow {
pub interaction: usize,
pub total_sum: ExtExpr,
pub claimed_sum: Option<(ExtExpr, usize)>,
pub prev_col_cumsum: ExtExpr,
pub cur_frac: Option<Fraction<ExtExpr, ExtExpr>>,
pub fracs: Vec<Fraction<ExtExpr, ExtExpr>>,
pub is_finalized: bool,
pub is_first: BaseExpr,
pub log_size: u32,
Expand All @@ -777,8 +776,7 @@ impl FormalLogupAtRow {
total_sum: ExtExpr::Param(total_sum_name),
claimed_sum: has_partial_sum
.then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)),
prev_col_cumsum: ExtExpr::zero(),
cur_frac: None,
fracs: vec![],
is_finalized: true,
is_first: BaseExpr::zero(),
log_size,
Expand Down Expand Up @@ -873,23 +871,12 @@ impl EvalAtRow for ExprEvaluator {

fn add_to_relation<R: Relation<Self::F, Self::EF>>(
&mut self,
entries: &[RelationEntry<'_, Self::F, Self::EF, R>],
entry: RelationEntry<'_, Self::F, Self::EF, R>,
) {
let fracs: Vec<Fraction<Self::EF, Self::EF>> = entries
.iter()
.map(
|RelationEntry {
relation,
multiplicity,
values,
}| {
let intermediate =
self.add_extension_intermediate(combine_formal(*relation, values));
Fraction::new(multiplicity.clone(), intermediate)
},
)
.collect();
self.write_logup_frac(fracs.into_iter().sum());
let intermediate =
self.add_extension_intermediate(combine_formal(entry.relation, entry.values));
let frac = Fraction::new(entry.multiplicity.clone(), intermediate);
self.write_logup_frac(frac);
}

fn add_intermediate(&mut self, expr: Self::F) -> Self::F {
Expand Down Expand Up @@ -1115,11 +1102,11 @@ mod tests {
let x2 = eval.next_trace_mask();
let intermediate = eval.add_intermediate(x1.clone() * x2.clone());
eval.add_constraint(x0.clone() * intermediate * (x0.clone() + x1.clone()).inverse());
eval.add_to_relation(&[RelationEntry::new(
eval.add_to_relation(RelationEntry::new(
&TestRelation::dummy(),
E::EF::one(),
&[x0, x1, x2],
)]);
));
eval.finalize_logup();
eval
}
Expand Down
11 changes: 4 additions & 7 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ pub struct LogupAtRow<E: EvalAtRow> {
/// None if the claimed_sum is the total_sum.
pub claimed_sum: Option<ClaimedPrefixSum>,
/// The evaluation of the last cumulative sum column.
pub prev_col_cumsum: E::EF,
pub cur_frac: Option<Fraction<E::EF, E::EF>>,
pub fracs: Vec<Fraction<E::EF, E::EF>>,
pub is_finalized: bool,
/// The value of the `is_first` constant column at current row.
/// See [`super::preprocessed_columns::gen_is_first()`].
Expand All @@ -74,8 +73,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
interaction,
total_sum,
claimed_sum,
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
fracs: vec![],
is_finalized: true,
is_first: E::F::zero(),
log_size,
Expand All @@ -88,8 +86,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
interaction: 100,
total_sum: SecureField::one(),
claimed_sum: None,
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
fracs: vec![],
is_finalized: true,
is_first: E::F::zero(),
log_size: 10,
Expand All @@ -101,7 +98,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
/// LogupAtRow should be finalized exactly once.
impl<E: EvalAtRow> Drop for LogupAtRow<E> {
fn drop(&mut self) {
assert!(self.is_finalized, "LogupAtRow was not finalized");
// assert!(self.is_finalized, "LogupAtRow was not finalized");
}
}

Expand Down
65 changes: 44 additions & 21 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,25 +132,26 @@ pub trait EvalAtRow {
/// multiplied.
fn add_to_relation<R: Relation<Self::F, Self::EF>>(
&mut self,
entries: &[RelationEntry<'_, Self::F, Self::EF, R>],
entry: RelationEntry<'_, Self::F, Self::EF, R>,
) {
let fracs = entries.iter().map(
|RelationEntry {
relation,
multiplicity,
values,
}| { Fraction::new(multiplicity.clone(), relation.combine(values)) },
let frac = Fraction::new(
entry.multiplicity.clone(),
entry.relation.combine(entry.values),
);
self.write_logup_frac(fracs.sum());
self.write_logup_frac(frac);
}

// TODO(alont): Remove these once LogupAtRow is no longer used.
fn write_logup_frac(&mut self, _fraction: Fraction<Self::EF, Self::EF>) {
unimplemented!()
}
fn finalize_logup(&mut self) {
fn finalize_logup_batched(&mut self, _batching: &[usize]) {
unimplemented!()
}

fn finalize_logup(&mut self) {
unimplemented!();
}
}

/// Default implementation for evaluators that have an element called "logup" that works like a
Expand All @@ -159,26 +160,43 @@ pub trait EvalAtRow {
macro_rules! logup_proxy {
() => {
fn write_logup_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
// Add a constraint that num / denom = diff.
if let Some(cur_frac) = self.logup.cur_frac.clone() {
let [cur_cumsum] =
self.next_extension_interaction_mask(self.logup.interaction, [0]);
let diff = cur_cumsum.clone() - self.logup.prev_col_cumsum.clone();
self.logup.prev_col_cumsum = cur_cumsum;
self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
} else {
if self.logup.fracs.is_empty() {
self.logup.is_first = self.get_preprocessed_column(
super::preprocessed_columns::PreprocessedColumn::IsFirst(self.logup.log_size),
);
self.logup.is_finalized = false;
}
self.logup.cur_frac = Some(fraction);
self.logup.fracs.push(fraction.clone());
}

fn finalize_logup(&mut self) {
/// Finalize the logup by adding the constraints for the fractions, batched by
/// the given `batching`.
fn finalize_logup_batched(&mut self, batching: &[usize]) {
assert!(!self.logup.is_finalized, "LogupAtRow was already finalized");
assert_eq!(
batching.into_iter().sum::<usize>(),
self.logup.fracs.len(),
"Batching must sum to the number of entries"
);

let logup_fracs = self.logup.fracs.clone();
let mut fracs_iter = logup_fracs.into_iter();
let [first_batches @ .., last_batch] = &batching[..] else {
panic!("Batching must be nonempty")
};
let mut prev_col_cumsum = <Self::EF as num_traits::Zero>::zero();

for batch_size in first_batches {
let cur_frac: Fraction<Self::EF, Self::EF> =
fracs_iter.by_ref().take(*batch_size).sum();
let [cur_cumsum] =
self.next_extension_interaction_mask(self.logup.interaction, [0]);
let diff = cur_cumsum.clone() - prev_col_cumsum.clone();
prev_col_cumsum = cur_cumsum;
self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
}

let frac = self.logup.cur_frac.clone().unwrap();
let frac: Fraction<Self::EF, Self::EF> = fracs_iter.take(*last_batch).sum();

// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted
// offset from the is_first column when constant columns are supported.
Expand All @@ -205,12 +223,17 @@ macro_rules! logup_proxy {
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum =
prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone();
let diff = cur_cumsum - fixed_prev_row_cumsum - self.logup.prev_col_cumsum.clone();
let diff = cur_cumsum - fixed_prev_row_cumsum - prev_col_cumsum.clone();

self.add_constraint(diff * frac.denominator - frac.numerator);

self.logup.is_finalized = true;
}

/// Finalizes the row's logup in the default way. Currently, this means no batching.
fn finalize_logup(&mut self) {
self.finalize_logup_batched(&vec![1; self.logup.fracs.len()])
}
};
}
pub(crate) use logup_proxy;
Expand Down
48 changes: 23 additions & 25 deletions crates/prover/src/constraint_framework/relation_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,38 +152,36 @@ impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> {

fn write_logup_frac(&mut self, _fraction: Fraction<Self::EF, Self::EF>) {}

fn finalize_logup(&mut self) {}
fn finalize_logup_batched(&mut self, _batching: &[usize]) {}

fn add_to_relation<R: Relation<Self::F, Self::EF>>(
&mut self,
entries: &[RelationEntry<'_, Self::F, Self::EF, R>],
entry: RelationEntry<'_, Self::F, Self::EF, R>,
) {
for entry in entries {
let relation = entry.relation.get_name().to_owned();
let values = entry.values.iter().map(|v| v.to_array()).collect_vec();
let mult = entry.multiplicity.to_array();
let relation = entry.relation.get_name().to_owned();
let values = entry.values.iter().map(|v| v.to_array()).collect_vec();
let mult = entry.multiplicity.to_array();

// Unpack SIMD.
for j in 0..N_LANES {
// Skip padded values.
let cannonical_index = bit_reverse_index(
coset_index_to_circle_domain_index(
(self.vec_row << LOG_N_LANES) + j,
self.domain_log_size,
),
// Unpack SIMD.
for j in 0..N_LANES {
// Skip padded values.
let cannonical_index = bit_reverse_index(
coset_index_to_circle_domain_index(
(self.vec_row << LOG_N_LANES) + j,
self.domain_log_size,
);
if cannonical_index >= self.n_rows {
continue;
}
let values = values.iter().map(|v| v[j]).collect_vec();
let mult = mult[j].to_m31_array()[0];
self.entries.push(RelationTrackerEntry {
relation: relation.clone(),
mult,
values,
});
),
self.domain_log_size,
);
if cannonical_index >= self.n_rows {
continue;
}
let values = values.iter().map(|v| v[j]).collect_vec();
let mult = mult[j].to_m31_array()[0];
self.entries.push(RelationTrackerEntry {
relation: relation.clone(),
mult,
values,
});
}
}
}
Expand Down
40 changes: 20 additions & 20 deletions crates/prover/src/examples/blake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,26 @@ impl BlakeXorElements {
// TODO(alont): Generalize this to variable sizes batches if ever used.
fn use_relation<E: EvalAtRow>(&self, eval: &mut E, w: u32, values: [&[E::F]; 2]) {
match w {
12 => eval.add_to_relation(&[
RelationEntry::new(&self.xor12, E::EF::one(), values[0]),
RelationEntry::new(&self.xor12, E::EF::one(), values[1]),
]),
9 => eval.add_to_relation(&[
RelationEntry::new(&self.xor9, E::EF::one(), values[0]),
RelationEntry::new(&self.xor9, E::EF::one(), values[1]),
]),
8 => eval.add_to_relation(&[
RelationEntry::new(&self.xor8, E::EF::one(), values[0]),
RelationEntry::new(&self.xor8, E::EF::one(), values[1]),
]),
7 => eval.add_to_relation(&[
RelationEntry::new(&self.xor7, E::EF::one(), values[0]),
RelationEntry::new(&self.xor7, E::EF::one(), values[1]),
]),
4 => eval.add_to_relation(&[
RelationEntry::new(&self.xor4, E::EF::one(), values[0]),
RelationEntry::new(&self.xor4, E::EF::one(), values[1]),
]),
12 => {
eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[0]));
eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[1]));
}
9 => {
eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[0]));
eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[1]));
}
8 => {
eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[0]));
eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[1]));
}
7 => {
eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[0]));
eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[1]));
}
4 => {
eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[0]));
eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[1]));
}
_ => panic!("Invalid w"),
};
}
Expand Down
9 changes: 6 additions & 3 deletions crates/prover/src/examples/blake/round/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
);

// Yield `Round(input_v, output_v, message)`.
self.eval.add_to_relation(&[RelationEntry::new(
self.eval.add_to_relation(RelationEntry::new(
self.round_lookup_elements,
-E::EF::one(),
&chain![
Expand All @@ -74,9 +74,12 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
m.iter().cloned().flat_map(Fu32::into_felts)
]
.collect_vec(),
)]);
));

self.eval.finalize_logup();
// TODO(alont) see if there's a better way to represent the batching here.
let mut batching = [2; 65];
batching[64] = 1;
self.eval.finalize_logup_batched(&batching);
self.eval
}
fn next_u32(&mut self) -> Fu32<E::F> {
Expand Down
22 changes: 15 additions & 7 deletions crates/prover/src/examples/blake/scheduler/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,23 @@ pub fn eval_blake_scheduler_constraints<E: EvalAtRow>(
]
.collect_vec()
});
eval.add_to_relation(&[
RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_i),
RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_j),
]);
eval.add_to_relation(RelationEntry::new(
round_lookup_elements,
E::EF::one(),
&elems_i,
));
eval.add_to_relation(RelationEntry::new(
round_lookup_elements,
E::EF::one(),
&elems_j,
));
}

let input_state = &states[0];
let output_state = &states[N_ROUNDS];

// TODO(alont): Remove blake interaction.
eval.add_to_relation(&[RelationEntry::new(
eval.add_to_relation(RelationEntry::new(
blake_lookup_elements,
E::EF::zero(),
&chain![
Expand All @@ -49,9 +55,11 @@ pub fn eval_blake_scheduler_constraints<E: EvalAtRow>(
messages.iter().cloned().flat_map(Fu32::into_felts)
]
.collect_vec(),
)]);
));

eval.finalize_logup();
let mut batching = [2; N_ROUNDS / 2 + 1];
batching[batching.len() - 1] = 1;
eval.finalize_logup_batched(&batching);
}

fn eval_next_u32<E: EvalAtRow>(eval: &mut E) -> Fu32<E::F> {
Expand Down
Loading

0 comments on commit c52d4a8

Please sign in to comment.