Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor constraint evaluation. #696

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ use std::ops::Add;

use num_traits::{One, Zero};

use super::circle::CirclePoint;
use super::constraints::point_vanishing;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::ExtensionOf;
use super::fields::{ExtensionOf, FieldExpOps};
use super::poly::circle::CircleDomain;

pub trait IteratorMutExt<'a, T: 'a>: Iterator<Item = &'a mut T> {
fn assign(self, other: impl IntoIterator<Item = T>)
Expand Down Expand Up @@ -150,6 +153,21 @@ where
res - z
}

pub fn point_vanish_denominator_inverses(
domain: CircleDomain,
vanish_point: CirclePoint<BaseField>,
) -> Vec<BaseField> {
let mut denoms = vec![];
for point in domain.iter() {
// TODO(AlonH): Use `point_vanishing_fraction` instead of `point_vanishing` everywhere.
denoms.push(point_vanishing(vanish_point, point));
}
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); 1 << (domain.log_size())];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
denom_inverses
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
Expand Down
131 changes: 42 additions & 89 deletions crates/prover/src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::core::air::{AirProver, Component, ComponentProver, ComponentTrace};
use crate::core::backend::CpuBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::circle::Coset;
use crate::core::constraints::{coset_vanishing, point_excluder, point_vanishing};
use crate::core::constraints::{coset_vanishing, point_excluder};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
Expand All @@ -22,7 +22,8 @@ use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{BASE_TRACE, INTERACTION_TRACE};
use crate::core::utils::{
bit_reverse, previous_bit_reversed_circle_domain_index, shifted_secure_combination,
bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index,
shifted_secure_combination,
};
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS;
Expand Down Expand Up @@ -72,56 +73,35 @@ impl WideFibComponent {
accum: &mut ColumnAccumulator<'_, CpuBackend>,
lookup_values: &LookupValues,
) {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let mut first_point_denoms = vec![];
let mut last_point_denoms = vec![];
for point in trace_eval_domain.iter() {
// TODO(AlonH): Use `point_vanishing_fraction` instead of `point_vanishing` everywhere.
first_point_denoms.push(point_vanishing(zero_domain.at(0), point));
last_point_denoms.push(point_vanishing(
zero_domain.at(zero_domain.size() - 1),
point,
));
}
bit_reverse(&mut first_point_denoms);
bit_reverse(&mut last_point_denoms);
let mut first_point_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
let mut last_point_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&first_point_denoms, &mut first_point_denom_inverses);
BaseField::batch_inverse(&last_point_denoms, &mut last_point_denom_inverses);
let mut first_point_numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let mut last_point_numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses = point_vanish_denominator_inverses(
trace_eval_domain,
zero_domain.at(zero_domain.size() - 1),
);
let (lookup_value_0, lookup_value_1, lookup_value_n_minus_2, lookup_value_n_minus_1) = (
lookup_values[LOOKUP_VALUE_0_ID],
lookup_values[LOOKUP_VALUE_1_ID],
lookup_values[LOOKUP_VALUE_N_MINUS_2_ID],
lookup_values[LOOKUP_VALUE_N_MINUS_1_ID],
);

#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
first_point_numerators[i] = accum.random_coeff_powers[self.n_columns() + 4]
for (i, (first_point_denom_inverse, last_point_denom_inverse)) in
zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate()
{
let first_point_numerator = accum.random_coeff_powers[self.n_columns() + 4]
* (trace_evals[BASE_TRACE][0][i] - lookup_value_0)
+ accum.random_coeff_powers[self.n_columns() + 3]
* (trace_evals[BASE_TRACE][1][i] - lookup_value_1);
last_point_numerators[i] = accum.random_coeff_powers[self.n_columns() + 2]
let last_point_numerator = accum.random_coeff_powers[self.n_columns() + 2]
* (trace_evals[BASE_TRACE][self.n_columns() - 2][i] - lookup_value_n_minus_2)
+ accum.random_coeff_powers[self.n_columns() + 1]
* (trace_evals[BASE_TRACE][self.n_columns() - 1][i] - lookup_value_n_minus_1);
}
for (i, (num, denom_inverse)) in first_point_numerators
.iter()
.zip(first_point_denom_inverses.iter())
.enumerate()
{
accum.accumulate(i, *num * *denom_inverse);
}
for (i, (num, denom_inverse)) in last_point_numerators
.iter()
.zip(last_point_denom_inverses.iter())
.enumerate()
{
accum.accumulate(i, *num * *denom_inverse);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse,
);
}
}

Expand All @@ -140,19 +120,16 @@ impl WideFibComponent {
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
let mut numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];

#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
for (i, denom_inverse) in denom_inverses.iter().enumerate() {
let mut numerator = SecureField::zero();
for j in 0..self.n_columns() - 2 {
numerators[i] += accum.random_coeff_powers[self.n_columns() - 3 - j]
numerator += accum.random_coeff_powers[self.n_columns() - 3 - j]
* (trace_evals[BASE_TRACE][j][i].square()
+ trace_evals[BASE_TRACE][j + 1][i].square()
- trace_evals[BASE_TRACE][j + 2][i]);
}
}
for (i, (num, denom_inverse)) in numerators.iter().zip(denom_inverses.iter()).enumerate() {
accum.accumulate(i, *num * *denom_inverse);
accum.accumulate(i, numerator * *denom_inverse)
}
}

Expand All @@ -165,24 +142,12 @@ impl WideFibComponent {
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let mut first_point_denoms = vec![];
let mut last_point_denoms = vec![];
for point in trace_eval_domain.iter() {
first_point_denoms.push(point_vanishing(zero_domain.at(0), point));
last_point_denoms.push(point_vanishing(
zero_domain.at(zero_domain.size() - 1),
point,
));
}
bit_reverse(&mut first_point_denoms);
bit_reverse(&mut last_point_denoms);
let mut first_point_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
let mut last_point_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&first_point_denoms, &mut first_point_denom_inverses);
BaseField::batch_inverse(&last_point_denoms, &mut last_point_denom_inverses);
let mut first_point_numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let mut last_point_numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses = point_vanish_denominator_inverses(
trace_eval_domain,
zero_domain.at(zero_domain.size() - 1),
);
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);
let (lookup_value_0, lookup_value_1, lookup_value_n_minus_2, lookup_value_n_minus_1) = (
lookup_values[LOOKUP_VALUE_0_ID],
Expand All @@ -191,12 +156,13 @@ impl WideFibComponent {
lookup_values[LOOKUP_VALUE_N_MINUS_1_ID],
);

#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
for (i, (first_point_denom_inverse, last_point_denom_inverse)) in
zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate()
{
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
first_point_numerators[i] = accum.random_coeff_powers[self.n_columns() - 1]
let first_point_numerator = accum.random_coeff_powers[self.n_columns() - 1]
* ((value
* shifted_secure_combination(
&[
Expand All @@ -211,28 +177,19 @@ impl WideFibComponent {
alpha,
z,
));
last_point_numerators[i] = accum.random_coeff_powers[self.n_columns() - 2]
let last_point_numerator = accum.random_coeff_powers[self.n_columns() - 2]
* ((value
* shifted_secure_combination(
&[lookup_value_n_minus_2, lookup_value_n_minus_1],
alpha,
z,
))
- shifted_secure_combination(&[lookup_value_0, lookup_value_1], alpha, z));
}
for (i, (num, denom_inverse)) in first_point_numerators
.iter()
.zip(first_point_denom_inverses.iter())
.enumerate()
{
accum.accumulate(i, *num * *denom_inverse);
}
for (i, (num, denom_inverse)) in last_point_numerators
.iter()
.zip(last_point_denom_inverses.iter())
.enumerate()
{
accum.accumulate(i, *num * *denom_inverse);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse,
);
}
}

Expand All @@ -255,11 +212,9 @@ impl WideFibComponent {
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
let mut numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);

#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
for (i, denom_inverse) in denom_inverses.iter().enumerate() {
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
Expand All @@ -271,7 +226,7 @@ impl WideFibComponent {
let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][prev_index]
}));
numerators[i] = accum.random_coeff_powers[self.n_columns()]
let numerator = accum.random_coeff_powers[self.n_columns()]
* ((value
* shifted_secure_combination(
&[
Expand All @@ -287,9 +242,7 @@ impl WideFibComponent {
alpha,
z,
)));
}
for (i, (num, denom_inverse)) in numerators.iter().zip(denom_inverses.iter()).enumerate() {
accum.accumulate(i, *num * *denom_inverse);
accum.accumulate(i, numerator * *denom_inverse);
}
}
}
Expand Down
Loading