diff --git a/crates/prover/src/core/air/air_ext.rs b/crates/prover/src/core/air/air_ext.rs index 3389b2cf3..e870a90fe 100644 --- a/crates/prover/src/core/air/air_ext.rs +++ b/crates/prover/src/core/air/air_ext.rs @@ -90,35 +90,39 @@ pub trait AirExt: Air { &'a self, trees: &'a [CommitmentTreeProver], ) -> Vec> { - let poly_iter = &mut trees[0].polynomials.iter(); - let eval_iter = &mut trees[0].evaluations.iter(); - let mut component_traces = vec![]; - self.components().iter().for_each(|component| { - let n_columns = component.trace_log_degree_bounds()[0].len(); - let polys = poly_iter.take(n_columns).collect_vec(); - let evals = eval_iter.take(n_columns).collect_vec(); - - component_traces.push(ComponentTrace { - polys: TreeVec::new(vec![polys]), - evals: TreeVec::new(vec![evals]), - }); - }); + let mut poly_iters = trees + .iter() + .map(|tree| tree.polynomials.iter()) + .collect_vec(); + let mut eval_iters = trees + .iter() + .map(|tree| tree.evaluations.iter()) + .collect_vec(); - if trees.len() > 1 { - let poly_iter = &mut trees[1].polynomials.iter(); - let eval_iter = &mut trees[1].evaluations.iter(); - self.components() - .iter() - .zip_eq(&mut component_traces) - .for_each(|(component, component_trace)| { - let n_columns = component.trace_log_degree_bounds()[1].len(); - let polys = poly_iter.take(n_columns).collect_vec(); - let evals = eval_iter.take(n_columns).collect_vec(); - component_trace.polys.push(polys); - component_trace.evals.push(evals); - }); - } - component_traces + self.components() + .iter() + .map(|component| { + let col_sizes_per_tree = component + .trace_log_degree_bounds() + .iter() + .map(|col_sizes| col_sizes.len()) + .collect_vec(); + let polys = col_sizes_per_tree + .iter() + .zip(poly_iters.iter_mut()) + .map(|(n_columns, iter)| iter.take(*n_columns).collect_vec()) + .collect_vec(); + let evals = col_sizes_per_tree + .iter() + .zip(eval_iters.iter_mut()) + .map(|(n_columns, iter)| iter.take(*n_columns).collect_vec()) + .collect_vec(); + ComponentTrace { + polys: TreeVec::new(polys), + evals: TreeVec::new(evals), + } + }) + .collect_vec() } }