Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
siq1 authored and zhenfeizhang committed Nov 11, 2024
1 parent fcf74e4 commit eb08646
Show file tree
Hide file tree
Showing 17 changed files with 98 additions and 75 deletions.
3 changes: 3 additions & 0 deletions ecgo/builder/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable {
}

// Sub computes the difference between the given variables.
// When more than two variables are provided, the difference is computed as i1 - Σ(i2...).
func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...)
return builder.add(vars, true)
Expand Down Expand Up @@ -142,6 +143,8 @@ func (builder *builder) ToBinary(i1 frontend.Variable, n ...int) []frontend.Vari
if nbBits < 0 {
panic("invalid n")
}
} else {
panic("only one argument is supported")
}

return bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits))
Expand Down
4 changes: 1 addition & 3 deletions ecgo/builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (builder *builder) Compile() (constraint.ConstraintSystem, error) {
return nil, nil
}

// ConstantValue returns the big.Int value of v and panics if v is not a constant.
// ConstantValue returns always returns (nil, false) now, since the Golang frontend doesn't know the values of variables.
func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) {
return nil, false
}
Expand Down Expand Up @@ -154,8 +154,6 @@ func (builder *builder) toVariableIds(in ...frontend.Variable) []int {
v := builder.toVariableId(i)
r = append(r, v)
}
// e(i1)
// e(i2)
for i := 0; i < len(in); i++ {
e(in[i])
}
Expand Down
4 changes: 3 additions & 1 deletion ecgo/builder/finalize.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package builder

import (
"fmt"

"github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/irsource"
)

Expand All @@ -25,7 +27,7 @@ func (builder *builder) Finalize() *irsource.Circuit {
cb := builder.defers[i]
err := cb(builder)
if err != nil {
panic(err)
panic(fmt.Sprintf("deferred function failed: %v", err))
}
}

Expand Down
31 changes: 22 additions & 9 deletions ecgo/builder/sub_circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type SubCircuit struct {
type SubCircuitRegistry struct {
m map[uint64]*SubCircuit
outputStructure map[uint64]*sliceStructure
fullHash map[uint64][32]byte
}

// SubCircuitAPI defines methods for working with subcircuits.
Expand All @@ -44,9 +45,22 @@ func newSubCircuitRegistry() *SubCircuitRegistry {
return &SubCircuitRegistry{
m: make(map[uint64]*SubCircuit),
outputStructure: make(map[uint64]*sliceStructure),
fullHash: make(map[uint64][32]byte),
}
}

func (sr *SubCircuitRegistry) getFullHashId(h [32]byte) uint64 {
id := binary.LittleEndian.Uint64(h[:8])
if v, ok := sr.fullHash[id]; ok {
if v != h {
panic("subcircuit id collision")
}
return id
}
sr.fullHash[id] = h
return id
}

func (parent *builder) callSubCircuit(
circuitId uint64,
input_ []frontend.Variable,
Expand Down Expand Up @@ -93,7 +107,7 @@ func (parent *builder) callSubCircuit(
func (parent *builder) MemorizedSimpleCall(f SubCircuitSimpleFunc, input []frontend.Variable) []frontend.Variable {
name := GetFuncName(f)
h := sha256.Sum256([]byte(fmt.Sprintf("simple_%d(%s)_%d", len(name), name, len(input))))
circuitId := binary.LittleEndian.Uint64(h[:8])
circuitId := parent.root.registry.getFullHashId(h)
return parent.callSubCircuit(circuitId, input, f)
}

Expand Down Expand Up @@ -205,13 +219,10 @@ func rebuildSliceVariables(vars []frontend.Variable, s *sliceStructure) reflect.
func isTypeSimple(t reflect.Type) bool {
k := t.Kind()
switch k {
case reflect.Bool:
return true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.String:
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.String:
return true
default:
return false
Expand Down Expand Up @@ -310,7 +321,9 @@ func (parent *builder) MemorizedCall(fn SubCircuitFunc, inputs ...interface{}) i
vs := inputVals[i].String()
h.Write([]byte(strconv.Itoa(len(vs)) + vs))
}
circuitId := binary.LittleEndian.Uint64(h.Sum(nil)[:8])
var tmp [32]byte
copy(tmp[:], h.Sum(nil))
circuitId := parent.root.registry.getFullHashId(tmp)

// sub-circuit caller
fnInner := func(api frontend.API, input []frontend.Variable) []frontend.Variable {
Expand Down
15 changes: 8 additions & 7 deletions ecgo/field/bn254/field_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,16 @@ func (engine *Field) Inverse(a constraint.Element) (constraint.Element, bool) {
return a, false
} else if e.IsOne() {
return a, true
}
var t fr.Element
t.Neg(e)
if t.IsOne() {
} else {
var t fr.Element
t.Neg(e)
if t.IsOne() {
return a, true
}

e.Inverse(e)
return a, true
}

e.Inverse(e)
return a, true
}

func (engine *Field) IsOne(a constraint.Element) bool {
Expand Down
2 changes: 1 addition & 1 deletion ecgo/field/m31/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/consensys/gnark/constraint"
)

const P = 2147483647
const P = 0x7fffffff

var Pbig = big.NewInt(P)
var ScalarField = Pbig
Expand Down
3 changes: 3 additions & 0 deletions ecgo/irwg/witness_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ func (rc *RootCircuit) evalSub(circuitId uint64, inputs []constraint.Element, pu
func callHint(hintId uint64, field *big.Int, inputs []*big.Int, outputs []*big.Int) error {
// The only required builtin hint (Div)
if hintId == 0xCCC000000001 {
if len(inputs) != 2 || len(outputs) != 1 {
return errors.New("Div hint requires 2 inputs and 1 output")
}
x := (&big.Int{}).Mod(inputs[0], field)
y := (&big.Int{}).Mod(inputs[1], field)
if y.Cmp(big.NewInt(0)) == 0 {
Expand Down
8 changes: 5 additions & 3 deletions ecgo/layered/serialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils"
)

const MAGIC = 3914834606642317635

func serializeCoef(o *utils.OutputBuf, bnlen int, coef *big.Int, coefType uint8, publicInputId uint64) {
if coefType == 1 {
o.AppendUint8(1)
Expand Down Expand Up @@ -35,7 +37,7 @@ func deserializeCoef(in *utils.InputBuf, bnlen int) (*big.Int, uint8, uint64) {
func (rc *RootCircuit) Serialize() []byte {
bnlen := field.GetFieldFromOrder(rc.Field).SerializedLen()
o := utils.OutputBuf{}
o.AppendUint64(3914834606642317635)
o.AppendUint64(MAGIC)
o.AppendBigInt(32, rc.Field)
o.AppendUint64(uint64(rc.NumPublicInputs))
o.AppendUint64(uint64(rc.NumActualOutputs))
Expand Down Expand Up @@ -91,7 +93,7 @@ func (rc *RootCircuit) Serialize() []byte {

func DeserializeRootCircuit(buf []byte) *RootCircuit {
in := utils.NewInputBuf(buf)
if in.ReadUint64() != 3914834606642317635 {
if in.ReadUint64() != MAGIC {
panic("invalid file header")
}
rc := &RootCircuit{}
Expand Down Expand Up @@ -178,7 +180,7 @@ func DetectFieldIdFromFile(fn string) uint64 {
panic(err)
}
in := utils.NewInputBuf(buf)
if in.ReadUint64() != 3914834606642317635 {
if in.ReadUint64() != MAGIC {
panic("invalid file header")
}
f := in.ReadBigInt(32)
Expand Down
11 changes: 7 additions & 4 deletions ecgo/utils/buf.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package utils

import (
"encoding/binary"
"fmt"
"math/big"

"github.com/consensys/gnark/constraint"
Expand All @@ -20,12 +21,12 @@ type SimpleField interface {
func (o *OutputBuf) AppendBigInt(n int, x *big.Int) {
zbuf := make([]byte, n)
b := x.Bytes()
if len(b) > n {
panic(fmt.Sprintf("big.Int is too large to serialize: %d > %d", len(b), n))
}
for i := 0; i < len(b); i++ {
zbuf[i] = b[len(b)-i-1]
}
for i := len(b); i < n; i++ {
zbuf[i] = 0
}
o.buf = append(o.buf, zbuf...)
}

Expand Down Expand Up @@ -53,7 +54,9 @@ func (o *OutputBuf) AppendIntSlice(x []int) {
}

func (o *OutputBuf) Bytes() []byte {
return o.buf
res := o.buf
o.buf = nil
return res
}

type InputBuf struct {
Expand Down
4 changes: 2 additions & 2 deletions ecgo/utils/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (m Map) Set(e Hashable, v interface{}) {
})
}

// adds (e, v) to the map, does nothing when e already exists
// adds (e, v) to the map, returns the current value when e already exists
func (m Map) Add(e Hashable, v interface{}) interface{} {
h := e.HashCode()
s, ok := m[h]
Expand All @@ -66,7 +66,7 @@ func (m Map) Add(e Hashable, v interface{}) interface{} {
return v
}

// filter keys in the map using the given function
// filter (e, v) in the map using f(v), returns the keys
func (m Map) FilterKeys(f func(interface{}) bool) []Hashable {
keys := []Hashable{}
for _, s := range m {
Expand Down
9 changes: 6 additions & 3 deletions ecgo/utils/power.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package utils

import "math/bits"

// pad to 2^n gates (and 4^n for first layer)
// 4^n exists for historical reasons, not used now
func NextPowerOfTwo(x int, is4 bool) int {
padk := 0
for x > (1 << padk) {
padk++
if x < 0 {
panic("x must be non-negative")
}

padk := bits.Len(uint(x))
if is4 && padk%2 != 0 {
padk++
}
Expand Down
4 changes: 2 additions & 2 deletions ecgo/utils/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ func (l *IntSeq) Less(i, j int) bool {
}

// SortIntSeq sorts an integer sequence using a given compare function
func SortIntSeq(s []int, cmp func(int, int) bool) {
func SortIntSeq(s []int, cmpLess func(int, int) bool) {
l := &IntSeq{
s: s,
cmp: cmp,
cmp: cmpLess,
}
sort.Sort(l)
}
29 changes: 20 additions & 9 deletions expander_compiler/src/frontend/builder.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::{
collections::HashMap,
hash::{Hash, Hasher},
};
use std::collections::HashMap;

use ethnum::U256;
use tiny_keccak::Hasher;

use crate::{
circuit::{
Expand Down Expand Up @@ -373,6 +371,7 @@ pub struct RootBuilder<C: Config> {
num_public_inputs: usize,
current_builders: Vec<(usize, Builder<C>)>,
sub_circuits: HashMap<usize, source::Circuit<C>>,
full_hash_id: HashMap<usize, [u8; 32]>,
}

macro_rules! root_binary_op {
Expand Down Expand Up @@ -465,6 +464,7 @@ impl<C: Config> RootBuilder<C> {
num_public_inputs,
current_builders: vec![(0, builder0)],
sub_circuits: HashMap::new(),
full_hash_id: HashMap::new(),
},
inputs,
public_inputs,
Expand Down Expand Up @@ -530,11 +530,22 @@ impl<C: Config> RootBuilder<C> {
f: F,
inputs: &[Variable],
) -> Vec<Variable> {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
"simple".hash(&mut hasher);
inputs.len().hash(&mut hasher);
get_function_id::<F>().hash(&mut hasher);
let circuit_id = hasher.finish() as usize;
let mut hasher = tiny_keccak::Keccak::v256();
hasher.update(b"simple");
hasher.update(&inputs.len().to_le_bytes());
hasher.update(&get_function_id::<F>().to_le_bytes());
let mut hash = [0u8; 32];
hasher.finalize(&mut hash);

let circuit_id = usize::from_le_bytes(hash[0..8].try_into().unwrap());
if let Some(prev_hash) = self.full_hash_id.get(&circuit_id) {
if *prev_hash != hash {
panic!("subcircuit id collision");
}
} else {
self.full_hash_id.insert(circuit_id, hash);
}

self.call_sub_circuit(circuit_id, inputs, f)
}

Expand Down
4 changes: 3 additions & 1 deletion expander_compiler/src/layering/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ pub struct SubCircuitInsn<'a> {
pub outputs: Vec<usize>,
}

const EXTRA_PRE_ALLOC_SIZE: usize = 1000;

impl<'a, C: Config> CompileContext<'a, C> {
pub fn compile(&mut self) {
// 1. do a toposort of the circuits
Expand Down Expand Up @@ -187,7 +189,7 @@ impl<'a, C: Config> CompileContext<'a, C> {
let mut n = nv + ns;
let circuit = self.rc.circuits.get(&circuit_id).unwrap();

let pre_alloc_size = n + (if n < 1000 { n } else { 1000 });
let pre_alloc_size = n + EXTRA_PRE_ALLOC_SIZE.min(n);
ic.min_layer = Vec::with_capacity(pre_alloc_size);
ic.max_layer = Vec::with_capacity(pre_alloc_size);

Expand Down
Loading

0 comments on commit eb08646

Please sign in to comment.