diff --git a/std/hash/hash.go b/std/hash/hash.go index d537e7f56..b86336038 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -27,6 +27,21 @@ type FieldHasher interface { Reset() } +// StateStorer allows to store and retrieve the state of a hash function. +type StateStorer interface { + FieldHasher + // State retrieves the current state of the hash function. Calling this + // method should not destroy the current state and allow continue the use of + // the current hasher. The number of returned Variable is implementation + // dependent. + State() []frontend.Variable + // SetState sets the state of the hash function from a previously stored + // state retrieved using [StateStorer.State] method. The implementation + // returns an error if the number of supplied Variable does not match the + // number of Variable expected. + SetState(state []frontend.Variable) error +} + var ( builderRegistry = make(map[string]func(api frontend.API) (FieldHasher, error)) lock sync.RWMutex diff --git a/std/hash/mimc/mimc.go b/std/hash/mimc/mimc.go index 69ff66191..bc2daaaa4 100644 --- a/std/hash/mimc/mimc.go +++ b/std/hash/mimc/mimc.go @@ -48,6 +48,31 @@ func (h *MiMC) Reset() { h.h = 0 } +// SetState manually sets the state of the hasher to the provided value. In the +// case of MiMC only a single frontend variable is expected to represent the +// state. +func (h *MiMC) SetState(newState []frontend.Variable) error { + + if len(h.data) > 0 { + return errors.New("the hasher is not in an initial state") + } + + if len(newState) != 1 { + return errors.New("the MiMC hasher expects a single field element to represent the state") + } + + h.h = newState[0] + h.data = nil + return nil +} + +// State returns the inner-state of the hasher. In the context of MiMC only a +// single field element is returned. +func (h *MiMC) State() []frontend.Variable { + h.Sum() // this flushes the unsummed data + return []frontend.Variable{h.h} +} + // Sum hash using [Miyaguchi–Preneel] where the XOR operation is replaced by // field addition. // diff --git a/std/hash/mimc/mimc_test.go b/std/hash/mimc/mimc_test.go index e698a95f6..019862cb3 100644 --- a/std/hash/mimc/mimc_test.go +++ b/std/hash/mimc/mimc_test.go @@ -4,6 +4,9 @@ package mimc import ( + "crypto/rand" + "errors" + "fmt" "math/big" "testing" @@ -80,3 +83,128 @@ func TestMimcAll(t *testing.T) { } } + +// stateStoreCircuit checks that SetState works as expected. The circuit, however +// does not check the correctness of the hashes returned by the MiMC function +// as there is another test already testing this property. +type stateStoreTestCircuit struct { + X frontend.Variable +} + +func (s *stateStoreTestCircuit) Define(api frontend.API) error { + + hsh1, err1 := NewMiMC(api) + hsh2, err2 := NewMiMC(api) + + if err1 != nil || err2 != nil { + return fmt.Errorf("could not instantiate the MIMC hasher: %w", errors.Join(err1, err2)) + } + + // This pre-shuffle the hasher state so that the test does not start from + // a zero state. + hsh1.Write(s.X) + + state := hsh1.State() + hsh2.SetState(state) + + hsh1.Write(s.X) + hsh2.Write(s.X) + + var ( + dig1 = hsh1.Sum() + dig2 = hsh2.Sum() + newState1 = hsh1.State() + newState2 = hsh2.State() + ) + + api.AssertIsEqual(dig1, dig2) + + for i := range newState1 { + api.AssertIsEqual(newState1[i], newState2[i]) + } + + return nil +} + +func TestStateStoreMiMC(t *testing.T) { + + assert := test.NewAssert(t) + + curves := map[ecc.ID]hash.Hash{ + ecc.BN254: hash.MIMC_BN254, + ecc.BLS12_381: hash.MIMC_BLS12_381, + ecc.BLS12_377: hash.MIMC_BLS12_377, + ecc.BW6_761: hash.MIMC_BW6_761, + ecc.BW6_633: hash.MIMC_BW6_633, + ecc.BLS24_315: hash.MIMC_BLS24_315, + ecc.BLS24_317: hash.MIMC_BLS24_317, + } + + for curve := range curves { + + // minimal cs res = hash(data) + var ( + circuit = &stateStoreTestCircuit{} + assignment = &stateStoreTestCircuit{X: 2} + ) + + assert.CheckCircuit(circuit, + test.WithValidAssignment(assignment), + test.WithCurves(curve)) + } +} + +type recoveredStateTestCircuit struct { + State []frontend.Variable + Input frontend.Variable + Expected frontend.Variable `gnark:",public"` +} + +func (c *recoveredStateTestCircuit) Define(api frontend.API) error { + h, err := NewMiMC(api) + if err != nil { + return fmt.Errorf("initialize hash: %w", err) + } + if err = h.SetState(c.State); err != nil { + return fmt.Errorf("set state: %w", err) + } + h.Write(c.Input) + res := h.Sum() + api.AssertIsEqual(res, c.Expected) + return nil +} + +func TestHasherFromState(t *testing.T) { + assert := test.NewAssert(t) + + hashes := map[ecc.ID]hash.Hash{ + ecc.BN254: hash.MIMC_BN254, + ecc.BLS12_381: hash.MIMC_BLS12_381, + ecc.BLS12_377: hash.MIMC_BLS12_377, + ecc.BW6_761: hash.MIMC_BW6_761, + ecc.BW6_633: hash.MIMC_BW6_633, + ecc.BLS24_315: hash.MIMC_BLS24_315, + ecc.BLS24_317: hash.MIMC_BLS24_317, + } + + for cc, hh := range hashes { + hasher := hh.New() + ss, ok := hasher.(hash.StateStorer) + assert.True(ok) + _, err := ss.Write([]byte("hello world")) + assert.NoError(err) + state := ss.State() + nbBytes := cc.ScalarField().BitLen() / 8 + buf := make([]byte, nbBytes) + _, err = rand.Read(buf) + assert.NoError(err) + ss.Write(buf) + expected := ss.Sum(nil) + bstate := new(big.Int).SetBytes(state) + binput := new(big.Int).SetBytes(buf) + assert.CheckCircuit( + &recoveredStateTestCircuit{State: make([]frontend.Variable, 1)}, + test.WithValidAssignment(&recoveredStateTestCircuit{State: []frontend.Variable{bstate}, Input: binput, Expected: expected}), + test.WithCurves(cc)) + } +}