From 0a06e1c5302d6a072574247edb20f274742b7d88 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:11:15 +0900 Subject: [PATCH 1/3] add more large add/mul expr tests (#43) --- .../src/circuit/ir/source/chains.rs | 1 + .../src/circuit/ir/source/tests.rs | 74 ++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/expander_compiler/src/circuit/ir/source/chains.rs b/expander_compiler/src/circuit/ir/source/chains.rs index d90af87..64ed3b7 100644 --- a/expander_compiler/src/circuit/ir/source/chains.rs +++ b/expander_compiler/src/circuit/ir/source/chains.rs @@ -140,6 +140,7 @@ impl Circuit { } impl RootCircuit { + // this function must be used with remove_unreachable pub fn detect_chains(&mut self) { for (_, circuit) in self.circuits.iter_mut() { circuit.detect_chains(); diff --git a/expander_compiler/src/circuit/ir/source/tests.rs b/expander_compiler/src/circuit/ir/source/tests.rs index e4f7e08..7b789b3 100644 --- a/expander_compiler/src/circuit/ir/source/tests.rs +++ b/expander_compiler/src/circuit/ir/source/tests.rs @@ -1,7 +1,7 @@ use rand::{Rng, RngCore}; use super::{ - ConstraintType, + Circuit, ConstraintType, Instruction::{self, ConstantLike, LinComb, Mul}, RootCircuit, }; @@ -190,3 +190,75 @@ fn opt_remove_unreachable_2() { } } } + +fn test_detect_chains_inner(is_mul: bool, seq_typ: usize) { + let n = 1000000; + let mut root = RootCircuit::::default(); + let mut insns = vec![]; + let mut lst = 1; + let get_insn = if is_mul { + |x, y| Instruction::::Mul(vec![x, y]) + } else { + |x, y| { + Instruction::LinComb(expr::LinComb { + terms: vec![ + expr::LinCombTerm { + coef: CField::one(), + var: x, + }, + expr::LinCombTerm { + coef: CField::one(), + var: y, + }, + ], + constant: CField::zero(), + }) + } + }; + if seq_typ == 1 { + lst = n; + for i in (1..n).rev() { + insns.push(get_insn(lst, i)); + lst = n * 2 - i; + } + } else if seq_typ == 2 { + for i in 2..=n { + insns.push(get_insn(lst, i)); + lst = n - 1 + i; + } + } else { + let mut q: Vec = (1..=n).collect(); + let mut i = 0; + lst = n; + while i + 1 < q.len() { + lst += 1; + insns.push(get_insn(q[i], q[i + 1])); + q.push(lst); + i += 2; + } + } + root.circuits.insert( + 0, + Circuit:: { + num_inputs: n, + instructions: insns, + constraints: vec![], + outputs: vec![lst], + }, + ); + assert_eq!(root.validate(), Ok(())); + root.detect_chains(); + let (root, _) = root.remove_unreachable(); + println!("{:?}", root); + assert_eq!(root.validate(), Ok(())); +} + +#[test] +fn test_detect_chains() { + test_detect_chains_inner(false, 1); + test_detect_chains_inner(false, 2); + test_detect_chains_inner(false, 3); + test_detect_chains_inner(true, 1); + test_detect_chains_inner(true, 2); + test_detect_chains_inner(true, 3); +} From 5e136ea295fff03e71d1e0ad7c881281d2ff779d Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:16:12 +0900 Subject: [PATCH 2/3] implement mul gate fanout limit (#48) * implement mul gate fanout limit * fmt * update gate order to compare input first * clippy * add dump circuit test --- expander_compiler/src/circuit/ir/dest/mod.rs | 1 + .../src/circuit/ir/dest/mul_fanout_limit.rs | 477 ++++++++++++++++++ expander_compiler/src/circuit/ir/expr.rs | 23 +- expander_compiler/src/circuit/layered/opt.rs | 36 +- expander_compiler/src/compile/mod.rs | 28 + expander_compiler/src/frontend/mod.rs | 16 + expander_compiler/src/layering/wire.rs | 5 +- expander_compiler/tests/mul_fanout_limit.rs | 74 +++ 8 files changed, 631 insertions(+), 29 deletions(-) create mode 100644 expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs create mode 100644 expander_compiler/tests/mul_fanout_limit.rs diff --git a/expander_compiler/src/circuit/ir/dest/mod.rs b/expander_compiler/src/circuit/ir/dest/mod.rs index 07415c7..f6cdc75 100644 --- a/expander_compiler/src/circuit/ir/dest/mod.rs +++ b/expander_compiler/src/circuit/ir/dest/mod.rs @@ -16,6 +16,7 @@ use super::{ pub mod tests; pub mod display; +pub mod mul_fanout_limit; #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum Instruction { diff --git a/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs new file mode 100644 index 0000000..442407f --- /dev/null +++ b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs @@ -0,0 +1,477 @@ +use super::*; + +// This module contains the implementation of the optimization that reduces the fanout of the input variables in multiplication gates. +// There are two ways to reduce the fanout of a variable: +// 1. Copy the whole expression to a new variable. This will copy all gates, and may increase the number of gates by a lot. +// 2. Create a relay expression of the variable. This may increase the layer of the circuit by 1. + +// These are the limits for the first method. +const MAX_COPIES_OF_VARIABLES: usize = 4; +const MAX_COPIES_OF_GATES: usize = 64; + +fn compute_max_copy_cnt(num_gates: usize) -> usize { + if num_gates == 0 { + return 0; + } + MAX_COPIES_OF_VARIABLES.min(MAX_COPIES_OF_GATES / num_gates) +} + +struct NewIdQueue { + queue: Vec<(usize, usize)>, + next: usize, + default_id: usize, +} + +impl NewIdQueue { + fn new(default_id: usize) -> Self { + Self { + queue: Vec::new(), + next: 0, + default_id, + } + } + + fn push(&mut self, id: usize, num: usize) { + self.queue.push((id, num)); + } + + fn get(&mut self) -> usize { + while self.next < self.queue.len() { + let (id, num) = self.queue[self.next]; + if num > 0 { + self.queue[self.next].1 -= 1; + return id; + } + self.next += 1; + } + self.default_id + } +} + +impl CircuitRelaxed { + fn solve_mul_fanout_limit(&self, limit: usize) -> CircuitRelaxed { + let mut max_copy_cnt = vec![0; self.num_inputs + 1]; + let mut mul_ref_cnt = vec![0; self.num_inputs + 1]; + let mut internal_var_insn_id = vec![None; self.num_inputs + 1]; + + for (i, insn) in self.instructions.iter().enumerate() { + match insn { + Instruction::ConstantLike { .. } => { + mul_ref_cnt.push(0); + max_copy_cnt.push(compute_max_copy_cnt(1)); + internal_var_insn_id.push(None); + } + Instruction::SubCircuitCall { num_outputs, .. } => { + for _ in 0..*num_outputs { + mul_ref_cnt.push(0); + max_copy_cnt.push(0); + internal_var_insn_id.push(None); + } + } + Instruction::InternalVariable { expr } => { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += 1; + mul_ref_cnt[y] += 1; + } + } + mul_ref_cnt.push(0); + max_copy_cnt.push(compute_max_copy_cnt(expr.len())); + internal_var_insn_id.push(Some(i)) + } + } + } + + let mut add_copy_cnt = vec![0; max_copy_cnt.len()]; + let mut relay_cnt = vec![0; max_copy_cnt.len()]; + let mut any_new = false; + + for i in (1..max_copy_cnt.len()).rev() { + let mc = max_copy_cnt[i].max(1); + if mul_ref_cnt[i] <= mc * limit { + add_copy_cnt[i] = ((mul_ref_cnt[i] + limit - 1) / limit).max(1) - 1; + any_new = true; + if let Some(j) = internal_var_insn_id[i] { + if let Instruction::InternalVariable { expr } = &self.instructions[j] { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += add_copy_cnt[i]; + mul_ref_cnt[y] += add_copy_cnt[i]; + } + } + } else { + unreachable!(); + } + } + } else { + // mul_ref_cnt[i] + relay_cnt[i] <= limit * (1 + relay_cnt[i]) + relay_cnt[i] = (mul_ref_cnt[i] - 2) / (limit - 1); + any_new = true; + } + } + + if !any_new { + return self.clone(); + } + + let mut new_id = vec![]; + let mut new_insns: Vec> = Vec::new(); + let mut new_var_max = self.num_inputs; + let mut last_solved_id = 0; + + for i in 0..=self.num_inputs { + new_id.push(NewIdQueue::new(i)); + } + + for insn in self.instructions.iter() { + while last_solved_id + 1 < new_id.len() { + last_solved_id += 1; + let x = last_solved_id; + if add_copy_cnt[x] == 0 && relay_cnt[x] == 0 { + continue; + } + let y = new_id[x].default_id; + new_id[x].push(y, limit); + for _ in 0..add_copy_cnt[x] { + let insn = new_insns.last().unwrap().clone(); + new_insns.push(insn); + new_var_max += 1; + new_id[x].push(new_var_max, limit); + } + for _ in 0..relay_cnt[x] { + let y = new_id[x].get(); + new_insns.push(Instruction::InternalVariable { + expr: Expression::new_linear(C::CircuitField::one(), y), + }); + new_var_max += 1; + new_id[x].push(new_var_max, limit); + } + } + match insn { + Instruction::ConstantLike { value } => { + new_insns.push(Instruction::ConstantLike { + value: value.clone(), + }); + new_var_max += 1; + new_id.push(NewIdQueue::new(new_var_max)); + } + Instruction::SubCircuitCall { + sub_circuit_id, + inputs, + num_outputs, + } => { + new_insns.push(Instruction::SubCircuitCall { + sub_circuit_id: *sub_circuit_id, + inputs: inputs.iter().map(|x| new_id[*x].default_id).collect(), + num_outputs: *num_outputs, + }); + for _ in 0..*num_outputs { + new_var_max += 1; + let x = new_id.len(); + new_id.push(NewIdQueue::new(new_var_max)); + assert_eq!(add_copy_cnt[x], 0); + } + } + Instruction::InternalVariable { expr } => { + let x = new_id.len(); + if add_copy_cnt[x] > 0 { + assert_eq!(relay_cnt[x], 0); + } + for _ in 0..=add_copy_cnt[x] { + let mut new_terms = vec![]; + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + new_terms.push(Term { + vars: VarSpec::Quad(new_id[x].get(), new_id[y].get()), + coef: term.coef, + }); + } else { + new_terms.push(Term { + vars: term.vars.replace_vars(|x| new_id[x].default_id), + coef: term.coef, + }); + } + } + new_insns.push(Instruction::InternalVariable { + expr: Expression::from_terms(new_terms), + }); + new_var_max += 1; + } + new_id.push(NewIdQueue::new(new_var_max)); + if add_copy_cnt[x] > 0 { + for i in 0..=add_copy_cnt[x] { + new_id[x].push(new_var_max - add_copy_cnt[x] + i, limit); + } + last_solved_id = x; + } + } + } + } + + CircuitRelaxed { + instructions: new_insns, + num_inputs: self.num_inputs, + outputs: self.outputs.iter().map(|x| new_id[*x].default_id).collect(), + constraints: self + .constraints + .iter() + .map(|x| new_id[*x].default_id) + .collect(), + } + } +} + +impl RootCircuitRelaxed { + pub fn solve_mul_fanout_limit(&self, limit: usize) -> RootCircuitRelaxed { + if limit <= 1 { + panic!("limit must be greater than 1"); + } + + let mut circuits = HashMap::new(); + for (id, circuit) in self.circuits.iter() { + circuits.insert(*id, circuit.solve_mul_fanout_limit(limit)); + } + RootCircuitRelaxed { + circuits, + num_public_inputs: self.num_public_inputs, + expected_num_output_zeroes: self.expected_num_output_zeroes, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit::config::{Config, M31Config as C}; + use crate::field::FieldArith; + use rand::{RngCore, SeedableRng}; + + type CField = ::CircuitField; + + fn verify_mul_fanout(rc: &RootCircuitRelaxed, limit: usize) { + for circuit in rc.circuits.values() { + let mut mul_ref_cnt = vec![0; circuit.num_inputs + 1]; + for insn in circuit.instructions.iter() { + match insn { + Instruction::ConstantLike { .. } => {} + Instruction::SubCircuitCall { .. } => {} + Instruction::InternalVariable { expr } => { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += 1; + mul_ref_cnt[y] += 1; + } + } + } + } + for _ in 0..insn.num_outputs() { + mul_ref_cnt.push(0); + } + } + for x in mul_ref_cnt.iter().skip(1) { + assert!(*x <= limit); + } + } + } + + fn do_test(root: RootCircuitRelaxed, limits: Vec) { + for lim in limits.iter() { + let new_root = root.solve_mul_fanout_limit(*lim); + assert_eq!(new_root.validate(), Ok(())); + assert_eq!(new_root.input_size(), root.input_size()); + verify_mul_fanout(&new_root, *lim); + let inputs: Vec = (0..root.input_size()) + .map(|_| CField::random_unsafe(&mut rand::thread_rng())) + .collect(); + let (out1, cond1) = root.eval_unsafe(inputs.clone()); + let (out2, cond2) = new_root.eval_unsafe(inputs); + assert_eq!(out1, out2); + assert_eq!(cond1, cond2); + } + } + + #[test] + fn fanout_test1() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 2, + }; + for i in 3..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::one(), 1, 2), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test2() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 1, + }; + for _ in 0..2 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(100), 1, 1), + }); + } + for i in 4..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(10), 2, 3), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test3() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 1, + }; + for _ in 0..2 { + circuit.instructions.push(Instruction::SubCircuitCall { + sub_circuit_id: 1, + inputs: vec![1], + num_outputs: 1, + }); + } + for i in 4..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(10), 2, 3), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + root.circuits.insert( + 1, + CircuitRelaxed { + instructions: vec![Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(100), 1, 1), + }], + constraints: vec![], + outputs: vec![2], + num_inputs: 1, + }, + ); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test_random() { + let mut rnd = rand::rngs::StdRng::seed_from_u64(3); + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 100, + }; + let mut q = vec![]; + for i in 1..=100 { + for _ in 0..5 { + q.push(i); + } + if i % 20 == 0 { + for _ in 0..100 { + q.push(i); + } + } + } + + let n = 10003; + + for i in 101..=n { + let mut terms = vec![]; + let mut c = q.len() / 2; + if i != n { + c = c.min(5); + } + for _ in 0..c { + let x = q.swap_remove(rnd.next_u64() as usize % q.len()); + let y = q.swap_remove(rnd.next_u64() as usize % q.len()); + terms.push(Term { + vars: VarSpec::Quad(x, y), + coef: CField::one(), + }); + } + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::from_terms(terms), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + for _ in 0..5 { + q.push(i); + } + if i % 20 == 0 { + for _ in 0..100 { + q.push(i); + } + } + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn full_fanout_test_and_dump() { + use crate::circuit::ir::common::rand_gen::{RandomCircuitConfig, RandomRange}; + use crate::utils::serde::Serde; + + let config = RandomCircuitConfig { + seed: 2, + num_circuits: RandomRange { min: 20, max: 20 }, + num_inputs: RandomRange { min: 1, max: 3 }, + num_instructions: RandomRange { min: 30, max: 50 }, + num_constraints: RandomRange { min: 0, max: 5 }, + num_outputs: RandomRange { min: 1, max: 3 }, + num_terms: RandomRange { min: 1, max: 5 }, + sub_circuit_prob: 0.05, + }; + let root = crate::circuit::ir::source::RootCircuit::::random(&config); + assert_eq!(root.validate(), Ok(())); + let (_, circuit) = crate::compile::compile_with_options( + &root, + crate::compile::CompileOptions::default().with_mul_fanout_limit(256), + ) + .unwrap(); + assert_eq!(circuit.validate(), Ok(())); + for segment in circuit.segments.iter() { + let mut ref_num = vec![0; segment.num_inputs]; + for m in segment.gate_muls.iter() { + ref_num[m.inputs[0]] += 1; + ref_num[m.inputs[1]] += 1; + } + for x in ref_num.iter() { + assert!(*x <= 256); + } + } + + let mut buf = Vec::new(); + circuit.serialize_into(&mut buf).unwrap(); + } +} diff --git a/expander_compiler/src/circuit/ir/expr.rs b/expander_compiler/src/circuit/ir/expr.rs index e9743f8..d672409 100644 --- a/expander_compiler/src/circuit/ir/expr.rs +++ b/expander_compiler/src/circuit/ir/expr.rs @@ -79,6 +79,18 @@ impl VarSpec { (_, VarSpec::RandomLinear(_)) => panic!("unexpected situation: RandomLinear"), } } + pub fn replace_vars usize>(&self, f: F) -> Self { + match self { + VarSpec::Const => VarSpec::Const, + VarSpec::Linear(x) => VarSpec::Linear(f(*x)), + VarSpec::Quad(x, y) => VarSpec::Quad(f(*x), f(*y)), + VarSpec::Custom { gate_type, inputs } => VarSpec::Custom { + gate_type: *gate_type, + inputs: inputs.iter().cloned().map(&f).collect(), + }, + VarSpec::RandomLinear(x) => VarSpec::RandomLinear(f(*x)), + } + } } impl Ord for Term { @@ -310,16 +322,7 @@ impl Expression { .iter() .map(|term| Term { coef: term.coef, - vars: match &term.vars { - VarSpec::Const => VarSpec::Const, - VarSpec::Linear(index) => VarSpec::Linear(f(*index)), - VarSpec::Quad(index1, index2) => VarSpec::Quad(f(*index1), f(*index2)), - VarSpec::Custom { gate_type, inputs } => VarSpec::Custom { - gate_type: *gate_type, - inputs: inputs.iter().cloned().map(&f).collect(), - }, - VarSpec::RandomLinear(index) => VarSpec::RandomLinear(f(*index)), - }, + vars: term.vars.replace_vars(&f), }) .collect(); Expression { terms } diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index 9bc232b..afce7e5 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -17,15 +17,6 @@ impl PartialOrd for Gate { impl Ord for Gate { fn cmp(&self, other: &Self) -> Ordering { - match self.output.cmp(&other.output) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - }; for i in 0..INPUT_NUM { match self.inputs[i].cmp(&other.inputs[i]) { Ordering::Less => { @@ -37,6 +28,15 @@ impl Ord for Gate { Ordering::Equal => {} }; } + match self.output.cmp(&other.output) { + Ordering::Less => { + return Ordering::Less; + } + Ordering::Greater => { + return Ordering::Greater; + } + Ordering::Equal => {} + }; self.coef.cmp(&other.coef) } } @@ -58,15 +58,6 @@ impl Ord for GateCustom { } Ordering::Equal => {} }; - match self.output.cmp(&other.output) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - }; match self.inputs.len().cmp(&other.inputs.len()) { Ordering::Less => { return Ordering::Less; @@ -87,6 +78,15 @@ impl Ord for GateCustom { Ordering::Equal => {} }; } + match self.output.cmp(&other.output) { + Ordering::Less => { + return Ordering::Less; + } + Ordering::Greater => { + return Ordering::Greater; + } + Ordering::Equal => {} + }; self.coef.cmp(&other.coef) } } diff --git a/expander_compiler/src/compile/mod.rs b/expander_compiler/src/compile/mod.rs index a3fa6a0..b4148f5 100644 --- a/expander_compiler/src/compile/mod.rs +++ b/expander_compiler/src/compile/mod.rs @@ -10,6 +10,18 @@ mod random_circuit_tests; #[cfg(test)] mod tests; +#[derive(Default)] +pub struct CompileOptions { + pub mul_fanout_limit: Option, +} + +impl CompileOptions { + pub fn with_mul_fanout_limit(mut self, mul_fanout_limit: usize) -> Self { + self.mul_fanout_limit = Some(mul_fanout_limit); + self + } +} + fn optimize_until_fixed_point(x: &T, im: &mut InputMapping, f: F) -> T where T: Clone + Eq, @@ -49,6 +61,13 @@ fn print_stat(stat_name: &str, stat: usize, is_last: bool) { pub fn compile( r_source: &ir::source::RootCircuit, +) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { + compile_with_options(r_source, CompileOptions::default()) +} + +pub fn compile_with_options( + r_source: &ir::source::RootCircuit, + options: CompileOptions, ) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { r_source.validate()?; @@ -114,6 +133,15 @@ pub fn compile( .validate() .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + let r_dest_relaxed_opt = if let Some(limit) = options.mul_fanout_limit { + r_dest_relaxed_opt.solve_mul_fanout_limit(limit) + } else { + r_dest_relaxed_opt + }; + r_dest_relaxed_opt + .validate() + .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + let r_dest_relaxed_p2 = if C::ENABLE_RANDOM_COMBINATION { r_dest_relaxed_opt } else { diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 1b087b3..0c75145 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -11,6 +11,7 @@ mod witness; pub use circuit::declare_circuit; pub type API = builder::RootBuilder; pub use crate::circuit::config::*; +pub use crate::compile::CompileOptions; pub use crate::field::{Field, BN254, GF2, M31}; pub use crate::utils::error::Error; pub use api::BasicAPI; @@ -64,3 +65,18 @@ pub fn compile + Define layered_circuit: lc, }) } + +pub fn compile_with_options< + C: Config, + Cir: internal::DumpLoadTwoVariables + Define + Clone, +>( + circuit: &Cir, + options: CompileOptions, +) -> Result, Error> { + let root = build(circuit); + let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; + Ok(CompileResult { + witness_solver: WitnessSolver { circuit: irw }, + layered_circuit: lc, + }) +} diff --git a/expander_compiler/src/layering/wire.rs b/expander_compiler/src/layering/wire.rs index c7d0a71..c2cb21f 100644 --- a/expander_compiler/src/layering/wire.rs +++ b/expander_compiler/src/layering/wire.rs @@ -309,8 +309,11 @@ impl<'a, C: Config> CompileContext<'a, C> { }); } VarSpec::Quad(vid0, vid1) => { + let x = aq.var_pos[vid0]; + let y = aq.var_pos[vid1]; + let inputs = if x < y { [x, y] } else { [y, x] }; res.gate_muls.push(GateMul { - inputs: [aq.var_pos[vid0], aq.var_pos[vid1]], + inputs, output: pos, coef: Coef::Constant(term.coef), }); diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/mul_fanout_limit.rs new file mode 100644 index 0000000..c0f3c68 --- /dev/null +++ b/expander_compiler/tests/mul_fanout_limit.rs @@ -0,0 +1,74 @@ +use expander_compiler::frontend::*; + +declare_circuit!(Circuit { + x: [Variable; 16], + y: [Variable; 512], + sum: Variable, +}); + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let mut sum = builder.constant(0); + for i in 0..16 { + for j in 0..512 { + let t = builder.mul(self.x[i], self.y[j]); + sum = builder.add(sum, t); + } + } + builder.assert_is_equal(self.sum, sum); + } +} + +fn mul_fanout_limit(limit: usize) { + let compile_result = compile_with_options( + &Circuit::default(), + CompileOptions::default().with_mul_fanout_limit(limit), + ) + .unwrap(); + let circuit = compile_result.layered_circuit; + for segment in circuit.segments.iter() { + let mut ref_num = vec![0; segment.num_inputs]; + for m in segment.gate_muls.iter() { + ref_num[m.inputs[0]] += 1; + ref_num[m.inputs[1]] += 1; + } + for x in ref_num.iter() { + assert!(*x <= limit); + } + } +} + +#[test] +fn mul_fanout_limit_2() { + mul_fanout_limit(2); +} + +#[test] +fn mul_fanout_limit_3() { + mul_fanout_limit(3); +} + +#[test] +fn mul_fanout_limit_4() { + mul_fanout_limit(4); +} + +#[test] +fn mul_fanout_limit_16() { + mul_fanout_limit(16); +} + +#[test] +fn mul_fanout_limit_64() { + mul_fanout_limit(64); +} + +#[test] +fn mul_fanout_limit_256() { + mul_fanout_limit(256); +} + +#[test] +fn mul_fanout_limit_1024() { + mul_fanout_limit(1024); +} From 3201cdd45f9970476222ae29089ad9a834cee68b Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:16:33 +0900 Subject: [PATCH 3/3] Ecgo const variables (#51) * ecgo const variables * fix --- ecgo/builder/api.go | 105 ++++++++++++++++++++++++++++++++- ecgo/builder/api_assertions.go | 20 +++++++ ecgo/builder/builder.go | 31 ++++++++-- ecgo/utils/gnarkexpr/expr.go | 6 ++ 4 files changed, 157 insertions(+), 5 deletions(-) diff --git a/ecgo/builder/api.go b/ecgo/builder/api.go index 7dd0d24..b26e654 100644 --- a/ecgo/builder/api.go +++ b/ecgo/builder/api.go @@ -53,6 +53,26 @@ func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) f // returns res = Σ(vars) or res = vars[0] - Σ(vars[1:]) if sub == true. func (builder *builder) add(vars []int, sub bool) frontend.Variable { + // check if all variables are constants + allConst := true + if sum, ok := builder.constantValue(vars[0]); ok { + for _, x := range vars[1:] { + if v, ok := builder.constantValue(x); ok { + if sub { + sum = builder.field.Sub(sum, v) + } else { + sum = builder.field.Add(sum, v) + } + } else { + allConst = false + break + } + } + if allConst { + return builder.toVariable(sum) + } + } + coef := make([]constraint.Element, len(vars)) coef[0] = builder.tOne if sub { @@ -75,6 +95,9 @@ func (builder *builder) add(vars []int, sub bool) frontend.Variable { // Neg returns the negation of the given variable. func (builder *builder) Neg(i frontend.Variable) frontend.Variable { v := builder.toVariableId(i) + if c, ok := builder.constantValue(v); ok { + return builder.toVariable(builder.field.Neg(c)) + } coef := []constraint.Element{builder.field.Neg(builder.tOne)} builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.LinComb, @@ -87,6 +110,20 @@ func (builder *builder) Neg(i frontend.Variable) frontend.Variable { // Mul computes the product of the given variables. func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) + allConst := true + if sum, ok := builder.constantValue(vars[0]); ok { + for _, x := range vars[1:] { + if v, ok := builder.constantValue(x); ok { + sum = builder.field.Mul(sum, v) + } else { + allConst = false + break + } + } + if allConst { + return builder.toVariable(sum) + } + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Mul, Inputs: vars, @@ -99,6 +136,18 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] + c1, ok1 := builder.constantValue(v1) + c2, ok2 := builder.constantValue(v2) + if ok1 && ok2 { + if c2.IsZero() { + if c1.IsZero() { + return builder.toVariable(constraint.Element{}) + } + panic("division by zero") + } + inv, _ := builder.field.Inverse(c2) + return builder.toVariable(builder.field.Mul(c1, inv)) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, X: v1, @@ -113,6 +162,15 @@ func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] + c1, ok1 := builder.constantValue(v1) + c2, ok2 := builder.constantValue(v2) + if ok1 && ok2 { + if c2.IsZero() { + panic("division by zero") + } + inv, _ := builder.field.Inverse(c2) + return builder.toVariable(builder.field.Mul(c1, inv)) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, X: v1, @@ -160,6 +218,17 @@ func (builder *builder) Xor(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + t := builder.field.Sub(c1, c2) + if t.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -174,6 +243,16 @@ func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + if c1.IsZero() && c2.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -188,6 +267,16 @@ func (builder *builder) And(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + if c1.IsZero() || c2.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -207,7 +296,15 @@ func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { // ensures that cond is boolean builder.AssertIsBoolean(cond) - v := builder.Sub(i1, i2) // no constraint is recorded + cst, ok := builder.constantValue(builder.toVariableId(cond)) + if ok { + if cst.IsZero() { + return i2 + } + return i1 + } + + v := builder.Sub(i1, i2) w := builder.Mul(cond, v) return builder.Add(w, i2) } @@ -246,6 +343,12 @@ func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten // IsZero returns 1 if the given variable is zero, otherwise returns 0. func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { a := builder.toVariableId(i1) + if c, ok := builder.constantValue(a); ok { + if c.IsZero() { + return builder.toVariable(builder.tOne) + } + return builder.toVariable(constraint.Element{}) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.IsZero, X: a, diff --git a/ecgo/builder/api_assertions.go b/ecgo/builder/api_assertions.go index c010259..ae8d937 100644 --- a/ecgo/builder/api_assertions.go +++ b/ecgo/builder/api_assertions.go @@ -13,6 +13,13 @@ import ( // AssertIsEqual adds an assertion that i1 is equal to i2. func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { x := builder.toVariableId(builder.Sub(i1, i2)) + v, xConstant := builder.constantValue(x) + if xConstant { + if !v.IsZero() { + panic("AssertIsEqual will never be satisfied on nonzero constant") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Zero, Var: x, @@ -22,6 +29,13 @@ func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { // AssertIsDifferent constrains i1 and i2 to have different values. func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { x := builder.toVariableId(builder.Sub(i1, i2)) + v, xConstant := builder.constantValue(x) + if xConstant { + if v.IsZero() { + panic("AssertIsDifferent will never be satisfied on zero constant") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.NonZero, Var: x, @@ -31,6 +45,12 @@ func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { // AssertIsBoolean adds an assertion that the variable is either 0 or 1. func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { x := builder.toVariableId(i1) + if b, ok := builder.constantValue(x); ok { + if !(b.IsZero() || builder.field.IsOne(b)) { + panic("assertIsBoolean failed: constant is not 0 or 1") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Bool, Var: x, diff --git a/ecgo/builder/builder.go b/ecgo/builder/builder.go index e90f8f0..3367ed9 100644 --- a/ecgo/builder/builder.go +++ b/ecgo/builder/builder.go @@ -35,6 +35,8 @@ type builder struct { nbExternalInput int maxVar int + varConstId []int + constValues []constraint.Element // defers (for gnark API) defers []func(frontend.API) error @@ -58,6 +60,8 @@ func (r *Root) newBuilder(nbExternalInput int) *builder { builder.tOne = builder.field.One() builder.maxVar = nbExternalInput + builder.varConstId = make([]int, nbExternalInput+1) + builder.constValues = make([]constraint.Element, 1) return &builder } @@ -106,11 +110,24 @@ func (builder *builder) Compile() (constraint.ConstraintSystem, error) { // ConstantValue returns always returns (nil, false) now, since the Golang frontend doesn't know the values of variables. func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { - return nil, false + coeff, ok := builder.constantValue(builder.toVariableId(v)) + if !ok { + return nil, false + } + return builder.field.ToBigInt(coeff), true +} + +func (builder *builder) constantValue(x int) (constraint.Element, bool) { + i := builder.varConstId[x] + if i == 0 { + return constraint.Element{}, false + } + return builder.constValues[i], true } func (builder *builder) addVarId() int { builder.maxVar += 1 + builder.varConstId = append(builder.varConstId, 0) return builder.maxVar } @@ -124,7 +141,10 @@ func (builder *builder) ceToId(x constraint.Element) int { ExtraId: 0, Const: x, }) - return builder.addVarId() + res := builder.addVarId() + builder.constValues = append(builder.constValues, x) + builder.varConstId[res] = len(builder.constValues) - 1 + return res } // toVariable will return (and allocate if neccesary) an Expression from given value @@ -147,6 +167,10 @@ func (builder *builder) toVariableId(input interface{}) int { } } +func (builder *builder) toVariable(input interface{}) frontend.Variable { + return newVariable(builder.toVariableId(input)) +} + // toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions func (builder *builder) toVariableIds(in ...frontend.Variable) []int { r := make([]int, 0, len(in)) @@ -195,8 +219,7 @@ func (builder *builder) newHintForId(id solver.HintID, nbOutputs int, inputs []f res := make([]frontend.Variable, nbOutputs) for i := 0; i < nbOutputs; i++ { - builder.maxVar += 1 - res[i] = newVariable(builder.maxVar) + res[i] = builder.addVar() } return res, nil } diff --git a/ecgo/utils/gnarkexpr/expr.go b/ecgo/utils/gnarkexpr/expr.go index 115115d..e54ec63 100644 --- a/ecgo/utils/gnarkexpr/expr.go +++ b/ecgo/utils/gnarkexpr/expr.go @@ -22,7 +22,13 @@ func init() { } } +// gnark uses uint32 +const MaxVariables = (1 << 31) - 100 + func NewVar(x int) Expr { + if x < 0 || x >= MaxVariables { + panic("variable id out of range") + } v := builder.InternalVariable(uint32(x)) t := reflect.ValueOf(v).Index(0).Interface().(Expr) if t.WireID() != x {