diff --git a/apps/garbled/default.pgo b/apps/garbled/default.pgo index 549cc95d..ef43e117 100644 Binary files a/apps/garbled/default.pgo and b/apps/garbled/default.pgo differ diff --git a/apps/garbled/examples/aesblock.mpcl b/apps/garbled/examples/aesblock.mpcl new file mode 100644 index 00000000..51c772c8 --- /dev/null +++ b/apps/garbled/examples/aesblock.mpcl @@ -0,0 +1,11 @@ +// -*- go -*- + +package main + +import ( + "crypto/aes" +) + +func main(key, data [16]byte) []byte { + return aes.EncryptBlock(key, data) +} diff --git a/apps/garbled/examples/aesblock2.mpcl b/apps/garbled/examples/aesblock2.mpcl new file mode 100644 index 00000000..afb72a8f --- /dev/null +++ b/apps/garbled/examples/aesblock2.mpcl @@ -0,0 +1,11 @@ +// -*- go -*- + +package main + +import ( + "crypto/aes" +) + +func main(key, data [16]byte) []byte { + return aes.Block128(key, data) +} diff --git a/apps/garbled/examples/aesexpand.mpcl b/apps/garbled/examples/aesexpand.mpcl new file mode 100644 index 00000000..8d2786b2 --- /dev/null +++ b/apps/garbled/examples/aesexpand.mpcl @@ -0,0 +1,11 @@ +// -*- go -*- + +package main + +import ( + "crypto/aes" +) + +func main(key, data [16]byte) []uint { + return aes.ExpandEncryptionKey(key) +} diff --git a/apps/garbled/examples/encrypt.mpcl b/apps/garbled/examples/encrypt.mpcl new file mode 100644 index 00000000..fd84d402 --- /dev/null +++ b/apps/garbled/examples/encrypt.mpcl @@ -0,0 +1,43 @@ +// -*- go -*- + +// Run the Evaluator with two inputs: evaluator's key and nonce shares: +// +// $ ./garbled -e -i 0x8cd98b88adab08d6d60fe57c8b8a33f3,0xfd5e0f8f155e7102aa526ad0 examples/encrypt.mpcl +// +// The Garbler takes three arguments: the message to encrypt, and its +// key and nonce shares: +// +// $ ./garbled -i 0x48656c6c6f2c20776f726c6421,0xed800b17b0c9d2334b249332155ddef5,0xa300751458c775a08762c2cd examples/encrypt.mpcl + +package main + +import ( + "crypto/cipher/gcm" +) + +type Garbler struct { + msg [64]byte + keyShare [16]byte + nonceShare [12]byte +} + +type Evaluator struct { + keyShare [16]byte + nonceShare [12]byte +} + +func main(g Garbler, e Evaluator) []byte { + var key [16]byte + + for i := 0; i < len(key); i++ { + key[i] = g.keyShare[i] ^ e.keyShare[i] + } + + var nonce [12]byte + + for i := 0; i < len(nonce); i++ { + nonce[i] = g.nonceShare[i] ^ e.nonceShare[i] + } + + return gcm.EncryptAES128(key, nonce, g.msg, []byte("unused")) +} diff --git a/apps/garbled/main.go b/apps/garbled/main.go index 821edf49..a57faa70 100644 --- a/apps/garbled/main.go +++ b/apps/garbled/main.go @@ -27,6 +27,7 @@ import ( "github.com/markkurossi/mpc/compiler/utils" "github.com/markkurossi/mpc/ot" "github.com/markkurossi/mpc/p2p" + "github.com/markkurossi/mpc/types" ) var ( @@ -407,10 +408,11 @@ func printResults(results []*big.Int, outputs circuit.IO) { func printResult(result *big.Int, output circuit.IOArg, short bool) string { var str string - if strings.HasPrefix(output.Type, "string") { + switch output.Type.Type { + case types.TString: mask := big.NewInt(0xff) - for i := 0; i < output.Size/8; i++ { + for i := 0; i < int(output.Type.Bits)/8; i++ { tmp := new(big.Int).Rsh(result, uint(i*8)) r := rune(tmp.And(tmp, mask).Uint64()) if unicode.IsPrint(r) { @@ -419,11 +421,10 @@ func printResult(result *big.Int, output circuit.IOArg, short bool) string { str += fmt.Sprintf("\\u%04x", r) } } - } else if strings.HasPrefix(output.Type, "uint") || - strings.HasPrefix(output.Type, "int") { - if output.Type[0] == 'i' { - bits := circuit.Size(output.Type) + case types.TUint, types.TInt: + if output.Type.Type == types.TInt { + bits := int(output.Type.Bits) if result.Bit(bits-1) == 1 { // Negative number. tmp := new(big.Int) @@ -439,47 +440,53 @@ func printResult(result *big.Int, output circuit.IOArg, short bool) string { } if short { str = fmt.Sprintf("%v", result) - } else if output.Size <= 64 { + } else if output.Type.Bits <= 64 { str = fmt.Sprintf("0x%x\t%v", bytes, result) } else { str = fmt.Sprintf("0x%x", bytes) } - } else if strings.HasPrefix(output.Type, "bool") { + + case types.TBool: str = fmt.Sprintf("%v", result.Uint64() != 0) - } else { - ok, count, elSize, elType := circuit.ParseArrayType(output.Type) - if ok { - mask := new(big.Int) - for i := 0; i < elSize; i++ { - mask.SetBit(mask, i, 1) - } - hexString := elType == "uint8" - if !hexString { - str = "[" - } - for i := 0; i < count; i++ { - r := new(big.Int).Rsh(result, uint(i*elSize)) - r = r.And(r, mask) - - if hexString { - str += fmt.Sprintf("%02x", r.Int64()) - } else { - if i > 0 { - str += " " - } - str += printResult(r, circuit.IOArg{ - Type: elType, - Size: elSize, - }, true) + case types.TArray: + count := int(output.Type.ArraySize) + elSize := int(output.Type.ElementType.Bits) + + mask := new(big.Int) + for i := 0; i < elSize; i++ { + mask.SetBit(mask, i, 1) + } + + var hexString bool + if output.Type.ElementType.Type == types.TUint && + output.Type.ElementType.Bits == 8 { + hexString = true + } + if !hexString { + str = "[" + } + for i := 0; i < count; i++ { + r := new(big.Int).Rsh(result, uint(i*elSize)) + r = r.And(r, mask) + + if hexString { + str += fmt.Sprintf("%02x", r.Int64()) + } else { + if i > 0 { + str += " " } + str += printResult(r, circuit.IOArg{ + Type: *output.Type.ElementType, + }, true) } - if !hexString { - str += "]" - } - } else { - str = fmt.Sprintf("%v (%s)", result, output.Type) } + if !hexString { + str += "]" + } + + default: + str = fmt.Sprintf("%v (%s)", result, output.Type) } return str diff --git a/benchmarks.md b/benchmarks.md index e6f5e9f6..2654070d 100644 --- a/benchmarks.md +++ b/benchmarks.md @@ -538,6 +538,85 @@ Max permanent wires: 53913890, cached circuits: 25 #gates=830166294 (XOR=533177896 XNOR=28813441 AND=267575026 OR=496562 INV=103369 xor=561991337 !xor=268174957 levels=10548 width=1796) #w=853882864 ``` +Value.HashValue based `WireAllocator`: + +``` +┌──────────────┬────────────────┬─────────┬────────┐ +│ Op │ Time │ % │ Xfer │ +├──────────────┼────────────────┼─────────┼────────┤ +│ Compile │ 1.89192514s │ 2.70% │ │ +│ Init │ 2.706353ms │ 0.00% │ 0B │ +│ OT Init │ 11.731µs │ 0.00% │ 16kB │ +│ Peer Inputs │ 45.549977ms │ 0.07% │ 57kB │ +│ Stream │ 1m8.029901416s │ 97.23% │ 15GB │ +│ ├╴InstrInit │ 2.933100244s │ 4.31% │ │ +│ ├╴CircComp │ 30.339124ms │ 0.04% │ │ +│ ├╴StreamInit │ 2.608974096s │ 3.84% │ │ +│ ╰╴Garble │ 1m1.560371684s │ 90.49% │ │ +│ Result │ 324.555µs │ 0.00% │ 8kB │ +│ Total │ 1m9.970419172s │ │ 15GB │ +│ ├╴Sent │ │ 100.00% │ 15GB │ +│ ├╴Rcvd │ │ 0.00% │ 45kB │ +│ ╰╴Flcd │ │ │ 231284 │ +└──────────────┴────────────────┴─────────┴────────┘ +Max permanent wires: 53913890, cached circuits: 25 +#gates=830166294 (XOR=533177896 XNOR=28813441 AND=267575026 OR=496562 INV=103369 xor=561991337 !xor=268174957 levels=10548 width=1796) #w=853882864 +``` + +Optimized `compiler/circuits/Wire`: + +``` +┌──────────────┬────────────────┬─────────┬────────┐ +│ Op │ Time │ % │ Xfer │ +├──────────────┼────────────────┼─────────┼────────┤ +│ Compile │ 1.870993069s │ 2.71% │ │ +│ Init │ 2.331431ms │ 0.00% │ 0B │ +│ OT Init │ 10.949µs │ 0.00% │ 16kB │ +│ Peer Inputs │ 44.089085ms │ 0.06% │ 57kB │ +│ Stream │ 1m7.0813688s │ 97.22% │ 15GB │ +│ ├╴InstrInit │ 2.421297578s │ 3.61% │ │ +│ ├╴CircComp │ 17.09415ms │ 0.03% │ │ +│ ├╴StreamInit │ 2.155089182s │ 3.21% │ │ +│ ╰╴Garble │ 1m1.550598148s │ 91.76% │ │ +│ Result │ 432.27µs │ 0.00% │ 8kB │ +│ Total │ 1m8.999225604s │ │ 15GB │ +│ ├╴Sent │ │ 100.00% │ 15GB │ +│ ├╴Rcvd │ │ 0.00% │ 45kB │ +│ ╰╴Flcd │ │ │ 231284 │ +└──────────────┴────────────────┴─────────┴────────┘ +Max permanent wires: 53913890, cached circuits: 25 +#gates=830166294 (XOR=533177896 XNOR=28813441 AND=267575026 OR=496562 INV=103369 xor=561991337 !xor=268174957 levels=10548 width=1796) #w=853882864 +``` + +Optimized streamer to use `circuit.Wire` instead of +`compiler/circuits/Wire` in wire cache: + +``` +┌──────────────┬────────────────┬─────────┬────────┐ +│ Op │ Time │ % │ Xfer │ +├──────────────┼────────────────┼─────────┼────────┤ +│ Compile │ 1.755362404s │ 2.64% │ │ +│ Init │ 2.79287ms │ 0.00% │ 0B │ +│ OT Init │ 13.525µs │ 0.00% │ 16kB │ +│ Peer Inputs │ 45.376796ms │ 0.07% │ 57kB │ +│ Stream │ 1m4.624303427s │ 97.28% │ 15GB │ +│ ├╴InstrInit │ 1.166489974s │ 1.81% │ │ +│ ├╴CircComp │ 18.17144ms │ 0.03% │ │ +│ ├╴StreamInit │ 1.886360054s │ 2.92% │ │ +│ ╰╴Garble │ 1m0.866348862s │ 94.18% │ │ +│ Result │ 225.299µs │ 0.00% │ 8kB │ +│ Total │ 1m6.428074321s │ │ 15GB │ +│ ├╴Sent │ │ 100.00% │ 15GB │ +│ ├╴Rcvd │ │ 0.00% │ 45kB │ +│ ╰╴Flcd │ │ │ 231284 │ +└──────────────┴────────────────┴─────────┴────────┘ +Max permanent wires: 53913890, cached circuits: 25 +#gates=830166294 (XOR=533177896 XNOR=28813441 AND=267575026 OR=496562 INV=103369 xor=561991337 !xor=268174957 levels=10548 width=1796) #w=853882864 + 66.59 real 69.15 user 6.69 sys + 3568140288 maximum resident set size + 4119990272 peak memory footprint +``` + Theoretical minimum single-threaded garbling time: ``` diff --git a/bmr/player.go b/bmr/player.go index ac7cf532..883f18e0 100644 --- a/bmr/player.go +++ b/bmr/player.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2022 Markku Rossi +// Copyright (c) 2022-2023 Markku Rossi // // All rights reserved. // @@ -86,11 +86,11 @@ func (p *Player) offlinePhase() error { var inputIndex int for id, input := range p.c.Inputs { if id != p.id { - for i := 0; i < input.Size; i++ { + for i := 0; i < int(input.Type.Bits); i++ { p.lambda.SetBit(p.lambda, inputIndex+i, 0) } } - inputIndex += input.Size + inputIndex += int(input.Type.Bits) } wires := make([]Wire, p.c.NumWires) diff --git a/circuit/circuit.go b/circuit/circuit.go index 72cb6021..e0a6575f 100644 --- a/circuit/circuit.go +++ b/circuit/circuit.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2019-2022 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -11,9 +11,6 @@ import ( "io" "math" "math/big" - "regexp" - "strconv" - "strings" "github.com/markkurossi/tabulate" ) @@ -100,153 +97,6 @@ func (op Operation) String() string { } } -// IOArg describes circuit input argument. -type IOArg struct { - Name string - Type string - Size int - Compound IO -} - -func (io IOArg) String() string { - if len(io.Compound) > 0 { - return io.Compound.String() - } - - if len(io.Name) > 0 { - return io.Name + ":" + io.Type - } - return io.Type -} - -// Parse parses the I/O argument from the input string values. -func (io IOArg) Parse(inputs []string) (*big.Int, error) { - result := new(big.Int) - - if len(io.Compound) == 0 { - if len(inputs) != 1 { - return nil, - fmt.Errorf("invalid amount of arguments, got %d, expected 1", - len(inputs)) - } - - if strings.HasPrefix(io.Type, "u") || strings.HasPrefix(io.Type, "i") { - _, ok := result.SetString(inputs[0], 0) - if !ok { - return nil, fmt.Errorf("invalid input: %s", inputs[0]) - } - } else if io.Type == "bool" { - switch inputs[0] { - case "0", "f", "false": - case "1", "t", "true": - result.SetInt64(1) - default: - return nil, fmt.Errorf("invalid bool constant: %s", inputs[0]) - } - } else { - ok, count, elSize, _ := ParseArrayType(io.Type) - if ok { - val := new(big.Int) - _, ok := val.SetString(inputs[0], 0) - if !ok { - return nil, fmt.Errorf("invalid input: %s", inputs[0]) - } - - valElCount := val.BitLen() / elSize - if val.BitLen()%elSize != 0 { - valElCount++ - } - if valElCount > count { - return nil, fmt.Errorf("too many values for input: %s", - inputs[0]) - } - pad := count - valElCount - val.Lsh(val, uint(pad*elSize)) - - mask := new(big.Int) - for i := 0; i < elSize; i++ { - mask.SetBit(mask, i, 1) - } - - for i := 0; i < count; i++ { - next := new(big.Int).Rsh(val, uint((count-i-1)*elSize)) - next = next.And(next, mask) - - next.Lsh(next, uint(i*elSize)) - result.Or(result, next) - } - } else { - return nil, fmt.Errorf("unsupported input type: %s", io.Type) - } - } - - return result, nil - } - if len(inputs) != len(io.Compound) { - return nil, - fmt.Errorf("invalid amount of arguments, got %d, expected %d", - len(inputs), len(io.Compound)) - } - - var offset int - - for idx, arg := range io.Compound { - input, err := arg.Parse(inputs[idx : idx+1]) - if err != nil { - return nil, err - } - - input.Lsh(input, uint(offset)) - result.Or(result, input) - - offset += arg.Size - } - return result, nil -} - -var reArr = regexp.MustCompilePOSIX(`^\[([[:digit:]]+)\](.+)$`) -var reSized = regexp.MustCompilePOSIX(`^[[:^digit:]]+([[:digit:]]+)$`) - -// ParseArrayType parses the argument value as array type. -func ParseArrayType(val string) (ok bool, count, elementSize int, - elementType string) { - - matches := reArr.FindStringSubmatch(val) - if matches == nil { - return - } - var err error - count, err = strconv.Atoi(matches[1]) - if err != nil { - panic(fmt.Sprintf("invalid array size: %s", matches[1])) - } - ok = true - elementSize = Size(matches[2]) - elementType = matches[2] - return -} - -// Size returns the type size in bits. -func Size(t string) int { - matches := reArr.FindStringSubmatch(t) - if matches != nil { - count, err := strconv.Atoi(matches[1]) - if err != nil { - panic(fmt.Sprintf("invalid array size: %s", matches[1])) - } - return count * Size(matches[2]) - } - matches = reSized.FindStringSubmatch(t) - if matches == nil { - panic(fmt.Sprintf("invalid type: %s", t)) - } - bits, err := strconv.Atoi(matches[1]) - if err != nil { - panic(fmt.Sprintf("invalid bit count: %s", matches[1])) - } - return bits -} - // IO specifies circuit input and output arguments. type IO []IOArg @@ -255,7 +105,7 @@ type IO []IOArg func (io IO) Size() int { var sum int for _, a := range io { - sum += a.Size + sum += int(a.Type.Bits) } return sum } @@ -269,7 +119,7 @@ func (io IO) String() string { if len(a.Name) > 0 { str += a.Name + ":" } - str += a.Type + str += a.Type.String() } return str } @@ -280,7 +130,7 @@ func (io IO) Split(in *big.Int) []*big.Int { var bit int for _, arg := range io { r := big.NewInt(0) - for i := 0; i < arg.Size; i++ { + for i := 0; i < int(arg.Type.Bits); i++ { if in.Bit(bit) == 1 { r = big.NewInt(0).SetBit(r, i, 1) } @@ -427,8 +277,8 @@ type Wire uint32 // InvalidWire specifies an invalid wire ID. const InvalidWire Wire = math.MaxUint32 -// ID returns the wire ID as integer. -func (w Wire) ID() int { +// Int returns the wire ID as integer. +func (w Wire) Int() int { return int(w) } diff --git a/circuit/computer.go b/circuit/computer.go index c3e59c78..6e8f27a7 100644 --- a/circuit/computer.go +++ b/circuit/computer.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -32,7 +32,7 @@ func (c *Circuit) Compute(inputs []*big.Int) ([]*big.Int, error) { var w int for idx, io := range args { - for bit := 0; bit < io.Size; bit++ { + for bit := 0; bit < int(io.Type.Bits); bit++ { wires[w] = byte(inputs[idx].Bit(bit)) w++ } @@ -78,7 +78,7 @@ func (c *Circuit) Compute(inputs []*big.Int) ([]*big.Int, error) { var result []*big.Int for _, io := range c.Outputs { r := new(big.Int) - for bit := 0; bit < io.Size; bit++ { + for bit := 0; bit < int(io.Type.Bits); bit++ { if wires[w] != 0 { r.SetBit(r, bit, 1) } diff --git a/circuit/dot.go b/circuit/dot.go index 534d6aa7..7a7f0628 100644 --- a/circuit/dot.go +++ b/circuit/dot.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2019-2022 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -32,7 +32,7 @@ func (c *Circuit) Dot(out io.Writer) { fmt.Fprintf(out, " { rank=same") var numInputs int for _, input := range c.Inputs { - numInputs += input.Size + numInputs += int(input.Type.Bits) } for w := 0; w < numInputs; w++ { fmt.Fprintf(out, "; w%d", w) diff --git a/circuit/evaluator.go b/circuit/evaluator.go index fcfca9c2..733f94f5 100644 --- a/circuit/evaluator.go +++ b/circuit/evaluator.go @@ -72,7 +72,7 @@ func Evaluator(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int, wires := make([]ot.Label, circ.NumWires) // Receive peer inputs. - for i := 0; i < circ.Inputs[0].Size; i++ { + for i := 0; i < int(circ.Inputs[0].Type.Bits); i++ { err := conn.ReceiveLabel(&label, &labelData) if err != nil { return nil, err @@ -93,23 +93,23 @@ func Evaluator(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int, fmt.Printf(" - Querying our inputs...\n") } // Wire offset. - if err := conn.SendUint32(circ.Inputs[0].Size); err != nil { + if err := conn.SendUint32(int(circ.Inputs[0].Type.Bits)); err != nil { return nil, err } // Wire count. - if err := conn.SendUint32(circ.Inputs[1].Size); err != nil { + if err := conn.SendUint32(int(circ.Inputs[1].Type.Bits)); err != nil { return nil, err } if err := conn.Flush(); err != nil { return nil, err } - flags := make([]bool, circ.Inputs[1].Size) - for i := 0; i < circ.Inputs[1].Size; i++ { + flags := make([]bool, int(circ.Inputs[1].Type.Bits)) + for i := 0; i < int(circ.Inputs[1].Type.Bits); i++ { if inputs.Bit(i) == 1 { flags[i] = true } } - if err := oti.Receive(flags, wires[circ.Inputs[0].Size:]); err != nil { + if err := oti.Receive(flags, wires[circ.Inputs[0].Type.Bits:]); err != nil { return nil, err } xfer := conn.Stats.Sum() - ioStats diff --git a/circuit/garble.go b/circuit/garble.go index 3424e4b3..d79d9141 100644 --- a/circuit/garble.go +++ b/circuit/garble.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2019-2022 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -244,11 +244,11 @@ func (g *Gate) garble(wires []ot.Wire, enc cipher.Block, r ot.Label, // Inputs. switch g.Op { case XOR, XNOR, AND, OR: - b = wires[g.Input1.ID()] + b = wires[g.Input1.Int()] fallthrough case INV: - a = wires[g.Input0.ID()] + a = wires[g.Input0.Int()] default: return nil, fmt.Errorf("invalid gate type %s", g.Op) @@ -404,7 +404,7 @@ func (g *Gate) garble(wires []ot.Wire, enc cipher.Block, r ot.Label, default: return nil, fmt.Errorf("invalid operand %s", g.Op) } - wires[g.Output.ID()] = c + wires[g.Output.Int()] = c return table[start : start+count], nil } diff --git a/circuit/garbler.go b/circuit/garbler.go index a19ccba1..691ae4f6 100644 --- a/circuit/garbler.go +++ b/circuit/garbler.go @@ -82,7 +82,7 @@ func Garbler(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int, // Select our inputs. var n1 []ot.Label - for i := 0; i < circ.Inputs[0].Size; i++ { + for i := 0; i < int(circ.Inputs[0].Type.Bits); i++ { wire := garbled.Wires[i] var n ot.Label @@ -128,7 +128,8 @@ func Garbler(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int, if err != nil { return nil, err } - if offset != circ.Inputs[0].Size || count != circ.Inputs[1].Size { + if offset != int(circ.Inputs[0].Type.Bits) || + count != int(circ.Inputs[1].Type.Bits) { return nil, fmt.Errorf("peer can't OT wires [%d...%d[", offset, offset+count) } diff --git a/circuit/ioarg.go b/circuit/ioarg.go new file mode 100644 index 00000000..f6e14c9a --- /dev/null +++ b/circuit/ioarg.go @@ -0,0 +1,121 @@ +// +// Copyright (c) 2019-2023 Markku Rossi +// +// All rights reserved. +// + +package circuit + +import ( + "fmt" + "math/big" + + "github.com/markkurossi/mpc/types" +) + +// IOArg describes circuit input argument. +type IOArg struct { + Name string + Type types.Info + Compound IO +} + +func (io IOArg) String() string { + if len(io.Compound) > 0 { + return io.Compound.String() + } + + if len(io.Name) > 0 { + return io.Name + ":" + io.Type.String() + } + return io.Type.String() +} + +// Parse parses the I/O argument from the input string values. +func (io IOArg) Parse(inputs []string) (*big.Int, error) { + result := new(big.Int) + + if len(io.Compound) == 0 { + if len(inputs) != 1 { + return nil, + fmt.Errorf("invalid amount of arguments, got %d, expected 1", + len(inputs)) + } + + switch io.Type.Type { + case types.TInt, types.TUint: + _, ok := result.SetString(inputs[0], 0) + if !ok { + return nil, fmt.Errorf("invalid input: %s", inputs[0]) + } + + case types.TBool: + switch inputs[0] { + case "0", "f", "false": + case "1", "t", "true": + result.SetInt64(1) + default: + return nil, fmt.Errorf("invalid bool constant: %s", inputs[0]) + } + + case types.TArray: + count := int(io.Type.ArraySize) + elSize := int(io.Type.ElementType.Bits) + + val := new(big.Int) + _, ok := val.SetString(inputs[0], 0) + if !ok { + return nil, fmt.Errorf("invalid input: %s", inputs[0]) + } + + valElCount := val.BitLen() / elSize + if val.BitLen()%elSize != 0 { + valElCount++ + } + if valElCount > count { + return nil, fmt.Errorf("too many values for input: %s", + inputs[0]) + } + pad := count - valElCount + val.Lsh(val, uint(pad*elSize)) + + mask := new(big.Int) + for i := 0; i < elSize; i++ { + mask.SetBit(mask, i, 1) + } + + for i := 0; i < count; i++ { + next := new(big.Int).Rsh(val, uint((count-i-1)*elSize)) + next = next.And(next, mask) + + next.Lsh(next, uint(i*elSize)) + result.Or(result, next) + } + + default: + return nil, fmt.Errorf("unsupported input type: %s", io.Type) + } + + return result, nil + } + if len(inputs) != len(io.Compound) { + return nil, + fmt.Errorf("invalid amount of arguments, got %d, expected %d", + len(inputs), len(io.Compound)) + } + + var offset int + + for idx, arg := range io.Compound { + input, err := arg.Parse(inputs[idx : idx+1]) + if err != nil { + return nil, err + } + + input.Lsh(input, uint(offset)) + result.Or(result, input) + + offset += int(arg.Type.Bits) + } + return result, nil +} diff --git a/circuit/marshal.go b/circuit/marshal.go index 1249c915..fb3282f9 100644 --- a/circuit/marshal.go +++ b/circuit/marshal.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2021 Markku Rossi +// Copyright (c) 2020-2021, 2023 Markku Rossi // // All rights reserved. // @@ -87,10 +87,10 @@ func marshalIOArg(out io.Writer, arg IOArg) error { if err := marshalString(out, arg.Name); err != nil { return err } - if err := marshalString(out, arg.Type); err != nil { + if err := marshalString(out, arg.Type.String()); err != nil { return err } - if err := binary.Write(out, bo, uint32(arg.Size)); err != nil { + if err := binary.Write(out, bo, uint32(arg.Type.Bits)); err != nil { return err } if err := binary.Write(out, bo, uint32(len(arg.Compound))); err != nil { @@ -118,12 +118,12 @@ func (c *Circuit) MarshalBristol(out io.Writer) error { fmt.Fprintf(out, "%d %d\n", c.NumGates, c.NumWires) fmt.Fprintf(out, "%d", len(c.Inputs)) for _, input := range c.Inputs { - fmt.Fprintf(out, " %d", input.Size) + fmt.Fprintf(out, " %d", input.Type.Bits) } fmt.Fprintln(out) fmt.Fprintf(out, "%d", len(c.Outputs)) for _, ret := range c.Outputs { - fmt.Fprintf(out, " %d", ret.Size) + fmt.Fprintf(out, " %d", ret.Type.Bits) } fmt.Fprintln(out) fmt.Fprintln(out) diff --git a/circuit/parser.go b/circuit/parser.go index 87465e21..6e11ad92 100644 --- a/circuit/parser.go +++ b/circuit/parser.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2019-2022 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -17,6 +17,8 @@ import ( "regexp" "strconv" "strings" + + "github.com/markkurossi/mpc/types" ) var reParts = regexp.MustCompilePOSIX("[[:space:]]+") @@ -90,7 +92,7 @@ func ParseMPCLC(in io.Reader) (*Circuit, error) { return nil, err } inputs = append(inputs, arg) - inputWires += arg.Size + inputWires += int(arg.Type.Bits) } for i := 0; i < int(header.NumOutputs); i++ { out, err := parseIOArg(r) @@ -98,7 +100,7 @@ func ParseMPCLC(in io.Reader) (*Circuit, error) { return nil, err } outputs = append(outputs, out) - outputWires += out.Size + outputWires += int(out.Type.Bits) } // Mark input wires seen. @@ -222,8 +224,11 @@ func parseIOArg(r *bufio.Reader) (arg IOArg, err error) { return arg, err } arg.Name = name - arg.Type = t - arg.Size = int(ui32) + arg.Type, err = types.Parse(t) + if err != nil { + return arg, err + } + arg.Type.Bits = types.Size(ui32) // Compound if err := binary.Read(r, bo, &ui32); err != nil { @@ -309,8 +314,10 @@ func ParseBristol(in io.Reader) (*Circuit, error) { } inputs = append(inputs, IOArg{ Name: fmt.Sprintf("NI%d", i), - Type: fmt.Sprintf("u%d", bits), - Size: bits, + Type: types.Info{ + Type: types.TUint, + Bits: types.Size(bits), + }, }) inputWires += bits } @@ -348,8 +355,10 @@ func ParseBristol(in io.Reader) (*Circuit, error) { } outputs = append(outputs, IOArg{ Name: fmt.Sprintf("NO%d", i), - Type: fmt.Sprintf("u%d", bits), - Size: bits, + Type: types.Info{ + Type: types.TUint, + Bits: types.Size(bits), + }, }) } diff --git a/circuit/stream_evaluator.go b/circuit/stream_evaluator.go index 5d2017c5..e5e7bdb8 100644 --- a/circuit/stream_evaluator.go +++ b/circuit/stream_evaluator.go @@ -15,6 +15,7 @@ import ( "github.com/markkurossi/mpc/ot" "github.com/markkurossi/mpc/p2p" + "github.com/markkurossi/mpc/types" ) // Protocol operation codes. @@ -141,7 +142,8 @@ func StreamEvaluator(conn *p2p.Conn, oti ot.OT, inputFlag []string, fmt.Printf(" - Out: %s\n", outputs) fmt.Printf(" - In: %s\n", inputFlag) - streaming, err := NewStreamEval(key, in1.Size+in2.Size, outputs.Size()) + streaming, err := NewStreamEval(key, int(in1.Type.Bits+in2.Type.Bits), + outputs.Size()) if err != nil { return nil, nil, err } @@ -149,7 +151,7 @@ func StreamEvaluator(conn *p2p.Conn, oti ot.OT, inputFlag []string, // Receive peer inputs. var label ot.Label var labelData ot.LabelData - for w := 0; w < in1.Size; w++ { + for w := 0; w < int(in1.Type.Bits); w++ { err := conn.ReceiveLabel(&label, &labelData) if err != nil { return nil, nil, err @@ -169,13 +171,13 @@ func StreamEvaluator(conn *p2p.Conn, oti ot.OT, inputFlag []string, if verbose { fmt.Printf(" - Querying our inputs...\n") } - flags := make([]bool, in2.Size) - for i := 0; i < in2.Size; i++ { + flags := make([]bool, in2.Type.Bits) + for i := 0; i < int(in2.Type.Bits); i++ { if inputs.Bit(i) == 1 { flags[i] = true } } - inputLabels := streaming.GetInputs(in1.Size, in2.Size) + inputLabels := streaming.GetInputs(int(in1.Type.Bits), int(in2.Type.Bits)) if err := oti.Receive(flags, inputLabels); err != nil { return nil, nil, err } @@ -474,8 +476,11 @@ func receiveArgument(conn *p2p.Conn) (arg IOArg, err error) { return arg, err } arg.Name = name - arg.Type = t - arg.Size = size + arg.Type, err = types.Parse(t) + if err != nil { + return arg, err + } + arg.Type.Bits = types.Size(size) count, err := conn.ReceiveUint32() if err != nil { diff --git a/circuit/stream_garble_test.go b/circuit/stream_garble_test.go index a4a3473a..9b60b976 100644 --- a/circuit/stream_garble_test.go +++ b/circuit/stream_garble_test.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2022 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -8,6 +8,7 @@ package circuit import ( "testing" + "time" "github.com/markkurossi/mpc/ot" ) @@ -344,3 +345,12 @@ func encode7var5(b []byte, v uint32) { b[3] = byte(v >> 7) b[4] = byte(v) } + +func BenchmarkTimeDuration(b *testing.B) { + var total time.Duration + + for i := 0; i < b.N; i++ { + start := time.Now() + total += time.Now().Sub(start) + } +} diff --git a/compiler/ast/builtin.go b/compiler/ast/builtin.go index a2546166..21e1b279 100644 --- a/compiler/ast/builtin.go +++ b/compiler/ast/builtin.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2019-2022 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -365,11 +365,11 @@ func nativeCircuit(name string, block *ssa.Block, ctx *Codegen, // Check that the argument types match. for idx, io := range circ.Inputs { arg := args[idx] - if io.Size < int(arg.Type.Bits) || io.Size > int(arg.Type.Bits) && + if io.Type.Bits < arg.Type.Bits || io.Type.Bits > arg.Type.Bits && !arg.Const { return nil, nil, ctx.Errorf(loc, "invalid argument %d for native circuit: got %s, need %d", - idx, arg.Type, io.Size) + idx, arg.Type, io.Type.Bits) } } @@ -384,7 +384,7 @@ func nativeCircuit(name string, block *ssa.Block, ctx *Codegen, for _, io := range circ.Outputs { result = append(result, gen.AnonVal(types.Info{ Type: types.TUndefined, - Bits: types.Size(io.Size), + Bits: io.Type.Bits, })) } diff --git a/compiler/ast/package.go b/compiler/ast/package.go index 71f3797e..d6ba280c 100644 --- a/compiler/ast/package.go +++ b/compiler/ast/package.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2022 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -81,16 +81,15 @@ func (pkg *Package) Compile(ctx *Codegen) (*ssa.Program, Annotations, error) { a := gen.NewVal(arg.Name, typeInfo, ctx.Scope()) ctx.Start().Bindings.Set(a, nil) - arg := circuit.IOArg{ - Name: a.String(), - Type: a.Type.String(), - Size: int(a.Type.Bits), + input := circuit.IOArg{ + Name: arg.Name, + Type: a.Type, } if typeInfo.Type == types.TStruct { - arg.Compound = flattenStruct(typeInfo) + input.Compound = flattenStruct(typeInfo) } - inputs = append(inputs, arg) + inputs = append(inputs, input) } // Compile main. @@ -128,8 +127,7 @@ func (pkg *Package) Compile(ctx *Codegen) (*ssa.Program, Annotations, error) { v := returnVars[idx] outputs = append(outputs, circuit.IOArg{ Name: v.String(), - Type: v.Type.String(), - Size: int(v.Type.Bits), + Type: v.Type, }) } @@ -171,8 +169,7 @@ func flattenStruct(t types.Info) circuit.IO { } else { result = append(result, circuit.IOArg{ Name: f.Name, - Type: f.Type.String(), - Size: int(f.Type.Bits), + Type: f.Type, }) } } diff --git a/compiler/circuits/allocator.go b/compiler/circuits/allocator.go new file mode 100644 index 00000000..597861ca --- /dev/null +++ b/compiler/circuits/allocator.go @@ -0,0 +1,101 @@ +// +// Copyright (c) 2023 Markku Rossi +// +// All rights reserved. +// + +package circuits + +import ( + "fmt" + "unsafe" + + "github.com/markkurossi/mpc/circuit" + "github.com/markkurossi/mpc/types" +) + +var ( + sizeofWire = uint64(unsafe.Sizeof(Wire{})) + sizeofGate = uint64(unsafe.Sizeof(Gate{})) +) + +// Allocator implements circuit wire and gate allocation. +type Allocator struct { + numWire uint64 + numWires uint64 + numGates uint64 +} + +// NewAllocator creates a new circuit allocator. +func NewAllocator() *Allocator { + return new(Allocator) +} + +// Wire allocates a new Wire. +func (alloc *Allocator) Wire() *Wire { + alloc.numWire++ + w := new(Wire) + w.Reset(UnassignedID) + return w +} + +// Wires allocate an array of Wires. +func (alloc *Allocator) Wires(bits types.Size) []*Wire { + alloc.numWires += uint64(bits) + + wires := make([]Wire, bits) + result := make([]*Wire, bits) + for i := 0; i < int(bits); i++ { + w := &wires[i] + w.id = UnassignedID + result[i] = w + } + return result +} + +// BinaryGate creates a new binary gate. +func (alloc *Allocator) BinaryGate(op circuit.Operation, a, b, o *Wire) *Gate { + alloc.numGates++ + gate := &Gate{ + Op: op, + A: a, + B: b, + O: o, + } + a.AddOutput(gate) + b.AddOutput(gate) + o.SetInput(gate) + + return gate +} + +// INVGate creates a new INV gate. +func (alloc *Allocator) INVGate(i, o *Wire) *Gate { + alloc.numGates++ + gate := &Gate{ + Op: circuit.INV, + A: i, + O: o, + } + i.AddOutput(gate) + o.SetInput(gate) + + return gate +} + +// Debug print debugging information about the circuit allocator. +func (alloc *Allocator) Debug() { + wireSize := circuit.FileSize(alloc.numWire * sizeofWire) + wiresSize := circuit.FileSize(alloc.numWires * sizeofWire) + gatesSize := circuit.FileSize(alloc.numGates * sizeofGate) + + total := float64(wireSize + wiresSize + gatesSize) + + fmt.Println("circuits.Allocator:") + fmt.Printf(" wire : %9v %5s %5.2f%%\n", + alloc.numWire, wireSize, float64(wireSize)/total*100.0) + fmt.Printf(" wires: %9v %5s %5.2f%%\n", + alloc.numWires, wiresSize, float64(wiresSize)/total*100.0) + fmt.Printf(" gates: %9v %5s %5.2f%%\n", + alloc.numGates, gatesSize, float64(gatesSize)/total*100.0) +} diff --git a/compiler/circuits/circ_adder.go b/compiler/circuits/circ_adder.go index 1aed71ed..0515c7f9 100644 --- a/compiler/circuits/circ_adder.go +++ b/compiler/circuits/circ_adder.go @@ -1,7 +1,7 @@ // // circ_adder.go // -// Copyright (c) 2019-2021 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -13,46 +13,46 @@ import ( ) // NewHalfAdder creates a half adder circuit. -func NewHalfAdder(compiler *Compiler, a, b, s, c *Wire) { +func NewHalfAdder(cc *Compiler, a, b, s, c *Wire) { // S = XOR(A, B) - compiler.AddGate(NewBinary(circuit.XOR, a, b, s)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a, b, s)) if c != nil { // C = AND(A, B) - compiler.AddGate(NewBinary(circuit.AND, a, b, c)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, a, b, c)) } } // NewFullAdder creates a full adder circuit -func NewFullAdder(compiler *Compiler, a, b, cin, s, cout *Wire) { - w1 := NewWire() - w2 := NewWire() - w3 := NewWire() +func NewFullAdder(cc *Compiler, a, b, cin, s, cout *Wire) { + w1 := cc.Calloc.Wire() + w2 := cc.Calloc.Wire() + w3 := cc.Calloc.Wire() // s = a XOR b XOR cin // cout = cin XOR ((a XOR cin) AND (b XOR cin)). // w1 = XOR(b, cin) - compiler.AddGate(NewBinary(circuit.XOR, b, cin, w1)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, b, cin, w1)) // s = XOR(a, w1) - compiler.AddGate(NewBinary(circuit.XOR, a, w1, s)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a, w1, s)) if cout != nil { // w2 = XOR(a, cin) - compiler.AddGate(NewBinary(circuit.XOR, a, cin, w2)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a, cin, w2)) // w3 = AND(w1, w2) - compiler.AddGate(NewBinary(circuit.AND, w1, w2, w3)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, w2, w3)) // cout = XOR(cin, w3) - compiler.AddGate(NewBinary(circuit.XOR, cin, w3, cout)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, cin, w3, cout)) } } // NewAdder creates a new adder circuit implementing z=x+y. -func NewAdder(compiler *Compiler, x, y, z []*Wire) error { - x, y = compiler.ZeroPad(x, y) +func NewAdder(cc *Compiler, x, y, z []*Wire) error { + x, y = cc.ZeroPad(x, y) if len(x) > len(z) { x = x[0:len(z)] y = y[0:len(z)] @@ -63,10 +63,10 @@ func NewAdder(compiler *Compiler, x, y, z []*Wire) error { if len(z) > 1 { cin = z[1] } - NewHalfAdder(compiler, x[0], y[0], z[0], cin) + NewHalfAdder(cc, x[0], y[0], z[0], cin) } else { - cin := NewWire() - NewHalfAdder(compiler, x[0], y[0], z[0], cin) + cin := cc.Calloc.Wire() + NewHalfAdder(cc, x[0], y[0], z[0], cin) for i := 1; i < len(x); i++ { var cout *Wire @@ -78,10 +78,10 @@ func NewAdder(compiler *Compiler, x, y, z []*Wire) error { cout = z[len(x)] } } else { - cout = NewWire() + cout = cc.Calloc.Wire() } - NewFullAdder(compiler, x[i], y[i], cin, z[i], cout) + NewFullAdder(cc, x[i], y[i], cin, z[i], cout) cin = cout } @@ -89,7 +89,7 @@ func NewAdder(compiler *Compiler, x, y, z []*Wire) error { // Set all leftover bits to zero. for i := len(x) + 1; i < len(z); i++ { - z[i] = compiler.ZeroWire() + z[i] = cc.ZeroWire() } return nil diff --git a/compiler/circuits/circ_binary.go b/compiler/circuits/circ_binary.go index efbfed88..066576dd 100644 --- a/compiler/circuits/circ_binary.go +++ b/compiler/circuits/circ_binary.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2021 Markku Rossi +// Copyright (c) 2020-2021, 2023 Markku Rossi // // All rights reserved. // @@ -11,55 +11,55 @@ import ( ) // NewBinaryAND creates a new binary AND circuit implementing r=x&y -func NewBinaryAND(compiler *Compiler, x, y, r []*Wire) error { - x, y = compiler.ZeroPad(x, y) +func NewBinaryAND(cc *Compiler, x, y, r []*Wire) error { + x, y = cc.ZeroPad(x, y) if len(r) < len(x) { x = x[0:len(r)] y = y[0:len(r)] } for i := 0; i < len(x); i++ { - compiler.AddGate(NewBinary(circuit.AND, x[i], y[i], r[i])) + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[i], y[i], r[i])) } return nil } // NewBinaryClear creates a new binary clear circuit implementing r=x&^y. -func NewBinaryClear(compiler *Compiler, x, y, r []*Wire) error { - x, y = compiler.ZeroPad(x, y) +func NewBinaryClear(cc *Compiler, x, y, r []*Wire) error { + x, y = cc.ZeroPad(x, y) if len(r) < len(x) { x = x[0:len(r)] y = y[0:len(r)] } for i := 0; i < len(x); i++ { - w := NewWire() - compiler.INV(y[i], w) - compiler.AddGate(NewBinary(circuit.AND, x[i], w, r[i])) + w := cc.Calloc.Wire() + cc.INV(y[i], w) + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[i], w, r[i])) } return nil } // NewBinaryOR creates a new binary OR circuit implementing r=x|y. -func NewBinaryOR(compiler *Compiler, x, y, r []*Wire) error { - x, y = compiler.ZeroPad(x, y) +func NewBinaryOR(cc *Compiler, x, y, r []*Wire) error { + x, y = cc.ZeroPad(x, y) if len(r) < len(x) { x = x[0:len(r)] y = y[0:len(r)] } for i := 0; i < len(x); i++ { - compiler.AddGate(NewBinary(circuit.OR, x[i], y[i], r[i])) + cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, x[i], y[i], r[i])) } return nil } // NewBinaryXOR creates a new binary XOR circuit implementing r=x^y. -func NewBinaryXOR(compiler *Compiler, x, y, r []*Wire) error { - x, y = compiler.ZeroPad(x, y) +func NewBinaryXOR(cc *Compiler, x, y, r []*Wire) error { + x, y = cc.ZeroPad(x, y) if len(r) < len(x) { x = x[0:len(r)] y = y[0:len(r)] } for i := 0; i < len(x); i++ { - compiler.AddGate(NewBinary(circuit.XOR, x[i], y[i], r[i])) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[i], y[i], r[i])) } return nil } diff --git a/compiler/circuits/circ_comparators.go b/compiler/circuits/circ_comparators.go index 6c0c2ee0..3160f75c 100644 --- a/compiler/circuits/circ_comparators.go +++ b/compiler/circuits/circ_comparators.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2021 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -14,152 +14,148 @@ import ( ) // comparator tests if x>y if cin=0, and x>=y if cin=1. -func comparator(compiler *Compiler, cin *Wire, x, y, r []*Wire) error { - x, y = compiler.ZeroPad(x, y) +func comparator(cc *Compiler, cin *Wire, x, y, r []*Wire) error { + x, y = cc.ZeroPad(x, y) if len(r) != 1 { return fmt.Errorf("invalid lt comparator arguments: r=%d", len(r)) } for i := 0; i < len(x); i++ { - w1 := NewWire() - compiler.AddGate(NewBinary(circuit.XNOR, cin, y[i], w1)) - w2 := NewWire() - compiler.AddGate(NewBinary(circuit.XOR, cin, x[i], w2)) - w3 := NewWire() - compiler.AddGate(NewBinary(circuit.AND, w1, w2, w3)) + w1 := cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR, cin, y[i], w1)) + w2 := cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, cin, x[i], w2)) + w3 := cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, w2, w3)) var cout *Wire if i+1 < len(x) { - cout = NewWire() + cout = cc.Calloc.Wire() } else { cout = r[0] } - compiler.AddGate(NewBinary(circuit.XOR, cin, w3, cout)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, cin, w3, cout)) cin = cout } return nil } // NewGtComparator tests if x>y. -func NewGtComparator(compiler *Compiler, x, y, r []*Wire) error { - return comparator(compiler, compiler.ZeroWire(), x, y, r) +func NewGtComparator(cc *Compiler, x, y, r []*Wire) error { + return comparator(cc, cc.ZeroWire(), x, y, r) } // NewGeComparator tests if x>=y. -func NewGeComparator(compiler *Compiler, x, y, r []*Wire) error { - return comparator(compiler, compiler.OneWire(), x, y, r) +func NewGeComparator(cc *Compiler, x, y, r []*Wire) error { + return comparator(cc, cc.OneWire(), x, y, r) } // NewLtComparator tests if x= len(x) { out = r[0] } else { - out = NewWire() + out = cc.Calloc.Wire() } - compiler.AddGate(NewBinary(circuit.OR, c, xor, out)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, c, xor, out)) c = out } return nil } // NewEqComparator tests if x==y. -func NewEqComparator(compiler *Compiler, x, y, r []*Wire) error { +func NewEqComparator(cc *Compiler, x, y, r []*Wire) error { if len(r) != 1 { return fmt.Errorf("invalid eq comparator arguments: r=%d", len(r)) } // w = x == y - w := NewWire() - err := NewNeqComparator(compiler, x, y, []*Wire{w}) + w := cc.Calloc.Wire() + err := NewNeqComparator(cc, x, y, []*Wire{w}) if err != nil { return err } // r = !w - compiler.INV(w, r[0]) + cc.INV(w, r[0]) return nil } // NewLogicalAND implements logical AND implementing r=x&y. The input // and output wires must be 1 bit wide. -func NewLogicalAND(compiler *Compiler, x, y, r []*Wire) error { +func NewLogicalAND(cc *Compiler, x, y, r []*Wire) error { if len(x) != 1 || len(y) != 1 || len(r) != 1 { return fmt.Errorf("invalid logical and arguments: x=%d, y=%d, r=%d", len(x), len(y), len(r)) } - compiler.AddGate(NewBinary(circuit.AND, x[0], y[0], r[0])) + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[0], y[0], r[0])) return nil } // NewLogicalOR implements logical OR implementing r=x|y. The input // and output wires must be 1 bit wide. -func NewLogicalOR(compiler *Compiler, x, y, r []*Wire) error { +func NewLogicalOR(cc *Compiler, x, y, r []*Wire) error { if len(x) != 1 || len(y) != 1 || len(r) != 1 { return fmt.Errorf("invalid logical or arguments: x=%d, y=%d, r=%d", len(x), len(y), len(r)) } - compiler.AddGate(NewBinary(circuit.OR, x[0], y[0], r[0])) + cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, x[0], y[0], r[0])) return nil } // NewBitSetTest tests if the index'th bit of x is set. -func NewBitSetTest(compiler *Compiler, x []*Wire, index types.Size, - r []*Wire) error { - +func NewBitSetTest(cc *Compiler, x []*Wire, index types.Size, r []*Wire) error { if len(r) != 1 { return fmt.Errorf("invalid bit set test arguments: x=%d, r=%d", len(x), len(r)) } if index < types.Size(len(x)) { - w := compiler.ZeroWire() - compiler.AddGate(NewBinary(circuit.XOR, x[index], w, r[0])) + w := cc.ZeroWire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[index], w, r[0])) } else { - r[0] = compiler.ZeroWire() + r[0] = cc.ZeroWire() } return nil } // NewBitClrTest tests if the index'th bit of x is unset. -func NewBitClrTest(compiler *Compiler, x []*Wire, index types.Size, - r []*Wire) error { - +func NewBitClrTest(cc *Compiler, x []*Wire, index types.Size, r []*Wire) error { if len(r) != 1 { return fmt.Errorf("invalid bit clear test arguments: x=%d, r=%d", len(x), len(r)) } if index < types.Size(len(x)) { - w := compiler.OneWire() - compiler.AddGate(NewBinary(circuit.XOR, x[index], w, r[0])) + w := cc.OneWire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[index], w, r[0])) } else { - r[0] = compiler.OneWire() + r[0] = cc.OneWire() } return nil } diff --git a/compiler/circuits/circ_divider.go b/compiler/circuits/circ_divider.go index 55ae0577..75ad5479 100644 --- a/compiler/circuits/circ_divider.go +++ b/compiler/circuits/circ_divider.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2019 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -7,8 +7,8 @@ package circuits // NewDivider creates a division circuit computing r=a/b, q=a%b. -func NewDivider(compiler *Compiler, a, b, q, r []*Wire) error { - a, b = compiler.ZeroPad(a, b) +func NewDivider(cc *Compiler, a, b, q, r []*Wire) error { + a, b = cc.ZeroPad(a, b) rIn := make([]*Wire, len(b)+1) rOut := make([]*Wire, len(b)+1) @@ -16,13 +16,13 @@ func NewDivider(compiler *Compiler, a, b, q, r []*Wire) error { // Init bINV. bINV := make([]*Wire, len(b)) for i := 0; i < len(b); i++ { - bINV[i] = NewWire() - compiler.INV(b[i], bINV[i]) + bINV[i] = cc.Calloc.Wire() + cc.INV(b[i], bINV[i]) } // Init for the first row. for i := 0; i < len(b); i++ { - rOut[i] = compiler.ZeroWire() + rOut[i] = cc.ZeroWire() } // Generate matrix. @@ -32,26 +32,26 @@ func NewDivider(compiler *Compiler, a, b, q, r []*Wire) error { copy(rIn[1:], rOut) // Adders from b{0} to b{n-1}, 0 - cIn := compiler.OneWire() + cIn := cc.OneWire() for x := 0; x < len(b)+1; x++ { var bw *Wire if x < len(b) { bw = bINV[x] } else { - bw = compiler.OneWire() // INV(0) + bw = cc.OneWire() // INV(0) } - co := NewWire() - ro := NewWire() - NewFullAdder(compiler, rIn[x], bw, cIn, ro, co) + co := cc.Calloc.Wire() + ro := cc.Calloc.Wire() + NewFullAdder(cc, rIn[x], bw, cIn, ro, co) rOut[x] = ro cIn = co } // Quotient y. if len(a)-1-y < len(q) { - w := NewWire() - compiler.INV(cIn, w) - compiler.INV(w, q[len(a)-1-y]) + w := cc.Calloc.Wire() + cc.INV(cIn, w) + cc.INV(w, q[len(a)-1-y]) } // MUXes from high to low bit. @@ -60,10 +60,10 @@ func NewDivider(compiler *Compiler, a, b, q, r []*Wire) error { if y+1 >= len(a) && x < len(r) { ro = r[x] } else { - ro = NewWire() + ro = cc.Calloc.Wire() } - err := NewMUX(compiler, []*Wire{cIn}, rOut[x:x+1], rIn[x:x+1], + err := NewMUX(cc, []*Wire{cIn}, rOut[x:x+1], rIn[x:x+1], []*Wire{ro}) if err != nil { return err @@ -74,12 +74,12 @@ func NewDivider(compiler *Compiler, a, b, q, r []*Wire) error { // Set extra quotient bits to zero. for y := len(a); y < len(q); y++ { - q[y] = compiler.ZeroWire() + q[y] = cc.ZeroWire() } // Set extra remainder bits to zero. for x := len(b); x < len(r); x++ { - r[x] = compiler.ZeroWire() + r[x] = cc.ZeroWire() } return nil diff --git a/compiler/circuits/circ_hamming.go b/compiler/circuits/circ_hamming.go index 8baaf15a..f9a6faa2 100644 --- a/compiler/circuits/circ_hamming.go +++ b/compiler/circuits/circ_hamming.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2021 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -13,13 +13,13 @@ import ( // Hamming creates a hamming distance circuit computing the hamming // distance between a and b and returning the distance in r. -func Hamming(compiler *Compiler, a, b, r []*Wire) error { - a, b = compiler.ZeroPad(a, b) +func Hamming(cc *Compiler, a, b, r []*Wire) error { + a, b = cc.ZeroPad(a, b) var arr [][]*Wire for i := 0; i < len(a); i++ { - w := NewWire() - compiler.AddGate(NewBinary(circuit.XOR, a[i], b[i], w)) + w := cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a[i], b[i], w)) arr = append(arr, []*Wire{w}) } @@ -27,8 +27,8 @@ func Hamming(compiler *Compiler, a, b, r []*Wire) error { var n [][]*Wire for i := 0; i < len(arr); i += 2 { if i+1 < len(arr) { - result := MakeWires(types.Size(len(arr[i]) + 1)) - err := NewAdder(compiler, arr[i], arr[i+1], result) + result := cc.Calloc.Wires(types.Size(len(arr[i]) + 1)) + err := NewAdder(cc, arr[i], arr[i+1], result) if err != nil { return err } @@ -40,5 +40,5 @@ func Hamming(compiler *Compiler, a, b, r []*Wire) error { arr = n } - return NewAdder(compiler, arr[0], arr[1], r) + return NewAdder(cc, arr[0], arr[1], r) } diff --git a/compiler/circuits/circ_index.go b/compiler/circuits/circ_index.go index 7e2509e3..b76ec137 100644 --- a/compiler/circuits/circ_index.go +++ b/compiler/circuits/circ_index.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2021 Markku Rossi +// Copyright (c) 2021-2023 Markku Rossi // // All rights reserved. // @@ -11,7 +11,7 @@ import ( ) // NewIndex creates a new array element selection (index) circuit. -func NewIndex(compiler *Compiler, size int, array, index, out []*Wire) error { +func NewIndex(cc *Compiler, size int, array, index, out []*Wire) error { if len(array)%size != 0 { return fmt.Errorf("array width %d must be multiple of element size %d", len(array), size) @@ -23,7 +23,7 @@ func NewIndex(compiler *Compiler, size int, array, index, out []*Wire) error { n := len(array) / size if n == 0 { for i := 0; i < len(out); i++ { - out[i] = compiler.ZeroWire() + out[i] = cc.ZeroWire() } return nil } @@ -35,16 +35,16 @@ func NewIndex(compiler *Compiler, size int, array, index, out []*Wire) error { bits++ } - return newIndex(compiler, bits-1, length, size, array, index, out) + return newIndex(cc, bits-1, length, size, array, index, out) } -func newIndex(compiler *Compiler, bit, length, size int, +func newIndex(cc *Compiler, bit, length, size int, array, index, out []*Wire) error { // Default "not found" value. def := make([]*Wire, size) for i := 0; i < size; i++ { - def[i] = compiler.ZeroWire() + def[i] = cc.ZeroWire() } n := len(array) / size @@ -58,20 +58,20 @@ func newIndex(compiler *Compiler, bit, length, size int, } else { tVal = def } - return NewMUX(compiler, index[0:1], tVal, fVal, out) + return NewMUX(cc, index[0:1], tVal, fVal, out) } length /= 2 fVal := make([]*Wire, size) for i := 0; i < size; i++ { - fVal[i] = NewWire() + fVal[i] = cc.Calloc.Wire() } fArray := array if n > length { fArray = fArray[:length*size] } - err := newIndex(compiler, bit-1, length, size, fArray, index, fVal) + err := newIndex(cc, bit-1, length, size, fArray, index, fVal) if err != nil { return err } @@ -80,9 +80,9 @@ func newIndex(compiler *Compiler, bit, length, size int, if n > length { tVal = make([]*Wire, size) for i := 0; i < size; i++ { - tVal[i] = NewWire() + tVal[i] = cc.Calloc.Wire() } - err = newIndex(compiler, bit-1, length, size, + err = newIndex(cc, bit-1, length, size, array[length*size:], index, tVal) if err != nil { return err @@ -91,5 +91,5 @@ func newIndex(compiler *Compiler, bit, length, size int, tVal = def } - return NewMUX(compiler, index[bit:bit+1], tVal, fVal, out) + return NewMUX(cc, index[bit:bit+1], tVal, fVal, out) } diff --git a/compiler/circuits/circ_multiplier.go b/compiler/circuits/circ_multiplier.go index aa02ad08..21bb431c 100644 --- a/compiler/circuits/circ_multiplier.go +++ b/compiler/circuits/circ_multiplier.go @@ -31,8 +31,8 @@ func NewMultiplier(c *Compiler, arrayTreshold int, x, y, z []*Wire) error { // NewArrayMultiplier creates a multiplier circuit implementing // x*y=z. This function implements Array Multiplier Circuit. -func NewArrayMultiplier(compiler *Compiler, x, y, z []*Wire) error { - x, y = compiler.ZeroPad(x, y) +func NewArrayMultiplier(cc *Compiler, x, y, z []*Wire) error { + x, y = cc.ZeroPad(x, y) if len(x) > len(z) { x = x[0:len(z)] y = y[0:len(z)] @@ -40,9 +40,9 @@ func NewArrayMultiplier(compiler *Compiler, x, y, z []*Wire) error { // One bit multiplication is AND. if len(x) == 1 { - compiler.AddGate(NewBinary(circuit.AND, x[0], y[0], z[0])) + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[0], y[0], z[0])) if len(z) > 1 { - z[1] = compiler.ZeroWire() + z[1] = cc.ZeroWire() } return nil } @@ -55,10 +55,10 @@ func NewArrayMultiplier(compiler *Compiler, x, y, z []*Wire) error { if i == 0 { s = z[0] } else { - s = NewWire() + s = cc.Calloc.Wire() sums = append(sums, s) } - compiler.AddGate(NewBinary(circuit.AND, xn, y[0], s)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, xn, y[0], s)) } // Construct len(y)-2 intermediate layers @@ -67,8 +67,8 @@ func NewArrayMultiplier(compiler *Compiler, x, y, z []*Wire) error { // ANDs for y(j) var ands []*Wire for _, xn := range x { - wire := NewWire() - compiler.AddGate(NewBinary(circuit.AND, xn, y[j], wire)) + wire := cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, xn, y[j], wire)) ands = append(ands, wire) } @@ -76,22 +76,22 @@ func NewArrayMultiplier(compiler *Compiler, x, y, z []*Wire) error { var nsums []*Wire var c *Wire for i := 0; i < len(ands); i++ { - cout := NewWire() + cout := cc.Calloc.Wire() var s *Wire if i == 0 { s = z[j] } else { - s = NewWire() + s = cc.Calloc.Wire() nsums = append(nsums, s) } if i == 0 { - NewHalfAdder(compiler, ands[i], sums[i], s, cout) + NewHalfAdder(cc, ands[i], sums[i], s, cout) } else if i >= len(sums) { - NewHalfAdder(compiler, ands[i], c, s, cout) + NewHalfAdder(cc, ands[i], c, s, cout) } else { - NewFullAdder(compiler, ands[i], sums[i], c, s, cout) + NewFullAdder(cc, ands[i], sums[i], c, s, cout) } c = cout } @@ -102,29 +102,29 @@ func NewArrayMultiplier(compiler *Compiler, x, y, z []*Wire) error { // Construct final layer. var c *Wire for i, xn := range x { - and := NewWire() - compiler.AddGate(NewBinary(circuit.AND, xn, y[j], and)) + and := cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, xn, y[j], and)) var cout *Wire if i+1 >= len(x) && j+i+1 < len(z) { cout = z[j+i+1] } else { - cout = NewWire() + cout = cc.Calloc.Wire() } if j+i < len(z) { if i == 0 { - NewHalfAdder(compiler, and, sums[i], z[j+i], cout) + NewHalfAdder(cc, and, sums[i], z[j+i], cout) } else if i >= len(sums) { - NewHalfAdder(compiler, and, c, z[j+i], cout) + NewHalfAdder(cc, and, c, z[j+i], cout) } else { - NewFullAdder(compiler, and, sums[i], c, z[j+i], cout) + NewFullAdder(cc, and, sums[i], c, z[j+i], cout) } } c = cout } for i := j + len(x) + 1; i < len(z); i++ { - z[1] = compiler.ZeroWire() + z[1] = cc.ZeroWire() } return nil @@ -169,34 +169,34 @@ func NewKaratsubaMultiplier(cc *Compiler, limit int, a, b, r []*Wire) error { bLow := b[:mid] bHigh := b[mid:] - z0 := MakeWires(types.Size(min(max(len(aLow), len(bLow))*2, len(r)))) + z0 := cc.Calloc.Wires(types.Size(min(max(len(aLow), len(bLow))*2, len(r)))) if err := NewKaratsubaMultiplier(cc, limit, aLow, bLow, z0); err != nil { return err } aSumLen := max(len(aLow), len(aHigh)) + 1 - aSum := MakeWires(types.Size(aSumLen)) + aSum := cc.Calloc.Wires(types.Size(aSumLen)) if err := NewAdder(cc, aLow, aHigh, aSum); err != nil { return err } bSumLen := max(len(bLow), len(bHigh)) + 1 - bSum := MakeWires(types.Size(bSumLen)) + bSum := cc.Calloc.Wires(types.Size(bSumLen)) if err := NewAdder(cc, bLow, bHigh, bSum); err != nil { return err } - z1 := MakeWires(types.Size(min(max(aSumLen, bSumLen)*2, len(r)))) + z1 := cc.Calloc.Wires(types.Size(min(max(aSumLen, bSumLen)*2, len(r)))) if err := NewKaratsubaMultiplier(cc, limit, aSum, bSum, z1); err != nil { return err } - z2 := MakeWires(types.Size(min(max(len(aHigh), len(bHigh))*2, len(r)))) + z2 := cc.Calloc.Wires(types.Size(min(max(len(aHigh), len(bHigh))*2, len(r)))) if err := NewKaratsubaMultiplier(cc, limit, aHigh, bHigh, z2); err != nil { return err } - sub1 := MakeWires(types.Size(len(r))) + sub1 := cc.Calloc.Wires(types.Size(len(r))) if err := NewSubtractor(cc, z1, z2, sub1); err != nil { return err } - sub2 := MakeWires(types.Size(len(r))) + sub2 := cc.Calloc.Wires(types.Size(len(r))) if err := NewSubtractor(cc, sub1, z0, sub2); err != nil { return err } @@ -204,7 +204,7 @@ func NewKaratsubaMultiplier(cc *Compiler, limit int, a, b, r []*Wire) error { shift1 := cc.ShiftLeft(z2, len(r), mid*2) shift2 := cc.ShiftLeft(sub2, len(r), mid) - add1 := MakeWires(types.Size(len(r))) + add1 := cc.Calloc.Wires(types.Size(len(r))) if err := NewAdder(cc, shift1, shift2, add1); err != nil { return err } diff --git a/compiler/circuits/circ_mux.go b/compiler/circuits/circ_mux.go index c653945e..d0a5377c 100644 --- a/compiler/circuits/circ_mux.go +++ b/compiler/circuits/circ_mux.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -14,25 +14,25 @@ import ( // NewMUX creates a multiplexer circuit that selects the input t or f // to output, based on the value of the condition cond. -func NewMUX(compiler *Compiler, cond, t, f, out []*Wire) error { - t, f = compiler.ZeroPad(t, f) +func NewMUX(cc *Compiler, cond, t, f, out []*Wire) error { + t, f = cc.ZeroPad(t, f) if len(cond) != 1 || len(t) != len(f) || len(t) != len(out) { return fmt.Errorf("invalid mux arguments: cond=%d, l=%d, r=%d, out=%d", len(cond), len(t), len(f), len(out)) } for i := 0; i < len(t); i++ { - w1 := NewWire() - w2 := NewWire() + w1 := cc.Calloc.Wire() + w2 := cc.Calloc.Wire() // w1 = XOR(f[i], t[i]) - compiler.AddGate(NewBinary(circuit.XOR, f[i], t[i], w1)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, f[i], t[i], w1)) // w2 = AND(w1, cond) - compiler.AddGate(NewBinary(circuit.AND, w1, cond[0], w2)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, cond[0], w2)) // out[i] = XOR(w2, f[i]) - compiler.AddGate(NewBinary(circuit.XOR, w2, f[i], out[i])) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, w2, f[i], out[i])) } return nil diff --git a/compiler/circuits/circ_subtractor.go b/compiler/circuits/circ_subtractor.go index 7648edd3..fd6ed5f2 100644 --- a/compiler/circuits/circ_subtractor.go +++ b/compiler/circuits/circ_subtractor.go @@ -1,7 +1,7 @@ // // circ_subtractor.go // -// Copyright (c) 2019-2021 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -13,30 +13,30 @@ import ( ) // NewFullSubtractor creates a full subtractor circuit. -func NewFullSubtractor(compiler *Compiler, x, y, cin, d, cout *Wire) { - w1 := NewWire() - compiler.AddGate(NewBinary(circuit.XNOR, y, cin, w1)) - compiler.AddGate(NewBinary(circuit.XNOR, x, w1, d)) +func NewFullSubtractor(cc *Compiler, x, y, cin, d, cout *Wire) { + w1 := cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR, y, cin, w1)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR, x, w1, d)) if cout != nil { - w2 := NewWire() - compiler.AddGate(NewBinary(circuit.XOR, x, cin, w2)) + w2 := cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x, cin, w2)) - w3 := NewWire() - compiler.AddGate(NewBinary(circuit.AND, w1, w2, w3)) + w3 := cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, w2, w3)) - compiler.AddGate(NewBinary(circuit.XOR, w3, cin, cout)) + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, w3, cin, cout)) } } // NewSubtractor creates a new subtractor circuit implementing z=x-y. -func NewSubtractor(compiler *Compiler, x, y, z []*Wire) error { - x, y = compiler.ZeroPad(x, y) +func NewSubtractor(cc *Compiler, x, y, z []*Wire) error { + x, y = cc.ZeroPad(x, y) if len(x) > len(z) { x = x[0:len(z)] y = y[0:len(z)] } - cin := compiler.ZeroWire() + cin := cc.ZeroWire() for i := 0; i < len(x); i++ { var cout *Wire @@ -48,16 +48,16 @@ func NewSubtractor(compiler *Compiler, x, y, z []*Wire) error { cout = z[i+1] } } else { - cout = NewWire() + cout = cc.Calloc.Wire() } // Note y-x here. - NewFullSubtractor(compiler, y[i], x[i], cin, z[i], cout) + NewFullSubtractor(cc, y[i], x[i], cin, z[i], cout) cin = cout } for i := len(x) + 1; i < len(z); i++ { - z[i] = compiler.ZeroWire() + z[i] = cc.ZeroWire() } return nil } diff --git a/compiler/circuits/circuits_test.go b/compiler/circuits/circuits_test.go index b9825ace..d6779c5b 100644 --- a/compiler/circuits/circuits_test.go +++ b/compiler/circuits/circuits_test.go @@ -15,6 +15,7 @@ import ( "github.com/markkurossi/mpc/circuit" "github.com/markkurossi/mpc/compiler/utils" + "github.com/markkurossi/mpc/types" ) const ( @@ -23,12 +24,13 @@ const ( var ( params = utils.NewParams() + calloc = NewAllocator() ) func makeWires(count int, output bool) []*Wire { var result []*Wire for i := 0; i < count; i++ { - w := NewWire() + w := calloc.Wire() w.SetOutput(output) result = append(result, w) } @@ -39,7 +41,10 @@ func NewIO(size int, name string) circuit.IO { return circuit.IO{ circuit.IOArg{ Name: name, - Size: size, + Type: types.Info{ + Type: types.TUint, + Bits: types.Size(size), + }, }, } } @@ -50,13 +55,13 @@ func TestAdd4(t *testing.T) { // 2xbits inputs, bits+1 outputs inputs := makeWires(bits*2, false) outputs := makeWires(bits+1, true) - c, err := NewCompiler(params, NewIO(bits*2, "in"), NewIO(bits+1, "out"), - inputs, outputs) + c, err := NewCompiler(params, calloc, NewIO(bits*2, "in"), + NewIO(bits+1, "out"), inputs, outputs) if err != nil { t.Fatalf("NewCompiler: %s", err) } - cin := NewWire() + cin := calloc.Wire() NewHalfAdder(c, inputs[0], inputs[bits], outputs[0], cin) for i := 1; i < bits; i++ { @@ -64,7 +69,7 @@ func TestAdd4(t *testing.T) { if i+1 >= bits { cout = outputs[bits] } else { - cout = NewWire() + cout = calloc.Wire() } NewFullAdder(c, inputs[i], inputs[bits+i], cin, outputs[i], cout) @@ -82,7 +87,7 @@ func TestAdd4(t *testing.T) { func TestFullSubtractor(t *testing.T) { inputs := makeWires(1+2, false) outputs := makeWires(2, true) - c, err := NewCompiler(params, NewIO(1+2, "in"), NewIO(2, "out"), + c, err := NewCompiler(params, calloc, NewIO(1+2, "in"), NewIO(2, "out"), inputs, outputs) if err != nil { t.Fatalf("NewCompiler: %s", err) @@ -101,7 +106,7 @@ func TestFullSubtractor(t *testing.T) { func TestMultiply1(t *testing.T) { inputs := makeWires(2, false) outputs := makeWires(2, true) - c, err := NewCompiler(params, NewIO(2, "in"), NewIO(2, "out"), + c, err := NewCompiler(params, calloc, NewIO(2, "in"), NewIO(2, "out"), inputs, outputs) if err != nil { t.Fatalf("NewCompiler: %s", err) @@ -119,8 +124,8 @@ func TestMultiply(t *testing.T) { inputs := makeWires(bits*2, false) outputs := makeWires(bits*2, true) - c, err := NewCompiler(params, NewIO(bits*2, "in"), NewIO(bits*2, "out"), - inputs, outputs) + c, err := NewCompiler(params, calloc, NewIO(bits*2, "in"), + NewIO(bits*2, "out"), inputs, outputs) if err != nil { t.Fatalf("NewCompiler: %s", err) } diff --git a/compiler/circuits/compiler.go b/compiler/circuits/compiler.go index 61e4e93c..a6454fc9 100644 --- a/compiler/circuits/compiler.go +++ b/compiler/circuits/compiler.go @@ -21,13 +21,14 @@ type Builtin func(cc *Compiler, a, b, r []*Wire) error // Compiler implements binary circuit compiler. type Compiler struct { Params *utils.Params + Calloc *Allocator OutputsAssigned bool Inputs circuit.IO Outputs circuit.IO InputWires []*Wire OutputWires []*Wire Gates []*Gate - nextWireID uint32 + nextWireID circuit.Wire pending []*Gate assigned []*Gate compiled []circuit.Gate @@ -39,14 +40,16 @@ type Compiler struct { // NewCompiler creates a new circuit compiler for the specified // circuit input and output values. -func NewCompiler(params *utils.Params, inputs, outputs circuit.IO, - inputWires, outputWires []*Wire) (*Compiler, error) { +func NewCompiler(params *utils.Params, calloc *Allocator, + inputs, outputs circuit.IO, inputWires, outputWires []*Wire) ( + *Compiler, error) { if len(inputWires) == 0 { return nil, fmt.Errorf("no inputs defined") } return &Compiler{ Params: params, + Calloc: calloc, Inputs: inputs, Outputs: outputs, InputWires: inputWires, @@ -56,39 +59,39 @@ func NewCompiler(params *utils.Params, inputs, outputs circuit.IO, } // InvI0Wire returns a wire holding value INV(input[0]). -func (c *Compiler) InvI0Wire() *Wire { - if c.invI0Wire == nil { - c.invI0Wire = NewWire() - c.AddGate(NewINV(c.InputWires[0], c.invI0Wire)) +func (cc *Compiler) InvI0Wire() *Wire { + if cc.invI0Wire == nil { + cc.invI0Wire = cc.Calloc.Wire() + cc.AddGate(cc.Calloc.INVGate(cc.InputWires[0], cc.invI0Wire)) } - return c.invI0Wire + return cc.invI0Wire } // ZeroWire returns a wire holding value 0. -func (c *Compiler) ZeroWire() *Wire { - if c.zeroWire == nil { - c.zeroWire = NewWire() - c.AddGate(NewBinary(circuit.AND, c.InputWires[0], c.InvI0Wire(), - c.zeroWire)) - c.zeroWire.SetValue(Zero) +func (cc *Compiler) ZeroWire() *Wire { + if cc.zeroWire == nil { + cc.zeroWire = cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, cc.InputWires[0], + cc.InvI0Wire(), cc.zeroWire)) + cc.zeroWire.SetValue(Zero) } - return c.zeroWire + return cc.zeroWire } // OneWire returns a wire holding value 1. -func (c *Compiler) OneWire() *Wire { - if c.oneWire == nil { - c.oneWire = NewWire() - c.AddGate(NewBinary(circuit.OR, c.InputWires[0], c.InvI0Wire(), - c.oneWire)) - c.oneWire.SetValue(One) +func (cc *Compiler) OneWire() *Wire { + if cc.oneWire == nil { + cc.oneWire = cc.Calloc.Wire() + cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, cc.InputWires[0], + cc.InvI0Wire(), cc.oneWire)) + cc.oneWire.SetValue(One) } - return c.oneWire + return cc.oneWire } // ZeroPad pads the argument wires x and y with zero values so that // the resulting wires have the same number of bits. -func (c *Compiler) ZeroPad(x, y []*Wire) ([]*Wire, []*Wire) { +func (cc *Compiler) ZeroPad(x, y []*Wire) ([]*Wire, []*Wire) { if len(x) == len(y) { return x, y } @@ -103,7 +106,7 @@ func (c *Compiler) ZeroPad(x, y []*Wire) ([]*Wire, []*Wire) { if i < len(x) { rx[i] = x[i] } else { - rx[i] = c.ZeroWire() + rx[i] = cc.ZeroWire() } } @@ -112,7 +115,7 @@ func (c *Compiler) ZeroPad(x, y []*Wire) ([]*Wire, []*Wire) { if i < len(y) { ry[i] = y[i] } else { - ry[i] = c.ZeroWire() + ry[i] = cc.ZeroWire() } } @@ -121,59 +124,59 @@ func (c *Compiler) ZeroPad(x, y []*Wire) ([]*Wire, []*Wire) { // ShiftLeft shifts the size number of bits of the input wires w, // count bits left. -func (c *Compiler) ShiftLeft(w []*Wire, size, count int) []*Wire { +func (cc *Compiler) ShiftLeft(w []*Wire, size, count int) []*Wire { result := make([]*Wire, size) if count < size { copy(result[count:], w) } for i := 0; i < count; i++ { - result[i] = c.ZeroWire() + result[i] = cc.ZeroWire() } for i := count + len(w); i < size; i++ { - result[i] = c.ZeroWire() + result[i] = cc.ZeroWire() } return result } // INV creates an inverse wire inverting the input wire i's value to // the output wire o. -func (c *Compiler) INV(i, o *Wire) { - c.AddGate(NewBinary(circuit.XOR, i, c.OneWire(), o)) +func (cc *Compiler) INV(i, o *Wire) { + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, i, cc.OneWire(), o)) } // ID creates an identity wire passing the input wire i's value to the // output wire o. -func (c *Compiler) ID(i, o *Wire) { - c.AddGate(NewBinary(circuit.XOR, i, c.ZeroWire(), o)) +func (cc *Compiler) ID(i, o *Wire) { + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, i, cc.ZeroWire(), o)) } // AddGate adds a get into the circuit. -func (c *Compiler) AddGate(gate *Gate) { - c.Gates = append(c.Gates, gate) +func (cc *Compiler) AddGate(gate *Gate) { + cc.Gates = append(cc.Gates, gate) } // SetNextWireID sets the next unique wire ID to use. -func (c *Compiler) SetNextWireID(next uint32) { - c.nextWireID = next +func (cc *Compiler) SetNextWireID(next circuit.Wire) { + cc.nextWireID = next } // NextWireID returns the next unique wire ID. -func (c *Compiler) NextWireID() uint32 { - ret := c.nextWireID - c.nextWireID++ +func (cc *Compiler) NextWireID() circuit.Wire { + ret := cc.nextWireID + cc.nextWireID++ return ret } // ConstPropagate propagates constant wire values in the circuit and // short circuits gates if their output does not depend on the gate's // logical operation. -func (c *Compiler) ConstPropagate() { +func (cc *Compiler) ConstPropagate() { var stats circuit.Stats start := time.Now() - for _, g := range c.Gates { + for _, g := range cc.Gates { switch g.Op { case circuit.XOR: if (g.A.Value() == Zero && g.B.Value() == Zero) || @@ -251,21 +254,21 @@ func (c *Compiler) ConstPropagate() { if g.A.Value() == Zero { g.A.RemoveOutput(g) - g.A = c.ZeroWire() + g.A = cc.ZeroWire() g.A.AddOutput(g) } else if g.A.Value() == One { g.A.RemoveOutput(g) - g.A = c.OneWire() + g.A = cc.OneWire() g.A.AddOutput(g) } if g.B != nil { if g.B.Value() == Zero { g.B.RemoveOutput(g) - g.B = c.ZeroWire() + g.B = cc.ZeroWire() g.B.AddOutput(g) } else if g.B.Value() == One { g.B.RemoveOutput(g) - g.B = c.OneWire() + g.B = cc.OneWire() g.B.AddOutput(g) } } @@ -273,21 +276,21 @@ func (c *Compiler) ConstPropagate() { elapsed := time.Since(start) - if c.Params.Diagnostics && stats.Count() > 0 { + if cc.Params.Diagnostics && stats.Count() > 0 { fmt.Printf(" - ConstPropagate: %12s: %d/%d (%.2f%%)\n", - elapsed, stats.Count(), len(c.Gates), - float64(stats.Count())/float64(len(c.Gates))*100) + elapsed, stats.Count(), len(cc.Gates), + float64(stats.Count())/float64(len(cc.Gates))*100) } } // ShortCircuitXORZero short circuits input to output where input is // XOR'ed to zero. -func (c *Compiler) ShortCircuitXORZero() { +func (cc *Compiler) ShortCircuitXORZero() { var stats circuit.Stats start := time.Now() - for _, g := range c.Gates { + for _, g := range cc.Gates { if g.Op != circuit.XOR { continue } @@ -297,7 +300,7 @@ func (c *Compiler) ShortCircuitXORZero() { g.B.Input().ResetOutput(g.O) // Disconnect gate's output wire. - g.O = NewWire() + g.O = cc.Calloc.Wire() stats[g.Op]++ } @@ -307,7 +310,7 @@ func (c *Compiler) ShortCircuitXORZero() { g.A.Input().ResetOutput(g.O) // Disconnect gate's output wire. - g.O = NewWire() + g.O = cc.Calloc.Wire() stats[g.Op]++ } @@ -315,81 +318,81 @@ func (c *Compiler) ShortCircuitXORZero() { elapsed := time.Since(start) - if c.Params.Diagnostics && stats.Count() > 0 { + if cc.Params.Diagnostics && stats.Count() > 0 { fmt.Printf(" - ShortCircuitXORZero: %12s: %d/%d (%.2f%%)\n", - elapsed, stats.Count(), len(c.Gates), - float64(stats.Count())/float64(len(c.Gates))*100) + elapsed, stats.Count(), len(cc.Gates), + float64(stats.Count())/float64(len(cc.Gates))*100) } } // Prune removes all gates whose output wires are unused. -func (c *Compiler) Prune() int { +func (cc *Compiler) Prune() int { - n := make([]*Gate, len(c.Gates)) + n := make([]*Gate, len(cc.Gates)) nPos := len(n) - for i := len(c.Gates) - 1; i >= 0; i-- { - g := c.Gates[i] + for i := len(cc.Gates) - 1; i >= 0; i-- { + g := cc.Gates[i] if !g.Prune() { nPos-- n[nPos] = g } } - c.Gates = n[nPos:] + cc.Gates = n[nPos:] return nPos } // Compile compiles the circuit. -func (c *Compiler) Compile() *circuit.Circuit { - if len(c.pending) != 0 { +func (cc *Compiler) Compile() *circuit.Circuit { + if len(cc.pending) != 0 { panic("Compile: pending set") } - c.pending = make([]*Gate, 0, len(c.Gates)) - if len(c.assigned) != 0 { + cc.pending = make([]*Gate, 0, len(cc.Gates)) + if len(cc.assigned) != 0 { panic("Compile: assigned set") } - c.assigned = make([]*Gate, 0, len(c.Gates)) - if len(c.compiled) != 0 { + cc.assigned = make([]*Gate, 0, len(cc.Gates)) + if len(cc.compiled) != 0 { panic("Compile: compiled set") } - c.compiled = make([]circuit.Gate, 0, len(c.Gates)) + cc.compiled = make([]circuit.Gate, 0, len(cc.Gates)) - for _, w := range c.InputWires { - w.Assign(c) + for _, w := range cc.InputWires { + w.Assign(cc) } - for len(c.pending) > 0 { - gate := c.pending[0] - c.pending = c.pending[1:] - gate.Assign(c) + for len(cc.pending) > 0 { + gate := cc.pending[0] + cc.pending = cc.pending[1:] + gate.Assign(cc) } // Assign outputs. - for _, w := range c.OutputWires { + for _, w := range cc.OutputWires { if w.Assigned() { - if !c.OutputsAssigned { + if !cc.OutputsAssigned { panic("Output already assigned") } } else { - w.SetID(c.NextWireID()) + w.SetID(cc.NextWireID()) } } // Compile circuit. - for _, gate := range c.assigned { - gate.Compile(c) + for _, gate := range cc.assigned { + gate.Compile(cc) } var stats circuit.Stats - for _, g := range c.compiled { + for _, g := range cc.compiled { stats[g.Op]++ } result := &circuit.Circuit{ - NumGates: len(c.compiled), - NumWires: int(c.nextWireID), - Inputs: c.Inputs, - Outputs: c.Outputs, - Gates: c.compiled, + NumGates: len(cc.compiled), + NumWires: int(cc.nextWireID), + Inputs: cc.Inputs, + Outputs: cc.Outputs, + Gates: cc.compiled, Stats: stats, } diff --git a/compiler/circuits/gates.go b/compiler/circuits/gates.go index 4929fcf7..6e9154c6 100644 --- a/compiler/circuits/gates.go +++ b/compiler/circuits/gates.go @@ -25,51 +25,23 @@ type Gate struct { O *Wire } -// NewBinary creates a new binary gate. -func NewBinary(op circuit.Operation, a, b, o *Wire) *Gate { - gate := &Gate{ - Op: op, - A: a, - B: b, - O: o, - } - a.AddOutput(gate) - b.AddOutput(gate) - o.SetInput(gate) - - return gate -} - -// NewINV creates a new INV gate. -func NewINV(i, o *Wire) *Gate { - gate := &Gate{ - Op: circuit.INV, - A: i, - O: o, - } - i.AddOutput(gate) - o.SetInput(gate) - - return gate -} - func (g *Gate) String() string { return fmt.Sprintf("%s %x %x %x", g.Op, g.A.ID(), g.B.ID(), g.O.ID()) } // Visit adds gate to the list of pending gates to be compiled. -func (g *Gate) Visit(c *Compiler) { +func (g *Gate) Visit(cc *Compiler) { switch g.Op { case circuit.INV: if !g.Dead && !g.Visited && g.A.Assigned() { g.Visited = true - c.pending = append(c.pending, g) + cc.pending = append(cc.pending, g) } default: if !g.Dead && !g.Visited && g.A.Assigned() && g.B.Assigned() { g.Visited = true - c.pending = append(c.pending, g) + cc.pending = append(cc.pending, g) } } } @@ -128,29 +100,29 @@ func (g *Gate) Prune() bool { } // Assign assigns gate's output wire ID. -func (g *Gate) Assign(c *Compiler) { +func (g *Gate) Assign(cc *Compiler) { if !g.Dead { - g.O.Assign(c) - c.assigned = append(c.assigned, g) + g.O.Assign(cc) + cc.assigned = append(cc.assigned, g) } } // Compile adds gate's binary circuit into compile circuit. -func (g *Gate) Compile(c *Compiler) { +func (g *Gate) Compile(cc *Compiler) { if g.Dead || g.Compiled { return } g.Compiled = true switch g.Op { case circuit.INV: - c.compiled = append(c.compiled, circuit.Gate{ + cc.compiled = append(cc.compiled, circuit.Gate{ Input0: circuit.Wire(g.A.ID()), Output: circuit.Wire(g.O.ID()), Op: g.Op, }) default: - c.compiled = append(c.compiled, circuit.Gate{ + cc.compiled = append(cc.compiled, circuit.Gate{ Input0: circuit.Wire(g.A.ID()), Input1: circuit.Wire(g.B.ID()), Output: circuit.Wire(g.O.ID()), diff --git a/compiler/circuits/wire.go b/compiler/circuits/wire.go index f9119650..39fb1d9d 100644 --- a/compiler/circuits/wire.go +++ b/compiler/circuits/wire.go @@ -10,24 +10,24 @@ import ( "fmt" "math" - "github.com/markkurossi/mpc/types" + "github.com/markkurossi/mpc/circuit" ) const ( // UnassignedID identifies an unassigned wire ID. - UnassignedID uint32 = math.MaxUint32 - outputMask = 0b10000000000000000000000000000000 - valueMask = 0b01100000000000000000000000000000 - numMask = 0b00011111111111111111111111111111 - valueShift = 29 + UnassignedID circuit.Wire = math.MaxUint32 + outputMask = 0b10000000000000000000000000000000 + valueMask = 0b01100000000000000000000000000000 + numMask = 0b00011111111111111111111111111111 + valueShift = 29 ) // Wire implements a wire connecting binary gates. type Wire struct { - ovnum uint32 - id uint32 - input *Gate - outputs []*Gate + ovnum uint32 + id circuit.Wire + // The gates[0] is the input gate, gates[1:] are the output gates. + gates []*Gate } // WireValue defines wire values. @@ -51,41 +51,22 @@ func (v WireValue) String() string { } } -// NewWire creates an unassigned wire. -func NewWire() *Wire { - w := new(Wire) - w.Reset(UnassignedID) - return w -} - -// MakeWires creates bits number of wires. -func MakeWires(bits types.Size) []*Wire { - result := make([]*Wire, bits) - wires := make([]Wire, bits) - for i := 0; i < int(bits); i++ { - w := &wires[i] - w.id = UnassignedID - result[i] = w - } - return result -} - // Reset resets the wire with the new ID. -func (w *Wire) Reset(id uint32) { +func (w *Wire) Reset(id circuit.Wire) { w.SetOutput(false) w.SetValue(Unknown) w.SetID(id) - w.input = nil + w.SetInput(nil) w.DisconnectOutputs() } // ID returns the wire ID. -func (w *Wire) ID() uint32 { +func (w *Wire) ID() circuit.Wire { return w.id } // SetID sets the wire ID. -func (w *Wire) SetID(id uint32) { +func (w *Wire) SetID(id circuit.Wire) { w.id = id } @@ -136,56 +117,76 @@ func (w *Wire) SetNumOutputs(num uint32) { // DisconnectOutputs disconnects wire from its output gates. func (w *Wire) DisconnectOutputs() { w.SetNumOutputs(0) - w.outputs = w.outputs[0:0] + if len(w.gates) > 1 { + w.gates = w.gates[0:1] + } } func (w *Wire) String() string { return fmt.Sprintf("Wire{%x, Input:%s, Value:%s, Outputs:%v, Output=%v}", - w.ID(), w.input, w.Value(), w.outputs, w.Output()) + w.ID(), w.Input(), w.Value(), w.gates[1:], w.Output()) } // Assign assings wire ID. -func (w *Wire) Assign(c *Compiler) { +func (w *Wire) Assign(cc *Compiler) { if w.Output() { return } if !w.Assigned() { - w.id = c.NextWireID() + w.id = cc.NextWireID() } w.ForEachOutput(func(gate *Gate) { - gate.Visit(c) + gate.Visit(cc) }) } // Input returns the wire's input gate. func (w *Wire) Input() *Gate { - return w.input + if len(w.gates) == 0 { + return nil + } + return w.gates[0] } // SetInput sets the wire's input gate. func (w *Wire) SetInput(gate *Gate) { - if w.input != nil { - panic("Input gate already set") + if gate == nil { + if len(w.gates) > 0 { + w.gates[0] = nil + } + } else { + if len(w.gates) == 0 { + w.gates = append(w.gates, gate) + } else { + if w.gates[0] != nil { + panic("Input gate already set") + } + w.gates[0] = gate + } } - w.input = gate } // IsInput tests if the wire is an input wire. func (w *Wire) IsInput() bool { - return w.input == nil + return w.Input() == nil } // ForEachOutput calls the argument function for each output gate of // the wire. func (w *Wire) ForEachOutput(f func(gate *Gate)) { - for _, gate := range w.outputs { - f(gate) + if len(w.gates) > 1 { + for _, gate := range w.gates[1:] { + f(gate) + } } } // AddOutput adds gate to the wire's output gates. func (w *Wire) AddOutput(gate *Gate) { - w.outputs = append(w.outputs, gate) + if len(w.gates) == 0 { + w.gates = append(w.gates, nil) + } + w.gates = append(w.gates, gate) w.SetNumOutputs(w.NumOutputs() + 1) } diff --git a/compiler/circuits/wire_test.go b/compiler/circuits/wire_test.go index 71efa6c6..7f573974 100644 --- a/compiler/circuits/wire_test.go +++ b/compiler/circuits/wire_test.go @@ -11,7 +11,7 @@ import ( ) func TestWire(t *testing.T) { - w := NewWire() + w := calloc.Wire() if w.ID() != UnassignedID { t.Error("w.ID") } diff --git a/compiler/compiler.go b/compiler/compiler.go index 2ed7c596..5d36c2cf 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -144,7 +144,7 @@ func (c *Compiler) stream(conn *p2p.Conn, oti ot.OT, source string, fmt.Printf(" - Out: %s\n", program.Outputs) fmt.Printf(" - In: %s\n", inputFlag) - out, bits, err := program.StreamCircuit(conn, oti, c.params, input, timing) + out, bits, err := program.Stream(conn, oti, c.params, input, timing) if err != nil { return nil, nil, err } diff --git a/compiler/ssa/circuitgen.go b/compiler/ssa/circuitgen.go index c6df5fc7..a5db5ed4 100644 --- a/compiler/ssa/circuitgen.go +++ b/compiler/ssa/circuitgen.go @@ -19,7 +19,9 @@ import ( func (prog *Program) CompileCircuit(params *utils.Params) ( *circuit.Circuit, error) { - cc, err := circuits.NewCompiler(params, prog.Inputs, prog.Outputs, + calloc := circuits.NewAllocator() + + cc, err := circuits.NewCompiler(params, calloc, prog.Inputs, prog.Outputs, prog.InputWires, prog.OutputWires) if err != nil { return nil, err @@ -78,7 +80,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { instr := step.Instr var wires [][]*circuits.Wire for _, in := range instr.In { - w, err := prog.Wires(in.String(), in.Type.Bits) + w, err := prog.walloc.Wires(in, in.Type.Bits) if err != nil { return err } @@ -86,7 +88,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } switch instr.Op { case Iadd, Uadd: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -96,7 +98,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Isub, Usub: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -106,7 +108,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Imult, Umult: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -117,7 +119,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Idiv, Udiv: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -128,7 +130,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Imod, Umod: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -158,10 +160,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } o[bit] = w } - err = prog.SetWires(instr.Out.String(), o) - if err != nil { - return err - } + prog.walloc.SetWires(*instr.Out, o) case Rshift, Srshift: var signWire *circuits.Wire @@ -189,10 +188,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } o[bit] = w } - err = prog.SetWires(instr.Out.String(), o) - if err != nil { - return err - } + prog.walloc.SetWires(*instr.Out, o) case Slice: from, err := instr.In[1].ConstInt() @@ -225,13 +221,10 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { for bit := to - from; int(bit) < len(o); bit++ { o[bit] = cc.ZeroWire() } - err = prog.SetWires(instr.Out.String(), o) - if err != nil { - return err - } + prog.walloc.SetWires(*instr.Out, o) case Index: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -247,7 +240,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Ilt, Ult: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -257,7 +250,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Ile, Ule: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -267,7 +260,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Igt, Ugt: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -277,7 +270,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Ige, Uge: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -287,7 +280,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Eq: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -297,7 +290,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Neq: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -312,7 +305,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { return fmt.Errorf("%s unsupported index type %T: %s", instr.Op, instr.In[1], err) } - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -327,7 +320,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { return fmt.Errorf("%s unsupported index type %T: %s", instr.Op, instr.In[1], err) } - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -337,7 +330,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case And: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -347,7 +340,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Or: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -357,7 +350,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Band: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -367,7 +360,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Bclr: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -377,7 +370,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Bor: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -387,7 +380,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Bxor: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -415,10 +408,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } o[bit] = w } - err := prog.SetWires(instr.Out.String(), o) - if err != nil { - return err - } + prog.walloc.SetWires(*instr.Out, o) case Amov: // v arr from to: @@ -457,13 +447,10 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } o[bit] = w } - err = prog.SetWires(instr.Out.String(), o) - if err != nil { - return err - } + prog.walloc.SetWires(*instr.Out, o) case Phi: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } @@ -476,7 +463,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { // Assign output wires. for _, wg := range wires { for _, w := range wg { - o := circuits.NewWire() + o := cc.Calloc.Wire() cc.ID(w, o) cc.OutputWires = append(cc.OutputWires, o) } @@ -489,9 +476,9 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { var circWires []*circuits.Wire // Flatten input wires. - for idx, w := range wires { + for wi, w := range wires { circWires = append(circWires, w...) - for i := len(w); i < instr.Circ.Inputs[idx].Size; i++ { + for i := len(w); i < int(instr.Circ.Inputs[wi].Type.Bits); i++ { // Zeroes for unset input wires. zw := cc.ZeroWire() circWires = append(circWires, zw) @@ -502,7 +489,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { var circOut []*circuits.Wire for _, r := range instr.Ret { - o, err := prog.Wires(r.String(), r.Type.Bits) + o, err := prog.walloc.Wires(r, r.Type.Bits) if err != nil { return err } @@ -512,7 +499,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { // Add intermediate wires. nint := instr.Circ.NumWires - len(circWires) - len(circOut) for i := 0; i < nint; i++ { - circWires = append(circWires, circuits.NewWire()) + circWires = append(circWires, cc.Calloc.Wire()) } // Append output wires. @@ -522,22 +509,22 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { for _, gate := range instr.Circ.Gates { switch gate.Op { case circuit.XOR: - cc.AddGate(circuits.NewBinary(circuit.XOR, + cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, circWires[gate.Input0], circWires[gate.Input1], circWires[gate.Output])) case circuit.XNOR: - cc.AddGate(circuits.NewBinary(circuit.XNOR, + cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR, circWires[gate.Input0], circWires[gate.Input1], circWires[gate.Output])) case circuit.AND: - cc.AddGate(circuits.NewBinary(circuit.AND, + cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, circWires[gate.Input0], circWires[gate.Input1], circWires[gate.Output])) case circuit.OR: - cc.AddGate(circuits.NewBinary(circuit.OR, + cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, circWires[gate.Input0], circWires[gate.Input1], circWires[gate.Output])) @@ -549,7 +536,7 @@ func (prog *Program) Circuit(cc *circuits.Compiler) error { } case Builtin: - o, err := prog.Wires(instr.Out.String(), instr.Out.Type.Bits) + o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits) if err != nil { return err } diff --git a/compiler/ssa/instructions.go b/compiler/ssa/instructions.go index f41a79ca..c364ea72 100644 --- a/compiler/ssa/instructions.go +++ b/compiler/ssa/instructions.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2022 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -151,7 +151,7 @@ type Instr struct { Label *Block Circ *circuit.Circuit Builtin circuits.Builtin - GC string + GC *Value Ret []Value } @@ -541,10 +541,10 @@ func NewBuiltinInstr(builtin circuits.Builtin, a, b, r Value) Instr { } // NewGCInstr creates a new GC instruction. -func NewGCInstr(v string) Instr { +func NewGCInstr(v Value) Instr { return Instr{ Op: GC, - GC: v, + GC: &v, } } @@ -560,7 +560,7 @@ func (i Instr) StringTyped() string { func (i Instr) string(maxLen int, typesOnly bool) string { result := i.Op.String() - if len(i.In) == 0 && i.Out == nil && i.Label == nil && len(i.GC) == 0 { + if len(i.In) == 0 && i.Out == nil && i.Label == nil && i.GC == nil { return result } @@ -590,9 +590,9 @@ func (i Instr) string(maxLen int, typesOnly bool) string { if i.Circ != nil { result += fmt.Sprintf(" {G=%d, W=%d}", i.Circ.NumGates, i.Circ.NumWires) } - if len(i.GC) > 0 { + if i.GC != nil { result += " " - result += i.GC + result += i.GC.String() } for _, r := range i.Ret { result += " " diff --git a/compiler/ssa/program.go b/compiler/ssa/program.go index 86721a1f..ff42ca78 100644 --- a/compiler/ssa/program.go +++ b/compiler/ssa/program.go @@ -29,11 +29,8 @@ type Program struct { OutputWires []*circuits.Wire Constants map[string]ConstantInst Steps []Step - wires map[string]*wireAlloc - freeWires map[types.Size][][]*circuits.Wire - nextWireID uint32 - flHit int - flMiss int + walloc *WireAllocator + calloc *circuits.Allocator zeroWire *circuits.Wire oneWire *circuits.Wire stats circuit.Stats @@ -47,14 +44,16 @@ type Program struct { func NewProgram(params *utils.Params, in, out circuit.IO, consts map[string]ConstantInst, steps []Step) (*Program, error) { + calloc := circuits.NewAllocator() + prog := &Program{ Params: params, Inputs: in, Outputs: out, Constants: consts, Steps: steps, - wires: make(map[string]*wireAlloc), - freeWires: make(map[types.Size][][]*circuits.Wire), + walloc: NewWireAllocator(calloc), + calloc: calloc, } // Inputs into wires. @@ -62,7 +61,11 @@ func NewProgram(params *utils.Params, in, out circuit.IO, if len(arg.Name) == 0 { arg.Name = fmt.Sprintf("arg{%d}", idx) } - wires, err := prog.Wires(arg.Name, types.Size(arg.Size)) + wires, err := prog.walloc.Wires(Value{ + Name: arg.Name, + Scope: 1, // Arguments are at scope 1. + Type: arg.Type, + }, arg.Type.Bits) if err != nil { return nil, err } @@ -194,7 +197,7 @@ func (prog *Program) GC() { if !live { // Input is not live. gcs = append(gcs, Step{ - Instr: NewGCInstr(in.String()), + Instr: NewGCInstr(in), }) } } @@ -237,8 +240,7 @@ func (prog *Program) DefineConstants(zero, one *circuits.Wire) error { var constWires int for _, c := range consts { - _, ok := prog.wires[c.String()] - if ok { + if prog.walloc.Allocated(c) { continue } @@ -255,10 +257,7 @@ func (prog *Program) DefineConstants(zero, one *circuits.Wire) error { wires = append(wires, w) } - err := prog.SetWires(c.String(), wires) - if err != nil { - return err - } + prog.walloc.SetWires(c, wires) } if len(consts) > 0 && prog.Params.Verbose { fmt.Printf("Defined %d constants: %d wires\n", @@ -267,6 +266,12 @@ func (prog *Program) DefineConstants(zero, one *circuits.Wire) error { return nil } +// StreamDebug print debugging information about streaming mode. +func (prog *Program) StreamDebug() { + prog.walloc.Debug() + prog.calloc.Debug() +} + // PP pretty-prints the program to the argument io.Writer. func (prog *Program) PP(out io.Writer) { for i, in := range prog.Inputs { diff --git a/compiler/ssa/streamer.go b/compiler/ssa/streamer.go index e06d06a3..bf4093c8 100644 --- a/compiler/ssa/streamer.go +++ b/compiler/ssa/streamer.go @@ -23,8 +23,8 @@ import ( "github.com/markkurossi/tabulate" ) -// StreamCircuit streams the program circuit into the P2P connection. -func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, +// Stream streams the program circuit into the P2P connection. +func (prog *Program) Stream(conn *p2p.Conn, oti ot.OT, params *utils.Params, inputs *big.Int, timing *circuit.Timing) ( circuit.IO, []*big.Int, error) { @@ -67,9 +67,8 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, for _, w := range prog.InputWires { // Program's inputs are unassigned because parser is shared // between streaming and non-streaming modes. - w.SetID(prog.nextWireID) - prog.nextWireID++ - ids = append(ids, circuit.Wire(w.ID())) + w.SetID(prog.walloc.NextWireID()) + ids = append(ids, w.ID()) } streaming, err := circuit.NewStreaming(key[:], ids, conn) @@ -79,7 +78,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, // Select our inputs. var n1 []ot.Label - for i := 0; i < prog.Inputs[0].Size; i++ { + for i := 0; i < int(prog.Inputs[0].Type.Bits); i++ { wire := streaming.GetInput(circuit.Wire(i)) var n ot.Label @@ -115,8 +114,8 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, timing.Sample("OT Init", []string{circuit.FileSize(xfer).String()}) // Peer OTs its inputs. - err = oti.Send(streaming.GetInputs(prog.Inputs[0].Size, - prog.Inputs[1].Size)) + err = oti.Send(streaming.GetInputs(int(prog.Inputs[0].Type.Bits), + int(prog.Inputs[1].Type.Bits))) if err != nil { return nil, nil, err } @@ -141,7 +140,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, // Stream circuit. cache := make(map[string]*circuit.Circuit) - var returnIDs []uint32 + var returnIDs []circuit.Wire start := time.Now() lastReport := start @@ -151,7 +150,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, istats := make(map[string]circuit.Stats) - var wires [][]*circuits.Wire + var wires [][]circuit.Wire var iIDs, oIDs []circuit.Wire for idx, step := range prog.Steps { @@ -177,18 +176,17 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, instr := step.Instr wires = wires[:0] for _, in := range instr.In { - w, err := prog.AssignedWires(in.String(), in.Type.Bits) + w, err := prog.walloc.AssignedIDs(in, in.Type.Bits) if err != nil { return nil, nil, err } wires = append(wires, w) } - var out []*circuits.Wire + var out []circuit.Wire var err error if instr.Out != nil { - out, err = prog.AssignedWires(instr.Out.String(), - instr.Out.Type.Bits) + out, err = prog.walloc.AssignedIDs(*instr.Out, instr.Out.Type.Bits) if err != nil { return nil, nil, err } @@ -213,9 +211,9 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, fmt.Errorf("%s: negative shift count %d", instr.Op, count) } for bit := 0; bit < len(out); bit++ { - var id uint32 + var id circuit.Wire if bit-int(count) >= 0 && bit-int(count) < len(wires[0]) { - id = wires[0][bit-int(count)].ID() + id = wires[0][bit-int(count)] } else { w, err := prog.ZeroWire(conn, streaming) if err != nil { @@ -223,11 +221,11 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, } id = w.ID() } - out[bit].SetID(id) + out[bit] = id } case Rshift, Srshift: - var signWire *circuits.Wire + var signWire circuit.Wire if instr.Op == Srshift { signWire = wires[0][len(wires[0])-1] } else { @@ -235,7 +233,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, if err != nil { return nil, nil, err } - signWire = zero + signWire = zero.ID() } count, err := instr.In[1].ConstInt() if err != nil { @@ -248,13 +246,13 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, fmt.Errorf("%s: negative shift count %d", instr.Op, count) } for bit := 0; bit < len(out); bit++ { - var id uint32 + var id circuit.Wire if bit+int(count) < len(wires[0]) { - id = wires[0][bit+int(count)].ID() + id = wires[0][bit+int(count)] } else { - id = signWire.ID() + id = signWire } - out[bit].SetID(id) + out[bit] = id } case Slice: @@ -275,9 +273,9 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, instr.Op, from, to) } for bit := from; bit < to; bit++ { - var id uint32 + var id circuit.Wire if int(bit) < len(wires[0]) { - id = wires[0][bit].ID() + id = wires[0][bit] } else { w, err := prog.ZeroWire(conn, streaming) if err != nil { @@ -285,11 +283,11 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, } id = w.ID() } - out[bit-from].SetID(id) + out[bit-from] = id } case Mov, Smov: - var signWire *circuits.Wire + var signWire circuit.Wire if instr.Op == Smov { signWire = wires[0][len(wires[0])-1] } else { @@ -297,16 +295,16 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, if err != nil { return nil, nil, err } - signWire = zero + signWire = zero.ID() } for bit := types.Size(0); bit < instr.Out.Type.Bits; bit++ { - var id uint32 + var id circuit.Wire if bit < types.Size(len(wires[0])) { - id = wires[0][bit].ID() + id = wires[0][bit] } else { - id = signWire.ID() + id = signWire } - out[bit].SetID(id) + out[bit] = id } case Amov: @@ -328,10 +326,10 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, } for bit := types.Size(0); bit < instr.Out.Type.Bits; bit++ { - var id uint32 + var id circuit.Wire if bit < from || bit >= to { if bit < types.Size(len(wires[1])) { - id = wires[1][bit].ID() + id = wires[1][bit] } else { w, err := prog.ZeroWire(conn, streaming) if err != nil { @@ -342,7 +340,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, } else { idx := bit - from if idx < types.Size(len(wires[0])) { - id = wires[0][idx].ID() + id = wires[0][idx] } else { w, err := prog.ZeroWire(conn, streaming) if err != nil { @@ -351,7 +349,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, id = w.ID() } } - out[bit].SetID(id) + out[bit] = id } case Ret: @@ -360,10 +358,10 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, } for _, arg := range wires { for _, w := range arg { - if err := conn.SendUint32(int(w.ID())); err != nil { + if err := conn.SendUint32(w.Int()); err != nil { return nil, nil, err } - returnIDs = append(returnIDs, w.ID()) + returnIDs = append(returnIDs, w) } } if circuit.StreamDebug { @@ -378,25 +376,25 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, iIDs = iIDs[:0] oIDs = oIDs[:0] for i := 0; i < len(wires); i++ { - for j := 0; j < instr.Circ.Inputs[i].Size; j++ { + for j := 0; j < int(instr.Circ.Inputs[i].Type.Bits); j++ { if j < len(wires[i]) { - iIDs = append(iIDs, circuit.Wire(wires[i][j].ID())) + iIDs = append(iIDs, wires[i][j]) } else { - iIDs = append(iIDs, circuit.Wire(prog.zeroWire.ID())) + iIDs = append(iIDs, prog.zeroWire.ID()) } } } // Return wires. for i, ret := range instr.Ret { - wires, err := prog.AssignedWires(ret.String(), ret.Type.Bits) + wires, err := prog.walloc.AssignedIDs(ret, ret.Type.Bits) if err != nil { return nil, nil, err } - for j := 0; j < instr.Circ.Outputs[i].Size; j++ { + for j := 0; j < int(instr.Circ.Outputs[i].Type.Bits); j++ { if j < len(wires) { - oIDs = append(oIDs, circuit.Wire(wires[j].ID())) + oIDs = append(oIDs, wires[j]) } else { - oIDs = append(oIDs, circuit.Wire(prog.zeroWire.ID())) + oIDs = append(oIDs, prog.zeroWire.ID()) } } } @@ -416,19 +414,13 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, } case GC: - alloc, ok := prog.wires[instr.GC] - if ok { - delete(prog.wires, instr.GC) - prog.recycleWires(alloc) - } else { - fmt.Printf("GC: %s not known\n", instr.GC) - } + prog.walloc.GCWires(*instr.GC) default: f, ok := circuitGenerators[instr.Op] if !ok { return nil, nil, - fmt.Errorf("Program.Stream: %s not implemented yet", + fmt.Errorf("Program.StreamCircuit: %s not implemented yet", instr.Op) } if params.Verbose && circuit.StreamDebug { @@ -441,17 +433,18 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, startTime := time.Now() for _, in := range wires { - w := circuits.MakeWires(types.Size(len(in))) + w := prog.calloc.Wires(types.Size(len(in))) cIn = append(cIn, w) flat = append(flat, w...) } - cOut := circuits.MakeWires(instr.Out.Type.Bits) + cOut := prog.calloc.Wires(instr.Out.Type.Bits) for i := types.Size(0); i < instr.Out.Type.Bits; i++ { cOut[i].SetOutput(true) } - cc, err := circuits.NewCompiler(params, nil, nil, flat, cOut) + cc, err := circuits.NewCompiler(params, prog.calloc, nil, nil, + flat, cOut) if err != nil { return nil, nil, err } @@ -462,8 +455,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, cc.ConstPropagate() pruned := cc.Prune() if params.Verbose && circuit.StreamDebug { - fmt.Printf("%05d: - pruned %d gates\n", - idx, pruned) + fmt.Printf("%05d: - pruned %d gates\n", idx, pruned) } circ = cc.Compile() if cacheable { @@ -488,11 +480,11 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, oIDs = oIDs[:0] for _, vars := range wires { for _, w := range vars { - iIDs = append(iIDs, circuit.Wire(w.ID())) + iIDs = append(iIDs, w) } } for _, w := range out { - oIDs = append(oIDs, circuit.Wire(w.ID())) + oIDs = append(oIDs, w) } err = prog.garble(conn, streaming, idx, circ, iIDs, oIDs) @@ -539,7 +531,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, if err != nil { return nil, nil, err } - wire := streaming.GetInput(circuit.Wire(returnIDs[i])) + wire := streaming.GetInput(returnIDs[i]) var bit uint if label.Equal(wire.L0) { bit = 0 @@ -567,7 +559,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, } fmt.Printf("Max permanent wires: %d, cached circuits: %d\n", - prog.nextWireID, len(cache)) + prog.walloc.NextWireID(), len(cache)) fmt.Printf("#gates=%d (%s) #w=%d\n", prog.stats.Count(), prog.stats, prog.numWires) @@ -681,7 +673,10 @@ func (prog *Program) ZeroWire(conn *p2p.Conn, streaming *circuit.Streaming) ( *circuits.Wire, error) { if prog.zeroWire == nil { - wires, err := prog.AssignedWires("{zero}", 1) + wires, err := prog.walloc.AssignedWires(Value{ + Const: true, + Name: "{zero}", + }, 1) if err != nil { return nil, err } @@ -691,15 +686,19 @@ func (prog *Program) ZeroWire(conn *p2p.Conn, streaming *circuit.Streaming) ( Inputs: []circuit.IOArg{ { Name: "i0", - Type: "uint1", - Size: 1, + Type: types.Info{ + Type: types.TUint, + Bits: 1, + }, }, }, Outputs: []circuit.IOArg{ { Name: "o0", - Type: "uint1", - Size: 1, + Type: types.Info{ + Type: types.TUint, + Bits: 1, + }, }, }, Gates: []circuit.Gate{ @@ -713,7 +712,7 @@ func (prog *Program) ZeroWire(conn *p2p.Conn, streaming *circuit.Streaming) ( Stats: circuit.Stats{ circuit.XOR: 1, }, - }, []circuit.Wire{0}, []circuit.Wire{circuit.Wire(wires[0].ID())}) + }, []circuit.Wire{0}, []circuit.Wire{wires[0].ID()}) if err != nil { return nil, err } @@ -727,7 +726,10 @@ func (prog *Program) OneWire(conn *p2p.Conn, streaming *circuit.Streaming) ( *circuits.Wire, error) { if prog.oneWire == nil { - wires, err := prog.AssignedWires("{one}", 1) + wires, err := prog.walloc.AssignedWires(Value{ + Const: true, + Name: "{one}", + }, 1) if err != nil { return nil, err } @@ -737,15 +739,19 @@ func (prog *Program) OneWire(conn *p2p.Conn, streaming *circuit.Streaming) ( Inputs: []circuit.IOArg{ { Name: "i0", - Type: "uint1", - Size: 1, + Type: types.Info{ + Type: types.TUint, + Bits: 1, + }, }, }, Outputs: []circuit.IOArg{ { Name: "o0", - Type: "uint1", - Size: 1, + Type: types.Info{ + Type: types.TUint, + Bits: 1, + }, }, }, Gates: []circuit.Gate{ @@ -759,7 +765,7 @@ func (prog *Program) OneWire(conn *p2p.Conn, streaming *circuit.Streaming) ( Stats: circuit.Stats{ circuit.XNOR: 1, }, - }, []circuit.Wire{0}, []circuit.Wire{circuit.Wire(wires[0].ID())}) + }, []circuit.Wire{0}, []circuit.Wire{wires[0].ID()}) if err != nil { return nil, err } @@ -772,10 +778,10 @@ func sendArgument(conn *p2p.Conn, arg circuit.IOArg) error { if err := conn.SendString(arg.Name); err != nil { return err } - if err := conn.SendString(arg.Type); err != nil { + if err := conn.SendString(arg.Type.String()); err != nil { return err } - if err := conn.SendUint32(arg.Size); err != nil { + if err := conn.SendUint32(int(arg.Type.Bits)); err != nil { return err } diff --git a/compiler/ssa/value.go b/compiler/ssa/value.go index 1f5948cb..5cc73b04 100644 --- a/compiler/ssa/value.go +++ b/compiler/ssa/value.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2022 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -26,7 +26,7 @@ type Value struct { ConstValue interface{} } -// Scope defines variable scope (max 256 levels of nested blocks). +// Scope defines variable scope (max 65536 levels of nested blocks). type Scope int16 // PtrInfo defines context information for pointer values. @@ -42,6 +42,17 @@ func (ptr PtrInfo) String() string { return fmt.Sprintf("*%s@%d", ptr.Name, ptr.Scope) } +// Equal tests if this PtrInfo is equal to the argument PtrInfo. +func (ptr *PtrInfo) Equal(o *PtrInfo) bool { + if ptr == nil { + return o == nil + } + if o == nil { + return false + } + return ptr.Name == o.Name && ptr.Scope == o.Scope && ptr.Offset == o.Offset +} + // Undefined defines an undefined value. var Undefined Value @@ -121,13 +132,40 @@ func (v *Value) ConstInt() (types.Size, error) { } } +// HashCode returns a hash code for the value. +func (v *Value) HashCode() (hash int) { + for _, r := range v.Name { + hash = hash<<4 ^ int(r) ^ hash>>24 + } + hash ^= int(v.Scope) << 3 + hash ^= int(v.Version) << 1 + + if !v.Const { + hash ^= int(v.Type.Bits) << 5 + } + + if hash < 0 { + hash = -hash + } + return +} + // Equal implements BindingValue.Equal. func (v *Value) Equal(other BindingValue) bool { o, ok := other.(*Value) if !ok { return false } - return o.Name == v.Name && o.Scope == v.Scope && o.Version == v.Version + if o.Const != v.Const { + return false + } + if o.Name != v.Name || o.Scope != v.Scope || o.Version != v.Version { + return false + } + if !v.Const && v.Type.Bits != o.Type.Bits { + return false + } + return v.PtrInfo.Equal(o.PtrInfo) } // Value implements BindingValue.Value. diff --git a/compiler/ssa/value_test.go b/compiler/ssa/value_test.go new file mode 100644 index 00000000..148ee4b2 --- /dev/null +++ b/compiler/ssa/value_test.go @@ -0,0 +1,35 @@ +// +// Copyright (c) 2023 Markku Rossi +// +// All rights reserved. +// + +package ssa + +import ( + "testing" +) + +var inputs = []string{ + "$127", "$126", "$125", "$124", "$123", "$122", "$121", "$119", + "$118", "$117", "$116", "$115", "$114", "$113", "$111", "$110", + "$109", "$108", "$107", "$106", "$105", "$103", "$102", "$101", + "$100", +} + +func TestHashCode(t *testing.T) { + counts := make(map[int]int) + for _, input := range inputs { + v := Value{ + Name: input, + Const: true, + } + counts[v.HashCode()]++ + } + + for k, v := range counts { + if v > 1 { + t.Errorf("HashCode %v: count=%v\n", k, v) + } + } +} diff --git a/compiler/ssa/wire_allocator.go b/compiler/ssa/wire_allocator.go index 3e62cfcb..0110b4dc 100644 --- a/compiler/ssa/wire_allocator.go +++ b/compiler/ssa/wire_allocator.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2023 Markku Rossi +// Copyright (c) 2023 Markku Rossi // // All rights reserved. // @@ -8,134 +8,362 @@ package ssa import ( "fmt" - "sort" + "math" + "github.com/markkurossi/mpc/circuit" "github.com/markkurossi/mpc/compiler/circuits" "github.com/markkurossi/mpc/types" ) -// Wires allocates unassigned wires for the argument value. -func (prog *Program) Wires(v string, bits types.Size) ( - []*circuits.Wire, error) { +// WireAllocator implements wire allocation using Value.HashCode to +// map values to wires. +type WireAllocator struct { + calloc *circuits.Allocator + freeHdrs []*allocByValue + freeWires map[types.Size][][]*circuits.Wire + freeIDs map[types.Size][][]circuit.Wire + hash [10240]*allocByValue + nextWireID circuit.Wire + flHdrs cacheStats + flWires cacheStats + flIDs cacheStats + lookupCount int + lookupFound int +} + +type cacheStats struct { + hit int + miss int +} + +func (cs cacheStats) String() string { + total := float64(cs.hit + cs.miss) + return fmt.Sprintf("hit=%v (%.2f%%), miss=%v (%.2f%%)", + cs.hit, float64(cs.hit)/total*100, + cs.miss, float64(cs.miss)/total*100) +} + +type allocByValue struct { + next *allocByValue + key Value + base circuit.Wire + wires []*circuits.Wire + ids []circuit.Wire +} + +func (alloc *allocByValue) String() string { + return fmt.Sprintf("%v[%v]: base=%v, len(wires)=%v", + alloc.key.String(), alloc.key.Type, + alloc.base, len(alloc.wires)) +} + +// NewWireAllocator creates a new WireAllocator. +func NewWireAllocator(calloc *circuits.Allocator) *WireAllocator { + return &WireAllocator{ + calloc: calloc, + freeWires: make(map[types.Size][][]*circuits.Wire), + freeIDs: make(map[types.Size][][]circuit.Wire), + } +} + +func (walloc *WireAllocator) hashCode(v Value) int { + return v.HashCode() % len(walloc.hash) +} + +func (walloc *WireAllocator) newHeader(v Value) (ret *allocByValue) { + if len(walloc.freeHdrs) == 0 { + ret = new(allocByValue) + walloc.flHdrs.miss++ + } else { + ret = walloc.freeHdrs[len(walloc.freeHdrs)-1] + walloc.freeHdrs = walloc.freeHdrs[:len(walloc.freeHdrs)-1] + walloc.flHdrs.hit++ + } + ret.key = v + ret.base = circuits.UnassignedID + return ret +} + +func (walloc *WireAllocator) newWires(bits types.Size) ( + result []*circuits.Wire) { + + fl, ok := walloc.freeWires[bits] + if ok && len(fl) > 0 { + result = fl[len(fl)-1] + walloc.freeWires[bits] = fl[:len(fl)-1] + walloc.flWires.hit++ + } else { + result = walloc.calloc.Wires(bits) + walloc.flWires.miss++ + } + return result +} + +func (walloc *WireAllocator) newIDs(bits types.Size) (result []circuit.Wire) { + fl, ok := walloc.freeIDs[bits] + if ok && len(fl) > 0 { + result = fl[len(fl)-1] + walloc.freeIDs[bits] = fl[:len(fl)-1] + walloc.flIDs.hit++ + } else { + result = make([]circuit.Wire, bits) + for i := 0; i < int(bits); i++ { + result[i] = circuits.UnassignedID + } + walloc.flIDs.miss++ + } + return result +} + +func (walloc *WireAllocator) lookup(hash int, v Value) *allocByValue { + var count int + for ptr := &walloc.hash[hash]; *ptr != nil; ptr = &(*ptr).next { + count++ + if (*ptr).key.Equal(&v) { + alloc := *ptr + + if count > 2 { + // MRU in the hash bucket. + *ptr = alloc.next + alloc.next = walloc.hash[hash] + walloc.hash[hash] = alloc + } + + walloc.lookupCount++ + walloc.lookupFound += count + return alloc + } + } + return nil +} + +func (walloc *WireAllocator) alloc(bits types.Size, v Value, + wires, ids bool) *allocByValue { + + result := walloc.newHeader(v) + + if wires && ids { + result.wires = walloc.newWires(bits) + result.ids = walloc.newIDs(bits) + result.base = result.wires[0].ID() + + for i := 0; i < int(bits); i++ { + result.ids[i] = result.wires[i].ID() + } + } else if wires { + result.wires = walloc.newWires(bits) + result.base = result.wires[0].ID() + } else { + result.ids = walloc.newIDs(bits) + result.base = result.ids[0] + } + return result +} + +func (walloc *WireAllocator) remove(hash int, v Value) *allocByValue { + for ptr := &walloc.hash[hash]; *ptr != nil; ptr = &(*ptr).next { + if (*ptr).key.Equal(&v) { + ret := *ptr + *ptr = (*ptr).next + return ret + } + } + return nil +} + +// Allocated tests if the wires have been allocated for the value. +func (walloc *WireAllocator) Allocated(v Value) bool { + hash := walloc.hashCode(v) + alloc := walloc.lookup(hash, v) + return alloc != nil +} + +// NextWireID allocated and returns the next unassigned wire ID. +// XXX is this sync with circuits.Compiler.NextWireID()? +func (walloc *WireAllocator) NextWireID() circuit.Wire { + ret := walloc.nextWireID + walloc.nextWireID++ + return ret +} + +// AssignedIDs allocates assigned wire IDs for the argument value. +func (walloc *WireAllocator) AssignedIDs(v Value, bits types.Size) ( + []circuit.Wire, error) { if bits <= 0 { return nil, fmt.Errorf("size not set for value %v", v) } - alloc, ok := prog.wires[v] - if !ok { - alloc = prog.allocWires(bits) - prog.wires[v] = alloc + hash := walloc.hashCode(v) + alloc := walloc.lookup(hash, v) + if alloc == nil { + alloc = walloc.alloc(bits, v, false, true) + alloc.next = walloc.hash[hash] + walloc.hash[hash] = alloc + + // Assign wire IDs. + if alloc.base == circuits.UnassignedID { + alloc.base = walloc.nextWireID + for i := 0; i < int(bits); i++ { + alloc.ids[i] = walloc.nextWireID + circuit.Wire(i) + } + walloc.nextWireID += circuit.Wire(bits) + } } - return alloc.Wires, nil + if alloc.ids == nil { + alloc.ids = walloc.newIDs(bits) + for i := 0; i < int(bits); i++ { + alloc.ids[i] = alloc.wires[i].ID() + } + } + return alloc.ids, nil } // AssignedWires allocates assigned wires for the argument value. -func (prog *Program) AssignedWires(v string, bits types.Size) ( +func (walloc *WireAllocator) AssignedWires(v Value, bits types.Size) ( []*circuits.Wire, error) { if bits <= 0 { return nil, fmt.Errorf("size not set for value %v", v) } - alloc, ok := prog.wires[v] - if !ok { - alloc = prog.allocWires(bits) - prog.wires[v] = alloc + hash := walloc.hashCode(v) + alloc := walloc.lookup(hash, v) + if alloc == nil { + alloc = walloc.alloc(bits, v, true, true) + alloc.next = walloc.hash[hash] + walloc.hash[hash] = alloc // Assign wire IDs. - if alloc.Base == circuits.UnassignedID { - alloc.Base = prog.nextWireID + if alloc.base == circuits.UnassignedID { + alloc.base = walloc.nextWireID for i := 0; i < int(bits); i++ { - alloc.Wires[i].SetID(prog.nextWireID + uint32(i)) + alloc.wires[i].SetID(walloc.nextWireID + circuit.Wire(i)) } - prog.nextWireID += uint32(bits) + walloc.nextWireID += circuit.Wire(bits) } } - - return alloc.Wires, nil -} - -type wireAlloc struct { - Base uint32 - Wires []*circuits.Wire + if alloc.ids == nil { + alloc.ids = walloc.newIDs(bits) + for i := 0; i < int(bits); i++ { + alloc.ids[i] = alloc.wires[i].ID() + } + } + return alloc.wires, nil } -func (prog *Program) allocWires(bits types.Size) *wireAlloc { - result := &wireAlloc{ - Base: circuits.UnassignedID, +// GCWires recycles the wires of the argument value. The wires must +// have been previously allocated with Wires, AssignedWires, or +// SetWires; the function panics if the wires have not been allocated. +func (walloc *WireAllocator) GCWires(v Value) { + hash := walloc.hashCode(v) + alloc := walloc.remove(hash, v) + if alloc == nil { + panic(fmt.Sprintf("GC: %s not known", v)) } - fl, ok := prog.freeWires[bits] - if ok && len(fl) > 0 { - result.Wires = fl[len(fl)-1] - result.Base = result.Wires[0].ID() - prog.freeWires[bits] = fl[:len(fl)-1] - prog.flHit++ - } else { - result.Wires = circuits.MakeWires(bits) - prog.flMiss++ + if alloc.wires != nil { + if alloc.base == circuits.UnassignedID { + alloc.base = alloc.wires[0].ID() + } + // Clear wires and reassign their IDs. + for i := 0; i < len(alloc.wires); i++ { + alloc.wires[i].Reset(alloc.base + circuit.Wire(i)) + } + bits := types.Size(len(alloc.wires)) + walloc.freeWires[bits] = append(walloc.freeWires[bits], alloc.wires) + } + if alloc.ids != nil { + if alloc.base == circuits.UnassignedID { + alloc.base = alloc.ids[0] + } + // Clear IDs. + for i := 0; i < len(alloc.ids); i++ { + alloc.ids[i] = alloc.base + circuit.Wire(i) + } + bits := types.Size(len(alloc.ids)) + walloc.freeIDs[bits] = append(walloc.freeIDs[bits], alloc.ids) } - return result + alloc.next = nil + alloc.base = circuits.UnassignedID + alloc.wires = nil + alloc.ids = nil + walloc.freeHdrs = append(walloc.freeHdrs, alloc) } -func (prog *Program) recycleWires(alloc *wireAlloc) { - if alloc.Base == circuits.UnassignedID { - alloc.Base = alloc.Wires[0].ID() - } - // Clear wires and reassign their IDs. - bits := types.Size(len(alloc.Wires)) - for i := 0; i < int(bits); i++ { - alloc.Wires[i].Reset(alloc.Base + uint32(i)) +// Wires allocates unassigned wires for the argument value. +func (walloc *WireAllocator) Wires(v Value, bits types.Size) ( + []*circuits.Wire, error) { + if bits <= 0 { + return nil, fmt.Errorf("size not set for value %v", v) } - - fl := prog.freeWires[bits] - fl = append(fl, alloc.Wires) - prog.freeWires[bits] = fl - if false { - fmt.Printf("FL: %d: ", bits) - for k, v := range prog.freeWires { - fmt.Printf(" %d:%d", k, len(v)) - } - fmt.Println() + hash := walloc.hashCode(v) + alloc := walloc.lookup(hash, v) + if alloc == nil { + alloc = walloc.alloc(bits, v, true, false) + alloc.next = walloc.hash[hash] + walloc.hash[hash] = alloc } + return alloc.wires, nil } // SetWires allocates wire IDs for the value's wires. -func (prog *Program) SetWires(v string, w []*circuits.Wire) error { - _, ok := prog.wires[v] - if ok { - return fmt.Errorf("wires already set for %v", v) +func (walloc *WireAllocator) SetWires(v Value, w []*circuits.Wire) { + hash := walloc.hashCode(v) + alloc := walloc.lookup(hash, v) + if alloc != nil { + panic(fmt.Sprintf("wires already set for %v", v)) } - alloc := &wireAlloc{ - Wires: w, + alloc = &allocByValue{ + key: v, + wires: w, + ids: make([]circuit.Wire, len(w)), } if len(w) == 0 { - alloc.Base = circuits.UnassignedID + alloc.base = circuits.UnassignedID } else { - alloc.Base = w[0].ID() + alloc.base = w[0].ID() + for i := 0; i < len(w); i++ { + alloc.ids[i] = w[i].ID() + } } - prog.wires[v] = alloc - - return nil + alloc.next = walloc.hash[hash] + walloc.hash[hash] = alloc } -// StreamDebug prints debugging information about the circuit -// streaming. -func (prog *Program) StreamDebug() { - total := float64(prog.flHit + prog.flMiss) - fmt.Printf("Wire freelist: hit=%v (%.2f%%), miss=%v (%.2f%%)\n", - prog.flHit, float64(prog.flHit)/total*100, - prog.flMiss, float64(prog.flMiss)/total*100) +// Debug prints debugging information about the wire allocator. +func (walloc *WireAllocator) Debug() { + fmt.Printf("WireAllocator:\n") + fmt.Printf(" hdrs : %s\n", walloc.flHdrs) + fmt.Printf(" wires: %s\n", walloc.flWires) + fmt.Printf(" ids : %s\n", walloc.flIDs) + + var sum, max int + min := math.MaxInt - var keys []types.Size - for k := range prog.freeWires { - keys = append(keys, k) + var maxIndex int + + for i := 0; i < len(walloc.hash); i++ { + var count int + for alloc := walloc.hash[i]; alloc != nil; alloc = alloc.next { + count++ + } + sum += count + if count < min { + min = count + } + if count > max { + max = count + maxIndex = i + } } - sort.Slice(keys, func(i, j int) bool { - return keys[i] < keys[j] - }) + fmt.Printf("Hash: min=%v, max=%v, avg=%.4f, lookup=%v (avg=%.4f)\n", + min, max, float64(sum)/float64(len(walloc.hash)), + walloc.lookupCount, + float64(walloc.lookupFound)/float64(walloc.lookupCount)) - for _, k := range keys { - fmt.Printf(" %d:\t%d\n", k, len(prog.freeWires[types.Size(k)])) + if false { + fmt.Printf("Max bucket:\n") + for alloc := walloc.hash[maxIndex]; alloc != nil; alloc = alloc.next { + fmt.Printf(" %v: %v\n", alloc.key.String(), len(alloc.wires)) + } } - fmt.Println() } diff --git a/types/parse.go b/types/parse.go index d16da956..13ec3b54 100644 --- a/types/parse.go +++ b/types/parse.go @@ -1,7 +1,7 @@ // // parse.go // -// Copyright (c) 2021 Markku Rossi +// Copyright (c) 2021-2023 Markku Rossi // // All rights reserved. // @@ -49,8 +49,11 @@ func Parse(val string) (info Info, err error) { case "s", "string": info.Type = TString + case "struct": + info.Type = TStruct + default: - return info, fmt.Errorf("unknown type: %s", val) + return info, fmt.Errorf("types.Parse: unknown type: %s", val) } var bits int64 if len(m[2]) > 0 { @@ -66,7 +69,7 @@ func Parse(val string) (info Info, err error) { m = reArr.FindStringSubmatch(val) if m == nil { - return info, fmt.Errorf("unknown type: %s", val) + return info, fmt.Errorf("types.Parse: unknown type: %s", val) } var elType Info elType, err = Parse(m[2])