diff --git a/src/core/backend/avx512/circle.rs b/src/core/backend/avx512/circle.rs index b67ba2e60..ec3b95830 100644 --- a/src/core/backend/avx512/circle.rs +++ b/src/core/backend/avx512/circle.rs @@ -93,7 +93,6 @@ impl AVX512Backend { denominators.push(denominators[i - 1] * mappings[i]); } - // TODO(Ohad): batch inverse. let mut denom_inverses = vec![F::zero(); denominators.len()]; F::batch_inverse(&denominators, &mut denom_inverses); diff --git a/src/core/backend/avx512/quotients.rs b/src/core/backend/avx512/quotients.rs index 78d7cf406..94db1daa9 100644 --- a/src/core/backend/avx512/quotients.rs +++ b/src/core/backend/avx512/quotients.rs @@ -23,7 +23,7 @@ impl QuotientOps for AVX512Backend { ) -> SecureEvaluation { assert!(domain.log_size() >= VECS_LOG_SIZE as u32); let mut values = SecureColumn::::zeros(domain.size()); - let quotient_constants = quotient_constants(sample_batches, random_coeff); + let quotient_constants = quotient_constants(sample_batches, random_coeff, domain); // TODO(spapini): bit reverse iterator. for vec_row in 0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)) { diff --git a/src/core/backend/cpu/quotients.rs b/src/core/backend/cpu/quotients.rs index 82b74c6f5..2990245ef 100644 --- a/src/core/backend/cpu/quotients.rs +++ b/src/core/backend/cpu/quotients.rs @@ -2,6 +2,7 @@ use itertools::{izip, zip_eq}; use num_traits::{One, Zero}; use super::CPUBackend; +use crate::core::backend::Col; use crate::core::circle::CirclePoint; use crate::core::commitment_scheme::quotients::{ColumnSampleBatch, PointSample, QuotientOps}; use crate::core::constraints::{complex_conjugate_line_coeffs, pair_vanishing}; @@ -11,7 +12,7 @@ use crate::core::fields::secure_column::SecureColumn; use crate::core::fields::{ComplexConjugate, FieldExpOps}; use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::bit_reverse_index; +use crate::core::utils::{bit_reverse, bit_reverse_index}; impl QuotientOps for CPUBackend { fn accumulate_quotients( @@ -21,7 +22,7 @@ impl QuotientOps for CPUBackend { sample_batches: &[ColumnSampleBatch], ) -> SecureEvaluation { let mut values = SecureColumn::zeros(domain.size()); - let quotient_constants = quotient_constants(sample_batches, random_coeff); + let quotient_constants = quotient_constants(sample_batches, random_coeff, domain); for row in 0..domain.size() { // TODO(alonh): Make an efficient bit reverse domain iterator, possibly for AVX backend. @@ -47,10 +48,11 @@ pub fn accumulate_row_quotients( domain_point: CirclePoint, ) -> SecureField { let mut row_accumulator = SecureField::zero(); - for (sample_batch, line_coeffs, batch_coeff) in izip!( + for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!( sample_batches, "ient_constants.line_coeffs, - "ient_constants.batch_random_coeffs + "ient_constants.batch_random_coeffs, + "ient_constants.denominator_inverses ) { let mut numerator = SecureField::zero(); for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs) @@ -61,13 +63,7 @@ pub fn accumulate_row_quotients( numerator += value - linear_term; } - let denominator = pair_vanishing( - sample_batch.point, - sample_batch.point.complex_conjugate(), - domain_point.into_ef(), - ); - - row_accumulator = row_accumulator * *batch_coeff + numerator / denominator; + row_accumulator = row_accumulator * *batch_coeff + numerator * denominator_inverses[row]; } row_accumulator } @@ -113,15 +109,48 @@ pub fn batch_random_coeffs( .collect() } +// TODO(AlonH): Make generic over backend. +fn denominator_inverses( + sample_batches: &[ColumnSampleBatch], + domain: CircleDomain, +) -> Vec> { + let mut flat_denominators = Vec::with_capacity(sample_batches.len() * domain.size()); + for sample_batch in sample_batches { + for row in 0..domain.size() { + let domain_point = domain.at(row); + let denominator = pair_vanishing( + sample_batch.point, + sample_batch.point.complex_conjugate(), + domain_point.into_ef(), + ); + flat_denominators.push(denominator); + } + } + + let mut flat_denominator_inverses = vec![SecureField::zero(); flat_denominators.len()]; + SecureField::batch_inverse(&flat_denominators, &mut flat_denominator_inverses); + + flat_denominator_inverses + .chunks_mut(domain.size()) + .map(|denominator_inverses| { + bit_reverse(denominator_inverses); + denominator_inverses.to_vec() + }) + .collect() +} + pub fn quotient_constants( sample_batches: &[ColumnSampleBatch], random_coeff: SecureField, + domain: CircleDomain, ) -> QuotientConstants { let line_coeffs = column_line_coeffs(sample_batches, random_coeff); let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff); + let denominator_inverses = denominator_inverses(sample_batches, domain); QuotientConstants { line_coeffs, batch_random_coeffs, + denominator_inverses, } } @@ -133,6 +162,8 @@ pub struct QuotientConstants { /// The random coefficients used to linearly combine the batched quotients For more details see /// [self::batch_random_coeffs]. pub batch_random_coeffs: Vec, + /// The inverses of the denominators of the quotients. + pub denominator_inverses: Vec>, } #[cfg(test)] diff --git a/src/core/commitment_scheme/quotients.rs b/src/core/commitment_scheme/quotients.rs index d4217372c..ffb0415f2 100644 --- a/src/core/commitment_scheme/quotients.rs +++ b/src/core/commitment_scheme/quotients.rs @@ -124,7 +124,6 @@ pub fn fri_answers_for_log_size( ) -> Result, VerificationError> { let commitment_domain = CanonicCoset::new(log_size).circle_domain(); let sample_batches = ColumnSampleBatch::new_vec(samples); - let quotient_constants = quotient_constants(&sample_batches, random_coeff); for queried_values in queried_values_per_column { if queried_values.len() != query_domain.flatten().len() { return Err(VerificationError::InvalidStructure( @@ -140,6 +139,7 @@ pub fn fri_answers_for_log_size( let mut evals = Vec::new(); for subdomain in query_domain.iter() { let domain = subdomain.to_circle_domain(&commitment_domain); + let quotient_constants = quotient_constants(&sample_batches, random_coeff, domain); let mut column_evals = Vec::new(); for queried_values in queried_values_per_column.iter_mut() { let eval = CircleEvaluation::new(