Skip to content

Commit

Permalink
add more large add/mul expr tests
Browse files Browse the repository at this point in the history
  • Loading branch information
siq1 committed Nov 1, 2024
1 parent 17f2373 commit c9be69f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
1 change: 1 addition & 0 deletions expander_compiler/src/circuit/ir/source/chains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ impl<C: Config> Circuit<C> {
}

impl<C: Config> RootCircuit<C> {
// this function must be used with remove_unreachable
pub fn detect_chains(&mut self) {
for (_, circuit) in self.circuits.iter_mut() {
circuit.detect_chains();
Expand Down
74 changes: 73 additions & 1 deletion expander_compiler/src/circuit/ir/source/tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use rand::{Rng, RngCore};

use super::{
ConstraintType,
Circuit, ConstraintType,
Instruction::{self, ConstantLike, LinComb, Mul},
RootCircuit,
};
Expand Down Expand Up @@ -190,3 +190,75 @@ fn opt_remove_unreachable_2() {
}
}
}

fn test_detect_chains_inner(is_mul: bool, seq_typ: usize) {
let n = 1000000;
let mut root = RootCircuit::<C>::default();
let mut insns = vec![];
let mut lst = 1;
let get_insn = if is_mul {
|x, y| Instruction::<C>::Mul(vec![x, y])
} else {
|x, y| {
Instruction::LinComb(expr::LinComb {
terms: vec![
expr::LinCombTerm {
coef: CField::one(),
var: x,
},
expr::LinCombTerm {
coef: CField::one(),
var: y,
},
],
constant: CField::zero(),
})
}
};
if seq_typ == 1 {
lst = n;
for i in (1..n).rev() {
insns.push(get_insn(lst, i));
lst = n * 2 - i;
}
} else if seq_typ == 2 {
for i in 2..=n {
insns.push(get_insn(lst, i));
lst = n - 1 + i;
}
} else {
let mut q: Vec<usize> = (1..=n).collect();
let mut i = 0;
lst = n;
while i + 1 < q.len() {
lst += 1;
insns.push(get_insn(q[i], q[i + 1]));
q.push(lst);
i += 2;
}
}
root.circuits.insert(
0,
Circuit::<C> {
num_inputs: n,
instructions: insns,
constraints: vec![],
outputs: vec![lst],
},
);
assert_eq!(root.validate(), Ok(()));
root.detect_chains();
let (root, _) = root.remove_unreachable();
println!("{:?}", root);
assert_eq!(root.validate(), Ok(()));
}

#[test]
fn test_detect_chains() {
test_detect_chains_inner(false, 1);
test_detect_chains_inner(false, 2);
test_detect_chains_inner(false, 3);
test_detect_chains_inner(true, 1);
test_detect_chains_inner(true, 2);
test_detect_chains_inner(true, 3);
}

0 comments on commit c9be69f

Please sign in to comment.