diff --git a/ecgo/builder/api.go b/ecgo/builder/api.go index 11106c4..7dd0d24 100644 --- a/ecgo/builder/api.go +++ b/ecgo/builder/api.go @@ -45,6 +45,7 @@ func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { } // Sub computes the difference between the given variables. +// When more than two variables are provided, the difference is computed as i1 - Σ(i2...). func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) return builder.add(vars, true) @@ -142,6 +143,8 @@ func (builder *builder) ToBinary(i1 frontend.Variable, n ...int) []frontend.Vari if nbBits < 0 { panic("invalid n") } + } else { + panic("only one argument is supported") } return bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits)) diff --git a/ecgo/builder/builder.go b/ecgo/builder/builder.go index e6831b9..e90f8f0 100644 --- a/ecgo/builder/builder.go +++ b/ecgo/builder/builder.go @@ -104,7 +104,7 @@ func (builder *builder) Compile() (constraint.ConstraintSystem, error) { return nil, nil } -// ConstantValue returns the big.Int value of v and panics if v is not a constant. +// ConstantValue returns always returns (nil, false) now, since the Golang frontend doesn't know the values of variables. func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { return nil, false } @@ -154,8 +154,6 @@ func (builder *builder) toVariableIds(in ...frontend.Variable) []int { v := builder.toVariableId(i) r = append(r, v) } - // e(i1) - // e(i2) for i := 0; i < len(in); i++ { e(in[i]) } diff --git a/ecgo/builder/finalize.go b/ecgo/builder/finalize.go index f38642f..d8956fe 100644 --- a/ecgo/builder/finalize.go +++ b/ecgo/builder/finalize.go @@ -1,6 +1,8 @@ package builder import ( + "fmt" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/irsource" ) @@ -25,7 +27,7 @@ func (builder *builder) Finalize() *irsource.Circuit { cb := builder.defers[i] err := cb(builder) if err != nil { - panic(err) + panic(fmt.Sprintf("deferred function failed: %v", err)) } } diff --git a/ecgo/builder/sub_circuit.go b/ecgo/builder/sub_circuit.go index 16d6e0a..72add55 100644 --- a/ecgo/builder/sub_circuit.go +++ b/ecgo/builder/sub_circuit.go @@ -32,6 +32,7 @@ type SubCircuit struct { type SubCircuitRegistry struct { m map[uint64]*SubCircuit outputStructure map[uint64]*sliceStructure + fullHash map[uint64][32]byte } // SubCircuitAPI defines methods for working with subcircuits. @@ -44,9 +45,22 @@ func newSubCircuitRegistry() *SubCircuitRegistry { return &SubCircuitRegistry{ m: make(map[uint64]*SubCircuit), outputStructure: make(map[uint64]*sliceStructure), + fullHash: make(map[uint64][32]byte), } } +func (sr *SubCircuitRegistry) getFullHashId(h [32]byte) uint64 { + id := binary.LittleEndian.Uint64(h[:8]) + if v, ok := sr.fullHash[id]; ok { + if v != h { + panic("subcircuit id collision") + } + return id + } + sr.fullHash[id] = h + return id +} + func (parent *builder) callSubCircuit( circuitId uint64, input_ []frontend.Variable, @@ -93,7 +107,7 @@ func (parent *builder) callSubCircuit( func (parent *builder) MemorizedSimpleCall(f SubCircuitSimpleFunc, input []frontend.Variable) []frontend.Variable { name := GetFuncName(f) h := sha256.Sum256([]byte(fmt.Sprintf("simple_%d(%s)_%d", len(name), name, len(input)))) - circuitId := binary.LittleEndian.Uint64(h[:8]) + circuitId := parent.root.registry.getFullHashId(h) return parent.callSubCircuit(circuitId, input, f) } @@ -205,13 +219,10 @@ func rebuildSliceVariables(vars []frontend.Variable, s *sliceStructure) reflect. func isTypeSimple(t reflect.Type) bool { k := t.Kind() switch k { - case reflect.Bool: - return true - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return true - case reflect.String: + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.String: return true default: return false @@ -310,7 +321,9 @@ func (parent *builder) MemorizedCall(fn SubCircuitFunc, inputs ...interface{}) i vs := inputVals[i].String() h.Write([]byte(strconv.Itoa(len(vs)) + vs)) } - circuitId := binary.LittleEndian.Uint64(h.Sum(nil)[:8]) + var tmp [32]byte + copy(tmp[:], h.Sum(nil)) + circuitId := parent.root.registry.getFullHashId(tmp) // sub-circuit caller fnInner := func(api frontend.API, input []frontend.Variable) []frontend.Variable { diff --git a/ecgo/field/bn254/field_wrapper.go b/ecgo/field/bn254/field_wrapper.go index e2ea4c6..864a74f 100644 --- a/ecgo/field/bn254/field_wrapper.go +++ b/ecgo/field/bn254/field_wrapper.go @@ -80,15 +80,16 @@ func (engine *Field) Inverse(a constraint.Element) (constraint.Element, bool) { return a, false } else if e.IsOne() { return a, true - } - var t fr.Element - t.Neg(e) - if t.IsOne() { + } else { + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + + e.Inverse(e) return a, true } - - e.Inverse(e) - return a, true } func (engine *Field) IsOne(a constraint.Element) bool { diff --git a/ecgo/field/m31/field.go b/ecgo/field/m31/field.go index 9139f5f..86a7097 100644 --- a/ecgo/field/m31/field.go +++ b/ecgo/field/m31/field.go @@ -8,7 +8,7 @@ import ( "github.com/consensys/gnark/constraint" ) -const P = 2147483647 +const P = 0x7fffffff var Pbig = big.NewInt(P) var ScalarField = Pbig diff --git a/ecgo/irwg/witness_gen.go b/ecgo/irwg/witness_gen.go index e86d10b..1595f1a 100644 --- a/ecgo/irwg/witness_gen.go +++ b/ecgo/irwg/witness_gen.go @@ -194,6 +194,9 @@ func (rc *RootCircuit) evalSub(circuitId uint64, inputs []constraint.Element, pu func callHint(hintId uint64, field *big.Int, inputs []*big.Int, outputs []*big.Int) error { // The only required builtin hint (Div) if hintId == 0xCCC000000001 { + if len(inputs) != 2 || len(outputs) != 1 { + return errors.New("Div hint requires 2 inputs and 1 output") + } x := (&big.Int{}).Mod(inputs[0], field) y := (&big.Int{}).Mod(inputs[1], field) if y.Cmp(big.NewInt(0)) == 0 { diff --git a/ecgo/layered/serialize.go b/ecgo/layered/serialize.go index 66bb8ad..1664a67 100644 --- a/ecgo/layered/serialize.go +++ b/ecgo/layered/serialize.go @@ -8,6 +8,8 @@ import ( "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils" ) +const MAGIC = 3914834606642317635 + func serializeCoef(o *utils.OutputBuf, bnlen int, coef *big.Int, coefType uint8, publicInputId uint64) { if coefType == 1 { o.AppendUint8(1) @@ -35,7 +37,7 @@ func deserializeCoef(in *utils.InputBuf, bnlen int) (*big.Int, uint8, uint64) { func (rc *RootCircuit) Serialize() []byte { bnlen := field.GetFieldFromOrder(rc.Field).SerializedLen() o := utils.OutputBuf{} - o.AppendUint64(3914834606642317635) + o.AppendUint64(MAGIC) o.AppendBigInt(32, rc.Field) o.AppendUint64(uint64(rc.NumPublicInputs)) o.AppendUint64(uint64(rc.NumActualOutputs)) @@ -91,7 +93,7 @@ func (rc *RootCircuit) Serialize() []byte { func DeserializeRootCircuit(buf []byte) *RootCircuit { in := utils.NewInputBuf(buf) - if in.ReadUint64() != 3914834606642317635 { + if in.ReadUint64() != MAGIC { panic("invalid file header") } rc := &RootCircuit{} @@ -178,7 +180,7 @@ func DetectFieldIdFromFile(fn string) uint64 { panic(err) } in := utils.NewInputBuf(buf) - if in.ReadUint64() != 3914834606642317635 { + if in.ReadUint64() != MAGIC { panic("invalid file header") } f := in.ReadBigInt(32) diff --git a/ecgo/utils/buf.go b/ecgo/utils/buf.go index a16f812..34f9cbb 100644 --- a/ecgo/utils/buf.go +++ b/ecgo/utils/buf.go @@ -2,6 +2,7 @@ package utils import ( "encoding/binary" + "fmt" "math/big" "github.com/consensys/gnark/constraint" @@ -20,12 +21,12 @@ type SimpleField interface { func (o *OutputBuf) AppendBigInt(n int, x *big.Int) { zbuf := make([]byte, n) b := x.Bytes() + if len(b) > n { + panic(fmt.Sprintf("big.Int is too large to serialize: %d > %d", len(b), n)) + } for i := 0; i < len(b); i++ { zbuf[i] = b[len(b)-i-1] } - for i := len(b); i < n; i++ { - zbuf[i] = 0 - } o.buf = append(o.buf, zbuf...) } @@ -53,7 +54,9 @@ func (o *OutputBuf) AppendIntSlice(x []int) { } func (o *OutputBuf) Bytes() []byte { - return o.buf + res := o.buf + o.buf = nil + return res } type InputBuf struct { diff --git a/ecgo/utils/map.go b/ecgo/utils/map.go index 295e0bb..b97ecb6 100644 --- a/ecgo/utils/map.go +++ b/ecgo/utils/map.go @@ -46,7 +46,7 @@ func (m Map) Set(e Hashable, v interface{}) { }) } -// adds (e, v) to the map, does nothing when e already exists +// adds (e, v) to the map, returns the current value when e already exists func (m Map) Add(e Hashable, v interface{}) interface{} { h := e.HashCode() s, ok := m[h] @@ -66,7 +66,7 @@ func (m Map) Add(e Hashable, v interface{}) interface{} { return v } -// filter keys in the map using the given function +// filter (e, v) in the map using f(v), returns the keys func (m Map) FilterKeys(f func(interface{}) bool) []Hashable { keys := []Hashable{} for _, s := range m { diff --git a/ecgo/utils/power.go b/ecgo/utils/power.go index 116d06d..a30a115 100644 --- a/ecgo/utils/power.go +++ b/ecgo/utils/power.go @@ -1,12 +1,15 @@ package utils +import "math/bits" + // pad to 2^n gates (and 4^n for first layer) // 4^n exists for historical reasons, not used now func NextPowerOfTwo(x int, is4 bool) int { - padk := 0 - for x > (1 << padk) { - padk++ + if x < 0 { + panic("x must be non-negative") } + + padk := bits.Len(uint(x)) if is4 && padk%2 != 0 { padk++ } diff --git a/ecgo/utils/sort.go b/ecgo/utils/sort.go index adf3b3f..519cef5 100644 --- a/ecgo/utils/sort.go +++ b/ecgo/utils/sort.go @@ -20,10 +20,10 @@ func (l *IntSeq) Less(i, j int) bool { } // SortIntSeq sorts an integer sequence using a given compare function -func SortIntSeq(s []int, cmp func(int, int) bool) { +func SortIntSeq(s []int, cmpLess func(int, int) bool) { l := &IntSeq{ s: s, - cmp: cmp, + cmp: cmpLess, } sort.Sort(l) } diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index eba6748..220927d 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -1,9 +1,7 @@ -use std::{ - collections::HashMap, - hash::{Hash, Hasher}, -}; +use std::collections::HashMap; use ethnum::U256; +use tiny_keccak::Hasher; use crate::{ circuit::{ @@ -373,6 +371,7 @@ pub struct RootBuilder { num_public_inputs: usize, current_builders: Vec<(usize, Builder)>, sub_circuits: HashMap>, + full_hash_id: HashMap, } macro_rules! root_binary_op { @@ -465,6 +464,7 @@ impl RootBuilder { num_public_inputs, current_builders: vec![(0, builder0)], sub_circuits: HashMap::new(), + full_hash_id: HashMap::new(), }, inputs, public_inputs, @@ -530,11 +530,22 @@ impl RootBuilder { f: F, inputs: &[Variable], ) -> Vec { - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - "simple".hash(&mut hasher); - inputs.len().hash(&mut hasher); - get_function_id::().hash(&mut hasher); - let circuit_id = hasher.finish() as usize; + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(b"simple"); + hasher.update(&inputs.len().to_le_bytes()); + hasher.update(&get_function_id::().to_le_bytes()); + let mut hash = [0u8; 32]; + hasher.finalize(&mut hash); + + let circuit_id = usize::from_le_bytes(hash[0..8].try_into().unwrap()); + if let Some(prev_hash) = self.full_hash_id.get(&circuit_id) { + if *prev_hash != hash { + panic!("subcircuit id collision"); + } + } else { + self.full_hash_id.insert(circuit_id, hash); + } + self.call_sub_circuit(circuit_id, inputs, f) } diff --git a/expander_compiler/src/layering/compile.rs b/expander_compiler/src/layering/compile.rs index 06080b7..f2b9214 100644 --- a/expander_compiler/src/layering/compile.rs +++ b/expander_compiler/src/layering/compile.rs @@ -88,6 +88,8 @@ pub struct SubCircuitInsn<'a> { pub outputs: Vec, } +const EXTRA_PRE_ALLOC_SIZE: usize = 1000; + impl<'a, C: Config> CompileContext<'a, C> { pub fn compile(&mut self) { // 1. do a toposort of the circuits @@ -187,7 +189,7 @@ impl<'a, C: Config> CompileContext<'a, C> { let mut n = nv + ns; let circuit = self.rc.circuits.get(&circuit_id).unwrap(); - let pre_alloc_size = n + (if n < 1000 { n } else { 1000 }); + let pre_alloc_size = n + EXTRA_PRE_ALLOC_SIZE.min(n); ic.min_layer = Vec::with_capacity(pre_alloc_size); ic.max_layer = Vec::with_capacity(pre_alloc_size); diff --git a/expander_compiler/src/layering/layer_layout.rs b/expander_compiler/src/layering/layer_layout.rs index bd29f2a..930bc88 100644 --- a/expander_compiler/src/layering/layer_layout.rs +++ b/expander_compiler/src/layering/layer_layout.rs @@ -27,8 +27,6 @@ pub struct PlacementRequest { pub input_ids: Vec, } -// TODO: use better data structure to maintain the segments - // finalized layout of a layer // dense -> placementDense[i] = variable on slot i (placementDense[i] == j means i-th slot stores varIdx[j]) // sparse -> placementSparse[i] = variable on slot i, and there are subLayouts. @@ -83,7 +81,6 @@ pub struct SubLayout { // request for layer layout #[derive(Hash, Clone, PartialEq, Eq)] pub struct LayerReq { - // TODO: more requirements, e.g. alignment pub circuit_id: usize, pub layer: usize, // which layer to solve? } @@ -179,7 +176,6 @@ impl<'a, C: Config> CompileContext<'a, C> { lc.parent.push(parent); } } - // TODO: partial merge } self.circuits.insert(circuit_id, ic); } @@ -263,11 +259,6 @@ impl<'a, C: Config> CompileContext<'a, C> { placements[i] = merge_layouts(s, mem::take(&mut children_variables[i])); } - // now placements[0] contains all direct variables - // we only need to merge with middle layers - // currently it's the most basic merging algorithm - just put them together - // TODO: optimize the merging algorithm - if lc.middle_sub_circuits.is_empty() { self.circuits.insert(req.circuit_id, ic); return LayerLayout { @@ -348,7 +339,6 @@ fn merge_layouts(s: Vec>, additional: Vec) -> Vec { // sort groups by size, and then place them one by one // since their size are always 2^n, the result is aligned // finally we insert the remaining variables to the empty slots - // TODO: improve this let mut n = 0; for x in s.iter() { let m = x.len(); @@ -379,7 +369,6 @@ fn merge_layouts(s: Vec>, additional: Vec) -> Vec { panic!("unexpected situation"); } let mut placed = false; - // TODO: better collision detection for i in (0..res.len()).step_by(pg.len()) { let mut ok = true; for j in 0..pg.len() { diff --git a/expander_compiler/src/layering/tests.rs b/expander_compiler/src/layering/tests.rs index faeb6e7..6eef63f 100644 --- a/expander_compiler/src/layering/tests.rs +++ b/expander_compiler/src/layering/tests.rs @@ -18,8 +18,6 @@ pub fn test_input( let (rc_output, rc_cond) = rc.eval_unsafe(input.clone()); let lc_input = input_mapping.map_inputs(input); let (lc_output, lc_cond) = lc.eval_unsafe(lc_input); - //println!("{:?}", rc_output); - //println!("{:?}", lc_output); assert_eq!(rc_cond, lc_cond); assert_eq!(rc_output, lc_output); } @@ -30,7 +28,6 @@ pub fn compile_and_random_test( ) -> (layered::Circuit, InputMapping) { assert!(rc.validate().is_ok()); let (lc, input_mapping) = compile(rc); - //print!("{}", lc); assert_eq!(lc.validate(), Ok(())); assert_eq!(rc.input_size(), input_mapping.cur_size()); let input_size = rc.input_size(); diff --git a/expander_compiler/src/layering/wire.rs b/expander_compiler/src/layering/wire.rs index 4bd1577..c7d0a71 100644 --- a/expander_compiler/src/layering/wire.rs +++ b/expander_compiler/src/layering/wire.rs @@ -21,14 +21,15 @@ struct LayoutQuery { } impl LayoutQuery { + // given a parent layer layout, this function query the layout of a sub circuit fn query( &self, layer_layout_pool: &mut Pool, circuits: &HashMap>, - vs: &[usize], - f: F, - cid: usize, - lid: usize, + vs: &[usize], // variables to query (in parent layer) + f: F, // f(i) = id of i-th variable in the sub circuit + cid: usize, // target circuit id + lid: usize, // target layer id ) -> SubLayout where F: Fn(usize) -> usize, @@ -64,9 +65,13 @@ impl LayoutQuery { } } let mut xor = if l <= r { l ^ r } else { 0 }; - while xor != 0 && (xor & (xor - 1)) != 0 { - xor &= xor - 1; - } + xor |= xor >> 1; + xor |= xor >> 2; + xor |= xor >> 4; + xor |= xor >> 8; + xor |= xor >> 16; + xor |= xor >> 32; + xor ^= xor >> 1; let n = if xor == 0 { 1 } else { xor << 1 }; let offset = if l <= r { l & !(n - 1) } else { 0 }; let mut placement = vec![EMPTY; n]; @@ -135,15 +140,6 @@ impl<'a, C: Config> CompileContext<'a, C> { let aq = self.layout_query(&a, cur_lc.vars.vec()); let bq = self.layout_query(&b, next_lc.vars.vec()); - /*println!( - "connect_wires: {} {} circuit_id={} cur_layer={} output_layer={}", - a_, b_, a.circuit_id, cur_layer, ic.output_layer - ); - println!("cur: {:?}", a.inner); - println!("next: {:?}", b.inner); - println!("cur_var: {:?}", cur_lc.vars.vec()); - println!("next_var: {:?}", next_lc.vars.vec());*/ - // check if all variables exist in the layout for x in cur_lc.vars.vec().iter() { if !aq.var_pos.contains_key(x) {