Skip to content

Commit

Permalink
Feat: settable hasher for MiMC (#1345)
Browse files Browse the repository at this point in the history
Co-authored-by: Ivo Kubjas <[email protected]>
  • Loading branch information
AlexandreBelling and ivokub authored Dec 17, 2024
1 parent 1f944e8 commit f3d9199
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 0 deletions.
15 changes: 15 additions & 0 deletions std/hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ type FieldHasher interface {
Reset()
}

// StateStorer allows to store and retrieve the state of a hash function.
type StateStorer interface {
FieldHasher
// State retrieves the current state of the hash function. Calling this
// method should not destroy the current state and allow continue the use of
// the current hasher. The number of returned Variable is implementation
// dependent.
State() []frontend.Variable
// SetState sets the state of the hash function from a previously stored
// state retrieved using [StateStorer.State] method. The implementation
// returns an error if the number of supplied Variable does not match the
// number of Variable expected.
SetState(state []frontend.Variable) error
}

var (
builderRegistry = make(map[string]func(api frontend.API) (FieldHasher, error))
lock sync.RWMutex
Expand Down
25 changes: 25 additions & 0 deletions std/hash/mimc/mimc.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ func (h *MiMC) Reset() {
h.h = 0
}

// SetState manually sets the state of the hasher to the provided value. In the
// case of MiMC only a single frontend variable is expected to represent the
// state.
func (h *MiMC) SetState(newState []frontend.Variable) error {

if len(h.data) > 0 {
return errors.New("the hasher is not in an initial state")
}

if len(newState) != 1 {
return errors.New("the MiMC hasher expects a single field element to represent the state")
}

h.h = newState[0]
h.data = nil
return nil
}

// State returns the inner-state of the hasher. In the context of MiMC only a
// single field element is returned.
func (h *MiMC) State() []frontend.Variable {
h.Sum() // this flushes the unsummed data
return []frontend.Variable{h.h}
}

// Sum hash using [Miyaguchi–Preneel] where the XOR operation is replaced by
// field addition.
//
Expand Down
128 changes: 128 additions & 0 deletions std/hash/mimc/mimc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
package mimc

import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"testing"

Expand Down Expand Up @@ -80,3 +83,128 @@ func TestMimcAll(t *testing.T) {
}

}

// stateStoreCircuit checks that SetState works as expected. The circuit, however
// does not check the correctness of the hashes returned by the MiMC function
// as there is another test already testing this property.
type stateStoreTestCircuit struct {
X frontend.Variable
}

func (s *stateStoreTestCircuit) Define(api frontend.API) error {

hsh1, err1 := NewMiMC(api)
hsh2, err2 := NewMiMC(api)

if err1 != nil || err2 != nil {
return fmt.Errorf("could not instantiate the MIMC hasher: %w", errors.Join(err1, err2))
}

// This pre-shuffle the hasher state so that the test does not start from
// a zero state.
hsh1.Write(s.X)

state := hsh1.State()
hsh2.SetState(state)

hsh1.Write(s.X)
hsh2.Write(s.X)

var (
dig1 = hsh1.Sum()
dig2 = hsh2.Sum()
newState1 = hsh1.State()
newState2 = hsh2.State()
)

api.AssertIsEqual(dig1, dig2)

for i := range newState1 {
api.AssertIsEqual(newState1[i], newState2[i])
}

return nil
}

func TestStateStoreMiMC(t *testing.T) {

assert := test.NewAssert(t)

curves := map[ecc.ID]hash.Hash{
ecc.BN254: hash.MIMC_BN254,
ecc.BLS12_381: hash.MIMC_BLS12_381,
ecc.BLS12_377: hash.MIMC_BLS12_377,
ecc.BW6_761: hash.MIMC_BW6_761,
ecc.BW6_633: hash.MIMC_BW6_633,
ecc.BLS24_315: hash.MIMC_BLS24_315,
ecc.BLS24_317: hash.MIMC_BLS24_317,
}

for curve := range curves {

// minimal cs res = hash(data)
var (
circuit = &stateStoreTestCircuit{}
assignment = &stateStoreTestCircuit{X: 2}
)

assert.CheckCircuit(circuit,
test.WithValidAssignment(assignment),
test.WithCurves(curve))
}
}

type recoveredStateTestCircuit struct {
State []frontend.Variable
Input frontend.Variable
Expected frontend.Variable `gnark:",public"`
}

func (c *recoveredStateTestCircuit) Define(api frontend.API) error {
h, err := NewMiMC(api)
if err != nil {
return fmt.Errorf("initialize hash: %w", err)
}
if err = h.SetState(c.State); err != nil {
return fmt.Errorf("set state: %w", err)
}
h.Write(c.Input)
res := h.Sum()
api.AssertIsEqual(res, c.Expected)
return nil
}

func TestHasherFromState(t *testing.T) {
assert := test.NewAssert(t)

hashes := map[ecc.ID]hash.Hash{
ecc.BN254: hash.MIMC_BN254,
ecc.BLS12_381: hash.MIMC_BLS12_381,
ecc.BLS12_377: hash.MIMC_BLS12_377,
ecc.BW6_761: hash.MIMC_BW6_761,
ecc.BW6_633: hash.MIMC_BW6_633,
ecc.BLS24_315: hash.MIMC_BLS24_315,
ecc.BLS24_317: hash.MIMC_BLS24_317,
}

for cc, hh := range hashes {
hasher := hh.New()
ss, ok := hasher.(hash.StateStorer)
assert.True(ok)
_, err := ss.Write([]byte("hello world"))
assert.NoError(err)
state := ss.State()
nbBytes := cc.ScalarField().BitLen() / 8
buf := make([]byte, nbBytes)
_, err = rand.Read(buf)
assert.NoError(err)
ss.Write(buf)
expected := ss.Sum(nil)
bstate := new(big.Int).SetBytes(state)
binput := new(big.Int).SetBytes(buf)
assert.CheckCircuit(
&recoveredStateTestCircuit{State: make([]frontend.Variable, 1)},
test.WithValidAssignment(&recoveredStateTestCircuit{State: []frontend.Variable{bstate}, Input: binput, Expected: expected}),
test.WithCurves(cc))
}
}

0 comments on commit f3d9199

Please sign in to comment.