Skip to content

Commit

Permalink
opt and test
Browse files Browse the repository at this point in the history
  • Loading branch information
siq1 committed Nov 28, 2024
1 parent a06d6f6 commit 586db8f
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 76 deletions.
2 changes: 1 addition & 1 deletion expander_compiler/src/layering/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct CompileContext<'a, C: Config> {

// compiled layered circuits
pub compiled_circuits: Vec<Segment<C>>,
pub conncected_wires: HashMap<u128, usize>,
pub conncected_wires: HashMap<Vec<usize>, usize>,

// layout id of each layer
pub layout_ids: Vec<usize>,
Expand Down
41 changes: 40 additions & 1 deletion expander_compiler/src/layering/tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use crate::circuit::{
config::{Config, M31Config as C},
input_mapping::InputMapping,
ir::{common::rand_gen::*, dest::RootCircuit as IrRootCircuit},
ir::{
common::rand_gen::*,
dest::{Circuit as IrCircuit, Instruction as IrInstruction, RootCircuit as IrRootCircuit},
expr::{Expression, Term},
},
layered,
};

use crate::field::M31 as CField;

use crate::field::FieldArith;

use super::compile;
Expand Down Expand Up @@ -122,3 +128,36 @@ fn random_circuits_4() {
compile_and_random_test(&root, 5);
}
}

#[test]
fn cross_layer_circuit() {
let mut root = IrRootCircuit::<C>::default();
const N: usize = 1000;
root.circuits.insert(
0,
IrCircuit::<C> {
instructions: vec![],
constraints: vec![N * 2 - 1],
outputs: vec![],
num_inputs: N,
},
);
for i in 0..N - 1 {
root.circuits
.get_mut(&0)
.unwrap()
.instructions
.push(IrInstruction::InternalVariable {
expr: Expression::from_terms(vec![
Term::new_linear(CField::one(), N + i),
Term::new_linear(CField::one(), N - i - 1),
]),
});
}
assert_eq!(root.validate(), Ok(()));
let (lc, _) = compile_and_random_test(&root, 5);
assert!((lc.layer_ids.len() as isize - N as isize).abs() <= 10);
for i in lc.layer_ids.iter() {
assert!(lc.segments[*i].gate_adds.len() <= 10);
}
}
166 changes: 92 additions & 74 deletions expander_compiler/src/layering/wire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,16 @@ impl<'a, C: Config> CompileContext<'a, C> {
});
}

let mut cached_ress = Vec::with_capacity(ic.output_layer);
for i in 1..=ic.output_layer {
let key = layout_ids[ic.min_used_layer[i]..=i].to_vec();
cached_ress.push(self.conncected_wires.get(&key).cloned());
}
let all_cached = cached_ress.iter().all(|x| x.is_some());
if all_cached {
return cached_ress.into_iter().map(|x| x.unwrap()).collect();
}

// connect sub circuits
for (i, insn_id) in ic.sub_circuit_insn_ids.iter().enumerate() {
let insn = &ic.sub_circuit_insn_refs[i];
Expand Down Expand Up @@ -262,65 +272,67 @@ impl<'a, C: Config> CompileContext<'a, C> {
if ic.min_layer[x] != 0 {
let next_layer = ic.min_layer[x];
let cur_layer = next_layer - 1;
let res = &mut ress[cur_layer];
let aq = &lqs[cur_layer];
let bq = &lqs[next_layer];
let pos = if let Some(p) = bq.var_pos.get(&x) {
*p
} else {
assert_eq!(cur_layer + 1, ic.output_layer);
continue;
};
if let Some(value) = ic.constant_like_variables.get(&x) {
res.gate_consts.push(GateConst {
inputs: [],
output: pos,
coef: value.clone(),
});
} else if ic.internal_variable_expr.contains_key(&x) {
for term in ic.internal_variable_expr[&x].iter() {
match &term.vars {
VarSpec::Const => {
res.gate_consts.push(GateConst {
inputs: [],
output: pos,
coef: Coef::Constant(term.coef),
});
}
VarSpec::Linear(vid) => {
res.gate_adds.push(GateAdd {
inputs: [Input::new(0, aq.var_pos[vid])],
output: pos,
coef: Coef::Constant(term.coef),
});
}
VarSpec::Quad(vid0, vid1) => {
res.gate_muls.push(GateMul {
inputs: [
Input::new(0, aq.var_pos[vid0]),
Input::new(0, aq.var_pos[vid1]),
],
output: pos,
coef: Coef::Constant(term.coef),
});
}
VarSpec::Custom { gate_type, inputs } => {
res.gate_customs.push(GateCustom {
gate_type: *gate_type,
inputs: inputs
.iter()
.map(|x| Input::new(0, aq.var_pos[x]))
.collect(),
output: pos,
coef: Coef::Constant(term.coef),
});
}
VarSpec::RandomLinear(vid) => {
res.gate_adds.push(GateAdd {
inputs: [Input::new(0, aq.var_pos[vid])],
output: pos,
coef: Coef::Random,
});
if cached_ress[cur_layer].is_none() {
let res = &mut ress[cur_layer];
let aq = &lqs[cur_layer];
let bq = &lqs[next_layer];
let pos = if let Some(p) = bq.var_pos.get(&x) {
*p
} else {
assert_eq!(cur_layer + 1, ic.output_layer);
continue;
};
if let Some(value) = ic.constant_like_variables.get(&x) {
res.gate_consts.push(GateConst {
inputs: [],
output: pos,
coef: value.clone(),
});
} else if ic.internal_variable_expr.contains_key(&x) {
for term in ic.internal_variable_expr[&x].iter() {
match &term.vars {
VarSpec::Const => {
res.gate_consts.push(GateConst {
inputs: [],
output: pos,
coef: Coef::Constant(term.coef),
});
}
VarSpec::Linear(vid) => {
res.gate_adds.push(GateAdd {
inputs: [Input::new(0, aq.var_pos[vid])],
output: pos,
coef: Coef::Constant(term.coef),
});
}
VarSpec::Quad(vid0, vid1) => {
res.gate_muls.push(GateMul {
inputs: [
Input::new(0, aq.var_pos[vid0]),
Input::new(0, aq.var_pos[vid1]),
],
output: pos,
coef: Coef::Constant(term.coef),
});
}
VarSpec::Custom { gate_type, inputs } => {
res.gate_customs.push(GateCustom {
gate_type: *gate_type,
inputs: inputs
.iter()
.map(|x| Input::new(0, aq.var_pos[x]))
.collect(),
output: pos,
coef: Coef::Constant(term.coef),
});
}
VarSpec::RandomLinear(vid) => {
res.gate_adds.push(GateAdd {
inputs: [Input::new(0, aq.var_pos[vid])],
output: pos,
coef: Coef::Random,
});
}
}
}
}
Expand All @@ -331,20 +343,22 @@ impl<'a, C: Config> CompileContext<'a, C> {
.iter()
.zip(ic.occured_layers[x].iter().skip(1))
{
let res = &mut ress[next_layer - 1];
let aq = &lqs[*cur_layer];
let bq = &lqs[*next_layer];
let pos = if let Some(p) = bq.var_pos.get(&x) {
*p
} else {
assert_eq!(*next_layer, ic.output_layer);
continue;
};
res.gate_adds.push(GateAdd {
inputs: [Input::new(next_layer - cur_layer - 1, aq.var_pos[&x])],
output: pos,
coef: Coef::Constant(C::CircuitField::one()),
});
if cached_ress[next_layer - 1].is_none() {
let res = &mut ress[next_layer - 1];
let aq = &lqs[*cur_layer];
let bq = &lqs[*next_layer];
let pos = if let Some(p) = bq.var_pos.get(&x) {
*p
} else {
assert_eq!(*next_layer, ic.output_layer);
continue;
};
res.gate_adds.push(GateAdd {
inputs: [Input::new(next_layer - cur_layer - 1, aq.var_pos[&x])],
output: pos,
coef: Coef::Constant(C::CircuitField::one()),
});
}
}
}

Expand Down Expand Up @@ -410,7 +424,11 @@ impl<'a, C: Config> CompileContext<'a, C> {

let mut ress_ids = Vec::new();

for res in ress.iter() {
for (res, cache) in ress.iter().zip(cached_ress.iter()) {
if let Some(cache) = cache {
ress_ids.push(*cache);
continue;
}
let res_id = self.compiled_circuits.len();
self.compiled_circuits.push(res.clone());
ress_ids.push(res_id);
Expand Down
3 changes: 3 additions & 0 deletions expander_compiler/tests/example_call_expander.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,21 @@ where
}

//#[test]
#[allow(dead_code)]
fn example_gf2() {
example::<GF2Config, GF2ExtConfigSha2>();
example::<GF2Config, GF2ExtConfigKeccak>();
}

//#[test]
#[allow(dead_code)]
fn example_m31() {
example::<M31Config, M31ExtConfigSha2>();
example::<M31Config, M31ExtConfigKeccak>();
}

//#[test]
#[allow(dead_code)]
fn example_bn254() {
example::<BN254Config, BN254ConfigSha2>();
example::<BN254Config, BN254ConfigKeccak>();
Expand Down
1 change: 1 addition & 0 deletions expander_compiler/tests/keccak_gf2_full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ impl Define<GF2Config> for Keccak256Circuit<Variable> {
}

//#[test]
#[allow(dead_code)]
fn keccak_gf2_full() {
let compile_result = compile(&Keccak256Circuit::default()).unwrap();
let CompileResult {
Expand Down

0 comments on commit 586db8f

Please sign in to comment.