diff --git a/src/core/backend/cpu/quotients.rs b/src/core/backend/cpu/quotients.rs index e46dd3234..af650595c 100644 --- a/src/core/backend/cpu/quotients.rs +++ b/src/core/backend/cpu/quotients.rs @@ -17,14 +17,14 @@ impl QuotientOps for CPUBackend { domain: CircleDomain, columns: &[&CircleEvaluation], random_coeff: SecureField, - samples: &[ColumnSampleBatch], + sample_batches: &[ColumnSampleBatch], ) -> SecureColumn { let mut res = SecureColumn::zeros(domain.size()); for row in 0..domain.size() { // TODO(alonh): Make an efficient bit reverse domain iterator, possibly for AVX backend. let domain_point = domain.at(bit_reverse_index(row, domain.log_size())); let row_value = - accumulate_row_quotients(samples, columns, row, random_coeff, domain_point); + accumulate_row_quotients(sample_batches, columns, row, random_coeff, domain_point); res.set(row, row_value); } res @@ -32,32 +32,34 @@ impl QuotientOps for CPUBackend { } pub fn accumulate_row_quotients( - samples: &[ColumnSampleBatch], + sample_batches: &[ColumnSampleBatch], columns: &[&CircleEvaluation], row: usize, random_coeff: SecureField, domain_point: CirclePoint, ) -> SecureField { - let mut row_accumlator = SecureField::zero(); - for sample in samples { + let mut row_accumulator = SecureField::zero(); + for sample_batch in sample_batches { let mut numerator = SecureField::zero(); - for (column_index, sampled_value) in &sample.columns_and_values { + for (column_index, sampled_value) in &sample_batch.columns_and_values { let column = &columns[*column_index]; let value = column[row]; - let linear_term = complex_conjugate_line(sample.point, *sampled_value, domain_point); + let linear_term = + complex_conjugate_line(sample_batch.point, *sampled_value, domain_point); numerator = numerator * random_coeff + value - linear_term; } let denominator = pair_vanishing( - sample.point, - sample.point.complex_conjugate(), + sample_batch.point, + sample_batch.point.complex_conjugate(), domain_point.into_ef(), ); - row_accumlator = row_accumlator * random_coeff.pow(sample.columns_and_values.len() as u128) + row_accumulator = row_accumulator + * random_coeff.pow(sample_batch.columns_and_values.len() as u128) + numerator / denominator; } - row_accumlator + row_accumulator } #[cfg(test)] diff --git a/src/core/commitment_scheme/quotients.rs b/src/core/commitment_scheme/quotients.rs index d4d7be58d..513961ca8 100644 --- a/src/core/commitment_scheme/quotients.rs +++ b/src/core/commitment_scheme/quotients.rs @@ -28,7 +28,7 @@ pub trait QuotientOps: Backend { domain: CircleDomain, columns: &[&CircleEvaluation], random_coeff: SecureField, - samples: &[ColumnSampleBatch], + sample_batches: &[ColumnSampleBatch], ) -> SecureColumn; } @@ -43,7 +43,7 @@ impl ColumnSampleBatch { /// Groups column samples by sampled point. /// # Arguments /// samples: For each column, a vector of samples. - pub fn new(samples: &[&Vec]) -> Vec { + pub fn new_vec(samples: &[&Vec]) -> Vec { // Group samples by point, and create a ColumnSampleBatch for each point. // This should keep a stable ordering. let mut grouped_samples = BTreeMap::new(); @@ -83,8 +83,8 @@ pub fn compute_fri_quotients( let (columns, samples): (Vec<_>, Vec<_>) = tuples.unzip(); let domain = CanonicCoset::new(log_size).circle_domain(); // TODO: slice. - let batched_samples = ColumnSampleBatch::new(&samples); - let values = B::accumulate_quotients(domain, &columns, random_coeff, &batched_samples); + let sample_batches = ColumnSampleBatch::new_vec(&samples); + let values = B::accumulate_quotients(domain, &columns, random_coeff, &sample_batches); SecureEvaluation { domain, values } }) .collect() @@ -123,7 +123,7 @@ pub fn fri_answers_for_log_size( queried_values_per_column: &[&Vec], ) -> Result, VerificationError> { let commitment_domain = CanonicCoset::new(log_size).circle_domain(); - let batched_samples = ColumnSampleBatch::new(samples); + let sample_batches = ColumnSampleBatch::new_vec(samples); for queried_values in queried_values_per_column { if queried_values.len() != query_domain.flatten().len() { return Err(VerificationError::InvalidStructure( @@ -152,7 +152,7 @@ pub fn fri_answers_for_log_size( for row in 0..domain.size() { let domain_point = domain.at(bit_reverse_index(row, log_size)); let value = accumulate_row_quotients( - &batched_samples, + &sample_batches, &column_evals.iter().collect_vec(), row, random_coeff,