Skip to content

Commit

Permalink
Batch inverse cpu quotient denominators. (#559)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 authored Apr 2, 2024
1 parent bc540de commit ddc8d3b
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 14 deletions.
1 change: 0 additions & 1 deletion src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ impl AVX512Backend {
denominators.push(denominators[i - 1] * mappings[i]);
}

// TODO(Ohad): batch inverse.
let mut denom_inverses = vec![F::zero(); denominators.len()];
F::batch_inverse(&denominators, &mut denom_inverses);

Expand Down
2 changes: 1 addition & 1 deletion src/core/backend/avx512/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl QuotientOps for AVX512Backend {
) -> SecureEvaluation<Self> {
assert!(domain.log_size() >= VECS_LOG_SIZE as u32);
let mut values = SecureColumn::<AVX512Backend>::zeros(domain.size());
let quotient_constants = quotient_constants(sample_batches, random_coeff);
let quotient_constants = quotient_constants(sample_batches, random_coeff, domain);

// TODO(spapini): bit reverse iterator.
for vec_row in 0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)) {
Expand Down
53 changes: 42 additions & 11 deletions src/core/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use itertools::{izip, zip_eq};
use num_traits::{One, Zero};

use super::CPUBackend;
use crate::core::backend::Col;
use crate::core::circle::CirclePoint;
use crate::core::commitment_scheme::quotients::{ColumnSampleBatch, PointSample, QuotientOps};
use crate::core::constraints::{complex_conjugate_line_coeffs, pair_vanishing};
Expand All @@ -11,7 +12,7 @@ use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::{ComplexConjugate, FieldExpOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;
use crate::core::utils::{bit_reverse, bit_reverse_index};

impl QuotientOps for CPUBackend {
fn accumulate_quotients(
Expand All @@ -21,7 +22,7 @@ impl QuotientOps for CPUBackend {
sample_batches: &[ColumnSampleBatch],
) -> SecureEvaluation<Self> {
let mut values = SecureColumn::zeros(domain.size());
let quotient_constants = quotient_constants(sample_batches, random_coeff);
let quotient_constants = quotient_constants(sample_batches, random_coeff, domain);

for row in 0..domain.size() {
// TODO(alonh): Make an efficient bit reverse domain iterator, possibly for AVX backend.
Expand All @@ -47,10 +48,11 @@ pub fn accumulate_row_quotients(
domain_point: CirclePoint<BaseField>,
) -> SecureField {
let mut row_accumulator = SecureField::zero();
for (sample_batch, line_coeffs, batch_coeff) in izip!(
for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!(
sample_batches,
&quotient_constants.line_coeffs,
&quotient_constants.batch_random_coeffs
&quotient_constants.batch_random_coeffs,
&quotient_constants.denominator_inverses
) {
let mut numerator = SecureField::zero();
for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs)
Expand All @@ -61,13 +63,7 @@ pub fn accumulate_row_quotients(
numerator += value - linear_term;
}

let denominator = pair_vanishing(
sample_batch.point,
sample_batch.point.complex_conjugate(),
domain_point.into_ef(),
);

row_accumulator = row_accumulator * *batch_coeff + numerator / denominator;
row_accumulator = row_accumulator * *batch_coeff + numerator * denominator_inverses[row];
}
row_accumulator
}
Expand Down Expand Up @@ -113,15 +109,48 @@ pub fn batch_random_coeffs(
.collect()
}

// TODO(AlonH): Make generic over backend.
fn denominator_inverses(
sample_batches: &[ColumnSampleBatch],
domain: CircleDomain,
) -> Vec<Col<CPUBackend, SecureField>> {
let mut flat_denominators = Vec::with_capacity(sample_batches.len() * domain.size());
for sample_batch in sample_batches {
for row in 0..domain.size() {
let domain_point = domain.at(row);
let denominator = pair_vanishing(
sample_batch.point,
sample_batch.point.complex_conjugate(),
domain_point.into_ef(),
);
flat_denominators.push(denominator);
}
}

let mut flat_denominator_inverses = vec![SecureField::zero(); flat_denominators.len()];
SecureField::batch_inverse(&flat_denominators, &mut flat_denominator_inverses);

flat_denominator_inverses
.chunks_mut(domain.size())
.map(|denominator_inverses| {
bit_reverse(denominator_inverses);
denominator_inverses.to_vec()
})
.collect()
}

pub fn quotient_constants(
sample_batches: &[ColumnSampleBatch],
random_coeff: SecureField,
domain: CircleDomain,
) -> QuotientConstants {
let line_coeffs = column_line_coeffs(sample_batches, random_coeff);
let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff);
let denominator_inverses = denominator_inverses(sample_batches, domain);
QuotientConstants {
line_coeffs,
batch_random_coeffs,
denominator_inverses,
}
}

Expand All @@ -133,6 +162,8 @@ pub struct QuotientConstants {
/// The random coefficients used to linearly combine the batched quotients For more details see
/// [self::batch_random_coeffs].
pub batch_random_coeffs: Vec<SecureField>,
/// The inverses of the denominators of the quotients.
pub denominator_inverses: Vec<Col<CPUBackend, SecureField>>,
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion src/core/commitment_scheme/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ pub fn fri_answers_for_log_size(
) -> Result<SparseCircleEvaluation<SecureField>, VerificationError> {
let commitment_domain = CanonicCoset::new(log_size).circle_domain();
let sample_batches = ColumnSampleBatch::new_vec(samples);
let quotient_constants = quotient_constants(&sample_batches, random_coeff);
for queried_values in queried_values_per_column {
if queried_values.len() != query_domain.flatten().len() {
return Err(VerificationError::InvalidStructure(
Expand All @@ -140,6 +139,7 @@ pub fn fri_answers_for_log_size(
let mut evals = Vec::new();
for subdomain in query_domain.iter() {
let domain = subdomain.to_circle_domain(&commitment_domain);
let quotient_constants = quotient_constants(&sample_batches, random_coeff, domain);
let mut column_evals = Vec::new();
for queried_values in queried_values_per_column.iter_mut() {
let eval = CircleEvaluation::new(
Expand Down

0 comments on commit ddc8d3b

Please sign in to comment.