Skip to content

Commit

Permalink
FRI quotients y optimization (#686)
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored Jul 2, 2024
1 parent ffeb96a commit 6cef1fb
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 57 deletions.
6 changes: 6 additions & 0 deletions crates/prover/src/core/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ pub fn accumulate_row_quotients(
{
let column = &columns[*column_index];
let value = column[row] * *c;
// The numerator is a line equation passing through
// (sample_point.y, sample_value), (conj(sample_point), conj(sample_value))
// evaluated at (domain_point.y, value).
// When substituting a polynomial in this line equation, we get a polynomial with a root
// at sample_point and conj(sample_point) if the original polynomial had the values
// sample_value and conj(sample_value) at these points.
let linear_term = *a * domain_point.y + *b;
numerator += value - linear_term;
}
Expand Down
162 changes: 106 additions & 56 deletions crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use itertools::{izip, zip_eq, Itertools};
use num_traits::Zero;
use tracing::{span, Level};

use super::column::SecureFieldVec;
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
Expand Down Expand Up @@ -31,55 +32,28 @@ impl QuotientOps for SimdBackend {
// TODO(spapini): Move to the caller when Columns support slices.
let (subdomain, mut subdomain_shifts) = domain.split(LOG_BLOWUP_FACTOR);

assert!(subdomain.log_size() >= LOG_N_LANES);
let mut values = SecureColumn::<Self>::zeros(subdomain.size());
let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain);

// TODO(spapini): bit reverse iterator.
for vec_row in 0..1 << (subdomain.log_size() - LOG_N_LANES) {
// TODO(spapini): Optimize this, for the small number of columns case.
let points = std::array::from_fn(|i| {
subdomain.at(bit_reverse_index(
(vec_row << LOG_N_LANES) + i,
subdomain.log_size(),
))
});
let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x));
let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y));
let row_accumulator = accumulate_row_quotients(
sample_batches,
columns,
&quotient_constants,
vec_row,
(domain_points_x, domain_points_y),
);
unsafe { values.set_packed(vec_row, row_accumulator) };
}

// Extend the evaluation to the full domain.
let mut extended_eval = SecureColumn::<Self>::zeros(domain.size());

let mut i = 0;
let values = values.columns;
let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset);
let subeval_polys = values.map(|c| {
i += 1;
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(subdomain, c)
.interpolate_with_twiddles(&twiddles)
});

// Bit reverse the shifts.
// Since we traverse the domain in bit-reversed order, we need bit-reverse the shifts.
// To see why, consider the index of a point in the natural order of the domain
// (least to most):
// b0 b1 b2 b3 b4 b5
// b0 adds P, b1 adds 2P, etc.. (b5 is special and flips the sign of the point).
// b0 adds G, b1 adds 2G, etc.. (b5 is special and flips the sign of the point).
// Splitting the domain to 4 parts yields:
// subdomain: b2 b3 b4 b5, shifts: b0 b1.
// b2 b3 b4 b5 is indeed a circle domain, with a bigger jump.
// Traversing the domain in bit-reversed order, after we finish with b5, b4, b3, b2,
// we need to change b1 and then b0. This is the bit reverse of the shift b0 b1.
bit_reverse(&mut subdomain_shifts);

let (span, mut extended_eval, subeval_polys) = accumulate_quotients_on_subdomain(
subdomain,
sample_batches,
random_coeff,
columns,
domain,
);

// Extend the evaluation to the full domain.
// TODO(spapini): Try to optimize out all these copies.
for (ci, &c) in subdomain_shifts.iter().enumerate() {
let subdomain = subdomain.shift(c);
Expand All @@ -93,6 +67,7 @@ impl QuotientOps for SimdBackend {
.copy_from_slice(&eval.data);
}
}
span.exit();

SecureEvaluation {
domain,
Expand All @@ -101,39 +76,114 @@ impl QuotientOps for SimdBackend {
}
}

fn accumulate_quotients_on_subdomain(
subdomain: CircleDomain,
sample_batches: &[ColumnSampleBatch],
random_coeff: SecureField,
columns: &[&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>],
domain: CircleDomain,
) -> (
span::EnteredSpan,
SecureColumn<SimdBackend>,
[crate::core::poly::circle::CirclePoly<SimdBackend>; 4],
) {
assert!(subdomain.log_size() >= LOG_N_LANES + 2);
let mut values = SecureColumn::<SimdBackend>::zeros(subdomain.size());
let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain);

let span = span!(Level::INFO, "Quotient accumulation").entered();
// TODO(spapini): bit reverse iterator.
for quad_row in 0..1 << (subdomain.log_size() - LOG_N_LANES - 2) {
// TODO(spapini): Use optimized domain iteration.
let spaced_ys = PackedBaseField::from_array(std::array::from_fn(|i| {
subdomain
.at(bit_reverse_index(
(quad_row << (LOG_N_LANES + 2)) + (i << 2),
subdomain.log_size(),
))
.y
}));
let row_accumulator = accumulate_row_quotients(
sample_batches,
columns,
&quotient_constants,
quad_row,
spaced_ys,
);
#[allow(clippy::needless_range_loop)]
for i in 0..4 {
unsafe { values.set_packed((quad_row << 2) + i, row_accumulator[i]) };
}
}
span.exit();
let span = span!(Level::INFO, "Quotient extension").entered();

// Extend the evaluation to the full domain.
let extended_eval = SecureColumn::<SimdBackend>::zeros(domain.size());

let mut i = 0;
let values = values.columns;
let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset);
let subeval_polys = values.map(|c| {
i += 1;
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(subdomain, c)
.interpolate_with_twiddles(&twiddles)
});
(span, extended_eval, subeval_polys)
}

/// Accumulates the quotients for 4 * N_LANES rows at a time.
/// spaced_ys - y values for N_LANES points in the domain, in jumps of 4.
pub fn accumulate_row_quotients(
sample_batches: &[ColumnSampleBatch],
columns: &[&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>],
quotient_constants: &QuotientConstants<SimdBackend>,
vec_row: usize,
domain_point_vec: (PackedBaseField, PackedBaseField),
) -> PackedSecureField {
let mut row_accumulator = PackedSecureField::zero();
quad_row: usize,
spaced_ys: PackedBaseField,
) -> [PackedSecureField; 4] {
let mut row_accumulator = [PackedSecureField::zero(); 4];
for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!(
sample_batches,
&quotient_constants.line_coeffs,
&quotient_constants.batch_random_coeffs,
&quotient_constants.denominator_inverses
) {
let mut numerator = PackedSecureField::zero();
let mut numerator = [PackedSecureField::zero(); 4];
for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs)
{
let column = &columns[*column_index];
let value = PackedSecureField::broadcast(*c) * column.data[vec_row];
// The numerator is a line equation passing through
// (sample_point.y, sample_value), (conj(sample_point), conj(sample_value))
// evaluated at (domain_point.y, value).
// When substituting a polynomial in this line equation, we get a polynomial with a root
// at sample_point and conj(sample_point) if the original polynomial had the values
// sample_value and conj(sample_value) at these points.
// TODO(AlonH): Use single point vanishing to save a multiplication.
let linear_term = PackedSecureField::broadcast(*a) * domain_point_vec.1
+ PackedSecureField::broadcast(*b);
numerator += value - linear_term;
let cvalues: [_; 4] = std::array::from_fn(|i| {
PackedSecureField::broadcast(*c) * column.data[(quad_row << 2) + i]
});

// The numerator is the line equation:
// c * value - a * point.y - b;
// Note that a, b, c were already multilpied by random_coeff^i.
// See [column_line_coeffs()] for more details.
// This is why we only add here.
// 4 consecutive point in the domain in bit reversed order are:
// P, -P, P + H, -P + H.
// H being the half point (-1,0). The y values for these are
// P.y, -P.y, -P.y, P.y.
// We use this fact to save multiplications.
// spaced_ys are the y value in jumps of 4:
// P0.y, P1.y, P2.y, ...
let spaced_ay = PackedSecureField::broadcast(*a) * spaced_ys;
// t0:t1 = a*P0.y, -a*P0.y, a*P1.y, -a*P1.y, ...
let (t0, t1) = spaced_ay.interleave(-spaced_ay);
// t2:t3:t4:t5 = a*P0.y, -a*P0.y, -a*P0.y, a*P0.y, a*P1.y, -a*P1.y, ...
let (t2, t3) = t0.interleave(-t0);
let (t4, t5) = t1.interleave(-t1);
let ay = [t2, t3, t4, t5];
for i in 0..4 {
numerator[i] += cvalues[i] - ay[i] - PackedSecureField::broadcast(*b);
}
}

row_accumulator = row_accumulator * PackedSecureField::broadcast(*batch_coeff)
+ numerator * denominator_inverses.data[vec_row];
for i in 0..4 {
row_accumulator[i] = row_accumulator[i] * PackedSecureField::broadcast(*batch_coeff)
+ numerator[i] * denominator_inverses.data[(quad_row << 2) + i];
}
}
row_accumulator
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ mod tests {

// Get from environment variable:
let log_n_instances = env::var("LOG_N_INSTANCES")
.unwrap_or_else(|_| "8".to_string())
.unwrap_or_else(|_| "10".to_string())
.parse::<u32>()
.unwrap();
let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32;
Expand Down

0 comments on commit 6cef1fb

Please sign in to comment.