diff --git a/halo2_proofs/src/arithmetic.rs b/halo2_proofs/src/arithmetic.rs index 76a40e71b..99aafe1d2 100644 --- a/halo2_proofs/src/arithmetic.rs +++ b/halo2_proofs/src/arithmetic.rs @@ -110,16 +110,34 @@ pub fn eval_polynomial(poly: &[F], point: F) -> F { /// This computes the inner product of two vectors `a` and `b`. /// /// This function will panic if the two vectors are not the same size. +/// For vectors smaller than 32 elements, it uses sequential computation for better performance. +/// For larger vectors, it switches to parallel computation using multiple threads. pub fn compute_inner_product(a: &[F], b: &[F]) -> F { - // TODO: parallelize? assert_eq!(a.len(), b.len()); - - let mut acc = F::ZERO; - for (a, b) in a.iter().zip(b.iter()) { - acc += (*a) * (*b); + + if a.len() < 32 { + // Use sequential computation for small vectors + let mut acc = F::ZERO; + for (a, b) in a.iter().zip(b.iter()) { + acc += (*a) * (*b); + } + return acc; } - acc + // Use parallel computation for large vectors + let mut products = vec![F::ZERO; a.len()]; + parallelize(&mut products, |products, chunk_size| { + for (((a, b), product), i) in a + .chunks(chunk_size) + .zip(b.chunks(chunk_size)) + .zip(products) + .zip(0..) + { + *product = a.iter().zip(b.iter()).fold(F::ZERO, |acc, (a, b)| acc + (*a) * (*b)); + } + }); + + products.iter().fold(F::ZERO, |acc, product| acc + *product) } /// Divides polynomial `a` in `X` by `X - b` with @@ -328,3 +346,30 @@ fn test_lagrange_interpolate() { } } } + +#[cfg(test)] +mod tests { + use super::*; + use rand_core::OsRng; + + #[test] + fn test_compute_inner_product() { + let rng = OsRng; + + // Test small vectors (sequential) + let a_small: Vec = (0..16).map(|_| Fp::random(rng)).collect(); + let b_small: Vec = (0..16).map(|_| Fp::random(rng)).collect(); + let result_small = compute_inner_product(&a_small, &b_small); + let expected_small = a_small.iter().zip(b_small.iter()) + .fold(Fp::ZERO, |acc, (a, b)| acc + (*a) * (*b)); + assert_eq!(result_small, expected_small); + + // Test large vectors (parallel) + let a_large: Vec = (0..64).map(|_| Fp::random(rng)).collect(); + let b_large: Vec = (0..64).map(|_| Fp::random(rng)).collect(); + let result_large = compute_inner_product(&a_large, &b_large); + let expected_large = a_large.iter().zip(b_large.iter()) + .fold(Fp::ZERO, |acc, (a, b)| acc + (*a) * (*b)); + assert_eq!(result_large, expected_large); + } +}