diff --git a/expander_compiler/src/circuit/ir/source/chains.rs b/expander_compiler/src/circuit/ir/source/chains.rs index d90af87..64ed3b7 100644 --- a/expander_compiler/src/circuit/ir/source/chains.rs +++ b/expander_compiler/src/circuit/ir/source/chains.rs @@ -140,6 +140,7 @@ impl Circuit { } impl RootCircuit { + // this function must be used with remove_unreachable pub fn detect_chains(&mut self) { for (_, circuit) in self.circuits.iter_mut() { circuit.detect_chains(); diff --git a/expander_compiler/src/circuit/ir/source/tests.rs b/expander_compiler/src/circuit/ir/source/tests.rs index e4f7e08..7b789b3 100644 --- a/expander_compiler/src/circuit/ir/source/tests.rs +++ b/expander_compiler/src/circuit/ir/source/tests.rs @@ -1,7 +1,7 @@ use rand::{Rng, RngCore}; use super::{ - ConstraintType, + Circuit, ConstraintType, Instruction::{self, ConstantLike, LinComb, Mul}, RootCircuit, }; @@ -190,3 +190,75 @@ fn opt_remove_unreachable_2() { } } } + +fn test_detect_chains_inner(is_mul: bool, seq_typ: usize) { + let n = 1000000; + let mut root = RootCircuit::::default(); + let mut insns = vec![]; + let mut lst = 1; + let get_insn = if is_mul { + |x, y| Instruction::::Mul(vec![x, y]) + } else { + |x, y| { + Instruction::LinComb(expr::LinComb { + terms: vec![ + expr::LinCombTerm { + coef: CField::one(), + var: x, + }, + expr::LinCombTerm { + coef: CField::one(), + var: y, + }, + ], + constant: CField::zero(), + }) + } + }; + if seq_typ == 1 { + lst = n; + for i in (1..n).rev() { + insns.push(get_insn(lst, i)); + lst = n * 2 - i; + } + } else if seq_typ == 2 { + for i in 2..=n { + insns.push(get_insn(lst, i)); + lst = n - 1 + i; + } + } else { + let mut q: Vec = (1..=n).collect(); + let mut i = 0; + lst = n; + while i + 1 < q.len() { + lst += 1; + insns.push(get_insn(q[i], q[i + 1])); + q.push(lst); + i += 2; + } + } + root.circuits.insert( + 0, + Circuit:: { + num_inputs: n, + instructions: insns, + constraints: vec![], + outputs: vec![lst], + }, + ); + assert_eq!(root.validate(), Ok(())); + root.detect_chains(); + let (root, _) = root.remove_unreachable(); + println!("{:?}", root); + assert_eq!(root.validate(), Ok(())); +} + +#[test] +fn test_detect_chains() { + test_detect_chains_inner(false, 1); + test_detect_chains_inner(false, 2); + test_detect_chains_inner(false, 3); + test_detect_chains_inner(true, 1); + test_detect_chains_inner(true, 2); + test_detect_chains_inner(true, 3); +}