Skip to content

Commit

Permalink
fix compilation time of large multiply expr
Browse files Browse the repository at this point in the history
  • Loading branch information
siq1 committed Oct 22, 2024
1 parent 100d1fc commit 17f2373
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 22 deletions.
122 changes: 100 additions & 22 deletions expander_compiler/src/builder/final_build_opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,33 +342,49 @@ impl<C: Config> Builder<C> {
Expression::from_terms(cur_terms)
}

fn cmp_expr_for_mul(&self, a: &Expression<C>, b: &Expression<C>) -> std::cmp::Ordering {
let la = self.layer_of_expr(a);
let lb = self.layer_of_expr(b);
if la != lb {
return la.cmp(&lb);
}
let la = a.len();
let lb = b.len();
if la != lb {
return la.cmp(&lb);
}
a.cmp(b)
}

fn mul_vec(&mut self, vars: &[usize]) -> Expression<C> {
use crate::utils::heap::{pop, push};
assert!(vars.len() >= 2);
let mut exprs: Vec<Expression<C>> = vars
.iter()
.map(|&v| self.try_make_single(self.in_var_exprs[v].clone()))
.collect();
while exprs.len() > 1 {
let mut exprs_pos: Vec<usize> = (0..exprs.len()).collect();
exprs_pos.sort_by(|a, b| {
let la = self.layer_of_expr(&exprs[*a]);
let lb = self.layer_of_expr(&exprs[*b]);
if la != lb {
la.cmp(&lb)
} else {
let la = exprs[*a].len();
let lb = exprs[*b].len();
if la != lb {
la.cmp(&lb)
} else {
exprs[*a].cmp(&exprs[*b])
}
}
});
let pos1 = exprs_pos[0];
let pos2 = exprs_pos[1];
let mut expr1 = exprs.swap_remove(pos1);
let mut expr2 = exprs.swap_remove(pos2 - (pos2 > pos1) as usize);
let mut exprs_pos_heap: Vec<usize> = vec![];
let mut next_push_pos = 0;
loop {
while next_push_pos != exprs.len() {
push(&mut exprs_pos_heap, next_push_pos, |a, b| {
self.cmp_expr_for_mul(&exprs[a], &exprs[b])
});
next_push_pos += 1;
}
if exprs_pos_heap.len() == 1 {
break;
}
let pos1 = pop(&mut exprs_pos_heap, |a, b| {
self.cmp_expr_for_mul(&exprs[a], &exprs[b])
})
.unwrap();
let pos2 = pop(&mut exprs_pos_heap, |a, b| {
self.cmp_expr_for_mul(&exprs[a], &exprs[b])
})
.unwrap();
let mut expr1 = std::mem::take(&mut exprs[pos1]);
let mut expr2 = std::mem::take(&mut exprs[pos2]);
if expr1.len() > expr2.len() {
std::mem::swap(&mut expr1, &mut expr2);
}
Expand Down Expand Up @@ -448,7 +464,8 @@ impl<C: Config> Builder<C> {
}
exprs.push(self.lin_comb_inner(vars, |_| C::CircuitField::one()));
}
exprs.remove(0)
let final_pos = exprs_pos_heap.pop().unwrap();
exprs.swap_remove(final_pos)
}

fn add_and_check_if_should_make_single(&mut self, e: Expression<C>) {
Expand Down Expand Up @@ -887,4 +904,65 @@ mod tests {
}
}
}

#[test]
fn large_add() {
let mut root = super::InRootCircuit::<C>::default();
let terms = (1..=100000)
.map(|i| ir::expr::LinCombTerm {
coef: CField::one(),
var: i,
})
.collect();
let lc = ir::expr::LinComb {
terms,
constant: CField::one(),
};
root.circuits.insert(
0,
super::InCircuit::<C> {
instructions: vec![super::InInstruction::<C>::LinComb(lc.clone())],
constraints: vec![100001],
outputs: vec![],
num_inputs: 100000,
},
);
assert_eq!(root.validate(), Ok(()));
let root_processed = super::process(&root).unwrap();
assert_eq!(root_processed.validate(), Ok(()));
match &root_processed.circuits[&0].instructions[0] {
ir::dest::Instruction::InternalVariable { expr } => {
assert_eq!(expr.len(), 100001);
}
_ => panic!(),
}
let inputs: Vec<CField> = (1..=100000).map(|i| CField::from(i)).collect();
let (out, ok) = root.eval_unsafe(inputs.clone());
let (out2, ok2) = root_processed.eval_unsafe(inputs);
assert_eq!(out, out2);
assert_eq!(ok, ok2);
}

#[test]
fn large_mul() {
let mut root = super::InRootCircuit::<C>::default();
let terms: Vec<usize> = (1..=100000).collect();
root.circuits.insert(
0,
super::InCircuit::<C> {
instructions: vec![super::InInstruction::<C>::Mul(terms.clone())],
constraints: vec![100001],
outputs: vec![],
num_inputs: 100000,
},
);
assert_eq!(root.validate(), Ok(()));
let root_processed = super::process(&root).unwrap();
assert_eq!(root_processed.validate(), Ok(()));
let inputs: Vec<CField> = (1..=100000).map(|i| CField::from(i)).collect();
let (out, ok) = root.eval_unsafe(inputs.clone());
let (out2, ok2) = root_processed.eval_unsafe(inputs);
assert_eq!(out, out2);
assert_eq!(ok, ok2);
}
}
63 changes: 63 additions & 0 deletions expander_compiler/src/builder/hint_normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,4 +467,67 @@ mod tests {
}
}
}

#[test]
fn large_add() {
let mut root = ir::common::RootCircuit::<super::IrcIn<C>>::default();
let terms = (1..=100000)
.map(|i| ir::expr::LinCombTerm {
coef: CField::one(),
var: i,
})
.collect();
let lc = ir::expr::LinComb {
terms,
constant: CField::one(),
};
root.circuits.insert(
0,
ir::common::Circuit::<super::IrcIn<C>> {
instructions: vec![ir::source::Instruction::LinComb(lc.clone())],
constraints: vec![ir::source::Constraint {
typ: ir::source::ConstraintType::Zero,
var: 100001,
}],
outputs: vec![],
num_inputs: 100000,
},
);
assert_eq!(root.validate(), Ok(()));
let root_processed = super::process(&root).unwrap();
assert_eq!(root_processed.validate(), Ok(()));
match &root_processed.circuits[&0].instructions[0] {
ir::hint_normalized::Instruction::LinComb(lc2) => {
assert_eq!(lc, *lc2);
}
_ => panic!(),
}
}

#[test]
fn large_mul() {
let mut root = ir::common::RootCircuit::<super::IrcIn<C>>::default();
let terms: Vec<usize> = (1..=100000).collect();
root.circuits.insert(
0,
ir::common::Circuit::<super::IrcIn<C>> {
instructions: vec![ir::source::Instruction::Mul(terms.clone())],
constraints: vec![ir::source::Constraint {
typ: ir::source::ConstraintType::Zero,
var: 100001,
}],
outputs: vec![],
num_inputs: 100000,
},
);
assert_eq!(root.validate(), Ok(()));
let root_processed = super::process(&root).unwrap();
assert_eq!(root_processed.validate(), Ok(()));
match &root_processed.circuits[&0].instructions[0] {
ir::hint_normalized::Instruction::Mul(terms2) => {
assert_eq!(terms, *terms2);
}
_ => panic!(),
}
}
}
73 changes: 73 additions & 0 deletions expander_compiler/src/utils/heap.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Handwritten binary min-heap with custom comparator

use std::cmp::Ordering;

pub fn push<F: Fn(usize, usize) -> Ordering>(s: &mut Vec<usize>, x: usize, cmp: F) {
s.push(x);
let mut i = s.len() - 1;
while i > 0 {
let p = (i - 1) / 2;
if cmp(s[i], s[p]) == Ordering::Less {
s.swap(i, p);
i = p;
} else {
break;
}
}
}

pub fn pop<F: Fn(usize, usize) -> Ordering>(s: &mut Vec<usize>, cmp: F) -> Option<usize> {
if s.is_empty() {
return None;
}
let ret = Some(s[0]);
if s.len() == 1 {
s.pop();
return ret;
}
s[0] = s.pop().unwrap();
let mut i = 0;
while 2 * i + 1 < s.len() {
let mut j = 2 * i + 1;
if j + 1 < s.len() && cmp(s[j + 1], s[j]) == Ordering::Less {
j += 1;
}
if cmp(s[j], s[i]) == Ordering::Less {
s.swap(i, j);
i = j;
} else {
break;
}
}
ret
}

#[cfg(test)]
mod tests {
use super::*;
use rand::{Rng, SeedableRng};
use std::collections::BinaryHeap;

#[test]
fn test_heap() {
let mut my_heap = vec![];
let mut std_heap = BinaryHeap::new();
let mut rng = rand::rngs::StdRng::seed_from_u64(123);
for i in 0..100000 {
let op = if i < 50000 {
rng.gen_range(0..2)
} else {
rng.gen_range(0..3) % 2
};
if op == 0 {
let x = rng.gen_range(0..100000);
push(&mut my_heap, x, |a, b| b.cmp(&a));
std_heap.push(x);
} else {
let x = pop(&mut my_heap, |a, b| b.cmp(&a));
let y = std_heap.pop();
assert_eq!(x, y);
}
}
}
}
1 change: 1 addition & 0 deletions expander_compiler/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod bucket_sort;
pub mod error;
pub mod function_id;
pub mod heap;
pub mod misc;
pub mod pool;
pub mod serde;
Expand Down

0 comments on commit 17f2373

Please sign in to comment.