Skip to content

Commit

Permalink
Fix CPU backend evaluation and interpolation bugs (#645)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored Jul 9, 2024
1 parent b3f9285 commit 56342f6
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 25 deletions.
203 changes: 179 additions & 24 deletions crates/prover/src/core/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,37 @@ impl PolyOps for CpuBackend {
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
twiddles: &TwiddleTree<Self>,
) -> CirclePoly<Self> {
let mut values = eval.values;

assert!(eval.domain.half_coset.is_doubling_of(twiddles.root_coset));
let line_twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles);

let mut values = eval.values;

if eval.domain.log_size() == 1 {
let (mut val0, mut val1) = (values[0], values[1]);
ibutterfly(
&mut val0,
&mut val1,
eval.domain.half_coset.initial.y.inverse(),
);
let inv = BaseField::from_u32_unchecked(2).inverse();
(values[0], values[1]) = (val0 * inv, val1 * inv);
return CirclePoly::new(values);
};
let y = eval.domain.half_coset.initial.y;
let n = BaseField::from(2);
let yn_inv = (y * n).inverse();
let y_inv = yn_inv * n;
let n_inv = yn_inv * y;
let (mut v0, mut v1) = (values[0], values[1]);
ibutterfly(&mut v0, &mut v1, y_inv);
return CirclePoly::new(vec![v0 * n_inv, v1 * n_inv]);
}

if eval.domain.log_size() == 2 {
let CirclePoint { x, y } = eval.domain.half_coset.initial;
let n = BaseField::from(4);
let xyn_inv = (x * y * n).inverse();
let x_inv = xyn_inv * y * n;
let y_inv = xyn_inv * x * n;
let n_inv = xyn_inv * x * y;
let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]);
ibutterfly(&mut v0, &mut v1, y_inv);
ibutterfly(&mut v2, &mut v3, -y_inv);
ibutterfly(&mut v0, &mut v2, x_inv);
ibutterfly(&mut v1, &mut v3, x_inv);
return CirclePoly::new(vec![v0 * n_inv, v1 * n_inv, v2 * n_inv, v3 * n_inv]);
}

let line_twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles);
let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]);

for (h, t) in circle_twiddles.enumerate() {
Expand All @@ -71,14 +85,18 @@ impl PolyOps for CpuBackend {
}

fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField {
// TODO(Andrew): Allocation here expensive for small polynomials.
let mut mappings = vec![point.y, point.x];
if poly.log_size() == 0 {
return poly.coeffs[0].into();
}

let mut mappings = vec![point.y];
let mut x = point.x;
for _ in 2..poly.log_size() {
x = CirclePoint::double_x(x);
for _ in 1..poly.log_size() {
mappings.push(x);
x = CirclePoint::double_x(x);
}
mappings.reverse();

fold(&poly.coeffs, &mappings)
}

Expand All @@ -95,17 +113,27 @@ impl PolyOps for CpuBackend {
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
let mut values = poly.extend(domain.log_size()).coeffs;

assert!(domain.half_coset.is_doubling_of(twiddles.root_coset));
let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);

let mut values = poly.extend(domain.log_size()).coeffs;

if domain.log_size() == 1 {
let (mut val0, mut val1) = (values[0], values[1]);
butterfly(&mut val0, &mut val1, domain.half_coset.initial.y.inverse());
return CircleEvaluation::new(domain, values);
};
let (mut v0, mut v1) = (values[0], values[1]);
butterfly(&mut v0, &mut v1, domain.half_coset.initial.y);
return CircleEvaluation::new(domain, vec![v0, v1]);
}

if domain.log_size() == 2 {
let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]);
let CirclePoint { x, y } = domain.half_coset.initial;
butterfly(&mut v0, &mut v2, x);
butterfly(&mut v1, &mut v3, x);
butterfly(&mut v0, &mut v1, y);
butterfly(&mut v2, &mut v3, -y);
return CircleEvaluation::new(domain, vec![v0, v1, v2, v3]);
}

let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]);

for (layer, layer_twiddles) in line_twiddles.iter().enumerate().rev() {
Expand Down Expand Up @@ -184,6 +212,8 @@ fn fft_layer_loop(
}

/// Computes the circle twiddles layer (layer 0) from the first line twiddles layer (layer 1).
///
/// Only works for line twiddles generated from a domain with size `>4`.
fn circle_twiddles_from_line_twiddles(
first_line_twiddles: &[BaseField],
) -> impl Iterator<Item = BaseField> + '_ {
Expand Down Expand Up @@ -219,3 +249,128 @@ impl<F: ExtensionOf<BaseField>, EvalOrder> IntoIterator
self.values.into_iter()
}
}

#[cfg(test)]
mod tests {
use std::iter::zip;

use num_traits::One;

use crate::core::backend::cpu::CpuCirclePoly;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::CanonicCoset;

#[test]
fn test_eval_at_point_with_4_coeffs() {
// Represents the polynomial `1 + 2y + 3x + 4xy`.
// Note coefficients are passed in bit reversed order.
let poly = CpuCirclePoly::new([1, 3, 2, 4].map(BaseField::from).to_vec());
let x = BaseField::from(5).into();
let y = BaseField::from(8).into();

let eval = poly.eval_at_point(CirclePoint { x, y });

assert_eq!(
eval,
poly.coeffs[0] + poly.coeffs[1] * y + poly.coeffs[2] * x + poly.coeffs[3] * x * y
);
}

#[test]
fn test_eval_at_point_with_2_coeffs() {
// Represents the polynomial `1 + 2y`.
let poly = CpuCirclePoly::new(vec![BaseField::from(1), BaseField::from(2)]);
let x = BaseField::from(5).into();
let y = BaseField::from(8).into();

let eval = poly.eval_at_point(CirclePoint { x, y });

assert_eq!(eval, poly.coeffs[0] + poly.coeffs[1] * y);
}

#[test]
fn test_eval_at_point_with_1_coeff() {
// Represents the polynomial `1`.
let poly = CpuCirclePoly::new(vec![BaseField::one()]);
let x = BaseField::from(5).into();
let y = BaseField::from(8).into();

let eval = poly.eval_at_point(CirclePoint { x, y });

assert_eq!(eval, SecureField::one());
}

#[test]
fn test_evaluate_2_coeffs() {
let domain = CanonicCoset::new(1).circle_domain();
let poly = CpuCirclePoly::new((1..=2).map(BaseField::from).collect());

let evaluation = poly.clone().evaluate(domain).bit_reverse();

for (i, (p, eval)) in zip(domain, evaluation).enumerate() {
let eval: SecureField = eval.into();
assert_eq!(eval, poly.eval_at_point(p.into_ef()), "mismatch at i={i}");
}
}

#[test]
fn test_evaluate_4_coeffs() {
let domain = CanonicCoset::new(2).circle_domain();
let poly = CpuCirclePoly::new((1..=4).map(BaseField::from).collect());

let evaluation = poly.clone().evaluate(domain).bit_reverse();

for (i, (x, eval)) in zip(domain, evaluation).enumerate() {
let eval: SecureField = eval.into();
assert_eq!(eval, poly.eval_at_point(x.into_ef()), "mismatch at i={i}");
}
}

#[test]
fn test_evaluate_8_coeffs() {
let domain = CanonicCoset::new(3).circle_domain();
let poly = CpuCirclePoly::new((1..=8).map(BaseField::from).collect());

let evaluation = poly.clone().evaluate(domain).bit_reverse();

for (i, (x, eval)) in zip(domain, evaluation).enumerate() {
let eval: SecureField = eval.into();
assert_eq!(eval, poly.eval_at_point(x.into_ef()), "mismatch at i={i}");
}
}

#[test]
fn test_interpolate_2_evals() {
let poly = CpuCirclePoly::new(vec![BaseField::one(), BaseField::from(2)]);
let domain = CanonicCoset::new(1).circle_domain();
let evals = poly.clone().evaluate(domain);

let interpolated_poly = evals.interpolate();

assert_eq!(interpolated_poly.coeffs, poly.coeffs);
}

#[test]
fn test_interpolate_4_evals() {
let poly = CpuCirclePoly::new((1..=4).map(BaseField::from).collect());
let domain = CanonicCoset::new(2).circle_domain();
let evals = poly.clone().evaluate(domain);

let interpolated_poly = evals.interpolate();

assert_eq!(interpolated_poly.coeffs, poly.coeffs);
}

#[test]
fn test_interpolate_8_evals() {
let poly = CpuCirclePoly::new((1..=8).map(BaseField::from).collect());
let domain = CanonicCoset::new(3).circle_domain();
let evals = poly.clone().evaluate(domain);

let interpolated_poly = evals.interpolate();

assert_eq!(interpolated_poly.coeffs, poly.coeffs);
}
}
1 change: 0 additions & 1 deletion crates/prover/src/core/poly/line.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ impl LinePoly {

/// Evaluates the polynomial at a single point.
pub fn eval_at_point(&self, mut x: SecureField) -> SecureField {
// TODO(Andrew): Allocation here expensive for small polynomials.
let mut doublings = Vec::new();
for _ in 0..self.log_size {
doublings.push(x);
Expand Down

0 comments on commit 56342f6

Please sign in to comment.