Skip to content

Commit

Permalink
Optimize polynomial batching computation (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
moodlezoup authored Jun 27, 2024
1 parent 8d690e8 commit aa5eb7f
Showing 1 changed file with 87 additions and 38 deletions.
125 changes: 87 additions & 38 deletions src/spartan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,33 +51,57 @@ impl<E: Engine> PolyEvalWitness<E> {

let size_max = W.iter().map(|w| w.p.len()).max().unwrap();
// Scale the input polynomials by the power of s
let p = W
.into_par_iter()
.zip_eq(powers.par_iter())
.map(|(mut w, s)| {
if *s != E::Scalar::ONE {
w.p.par_iter_mut().for_each(|e| *e *= s);
}
w.p
})
.reduce(
|| vec![E::Scalar::ZERO; size_max],
|left, right| {
// Sum into the largest polynomial
let (mut big, small) = if left.len() > right.len() {
(left, right)
} else {
(right, left)
};

big
.par_iter_mut()
.zip(small.par_iter())
.for_each(|(b, s)| *b += s);

big
},
);
let num_chunks = rayon::current_num_threads().next_power_of_two();
let chunk_size = size_max / num_chunks;

let p = if chunk_size > 0 {
(0..num_chunks)
.into_par_iter()
.flat_map_iter(|chunk_index| {
let mut chunk = vec![E::Scalar::ZERO; chunk_size];
for (coeff, poly) in powers.iter().zip(W.iter()) {
for (rlc, poly_eval) in chunk
.iter_mut()
.zip(poly.p[chunk_index * chunk_size..].iter())
{
if *coeff == E::Scalar::ONE {
*rlc += *poly_eval;
} else {
*rlc += *coeff * poly_eval;
};
}
}
chunk
})
.collect::<Vec<_>>()
} else {
W.into_par_iter()
.zip_eq(powers.par_iter())
.map(|(mut w, s)| {
if *s != E::Scalar::ONE {
w.p.par_iter_mut().for_each(|e| *e *= s);
}
w.p
})
.reduce(
|| vec![E::Scalar::ZERO; size_max],
|left, right| {
// Sum into the largest polynomial
let (mut big, small) = if left.len() > right.len() {
(left, right)
} else {
(right, left)
};

big
.par_iter_mut()
.zip(small.par_iter())
.for_each(|(b, s)| *b += s);

big
},
)
};

PolyEvalWitness { p }
}
Expand All @@ -96,17 +120,42 @@ impl<E: Engine> PolyEvalWitness<E> {

let powers_of_s = powers::<E>(s, p_vec.len());

let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * *weight).collect::<Vec<E::Scalar>>()
})
.reduce(
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
zip_with!((acc.into_iter(), v), |x, y| x + y).collect()
},
);
let num_chunks = rayon::current_num_threads().next_power_of_two();
let chunk_size = p_vec[0].len() / num_chunks;

let p = if chunk_size > 0 {
(0..num_chunks)
.into_par_iter()
.flat_map_iter(|chunk_index| {
let mut chunk = vec![E::Scalar::ZERO; chunk_size];
for (coeff, poly) in powers_of_s.iter().zip(p_vec.iter()) {
for (rlc, poly_eval) in chunk
.iter_mut()
.zip(poly[chunk_index * chunk_size..].iter())
{
if *coeff == E::Scalar::ONE {
*rlc += *poly_eval;
} else {
*rlc += *coeff * poly_eval;
};
}
}
chunk
})
.collect::<Vec<_>>()
} else {
zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * *weight).collect::<Vec<E::Scalar>>()
})
.reduce(
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
zip_with!((acc.into_iter(), v), |x, y| x + y).collect()
},
)
};

PolyEvalWitness { p }
}
Expand Down

0 comments on commit aa5eb7f

Please sign in to comment.