diff --git a/sha3/sha3_test.go b/sha3/sha3_test.go index 21e8cbad7b..8347b60d10 100644 --- a/sha3/sha3_test.go +++ b/sha3/sha3_test.go @@ -17,6 +17,7 @@ import ( "encoding/json" "fmt" "hash" + "io" "math/rand" "os" "strings" @@ -375,6 +376,116 @@ func TestClone(t *testing.T) { } } +func TestCSHAKEAccumulated(t *testing.T) { + // Generated with pycryptodome@3.20.0 + // + // from Crypto.Hash import cSHAKE128 + // rng = cSHAKE128.new() + // acc = cSHAKE128.new() + // for n in range(200): + // N = rng.read(n) + // for s in range(200): + // S = rng.read(s) + // c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N) + // c.update(rng.read(100)) + // acc.update(c.read(200)) + // c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N) + // c.update(rng.read(168)) + // acc.update(c.read(200)) + // c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N) + // c.update(rng.read(200)) + // acc.update(c.read(200)) + // print(acc.read(32).hex()) + // + // and with @noble/hashes@v1.5.0 + // + // import { bytesToHex } from "@noble/hashes/utils"; + // import { cshake128 } from "@noble/hashes/sha3-addons"; + // const rng = cshake128.create(); + // const acc = cshake128.create(); + // for (let n = 0; n < 200; n++) { + // const N = rng.xof(n); + // for (let s = 0; s < 200; s++) { + // const S = rng.xof(s); + // let c = cshake128.create({ NISTfn: N, personalization: S }); + // c.update(rng.xof(100)); + // acc.update(c.xof(200)); + // c = cshake128.create({ NISTfn: N, personalization: S }); + // c.update(rng.xof(168)); + // acc.update(c.xof(200)); + // c = cshake128.create({ NISTfn: N, personalization: S }); + // c.update(rng.xof(200)); + // acc.update(c.xof(200)); + // } + // } + // console.log(bytesToHex(acc.xof(32))); + // + t.Run("cSHAKE128", func(t *testing.T) { + testCSHAKEAccumulated(t, NewCShake128, rate128, + "bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252") + }) + t.Run("cSHAKE256", func(t *testing.T) { + testCSHAKEAccumulated(t, NewCShake256, rate256, + "0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef") + }) +} + +func testCSHAKEAccumulated(t *testing.T, newCShake func(N, S []byte) ShakeHash, rate int64, exp string) { + rnd := newCShake(nil, nil) + acc := newCShake(nil, nil) + for n := 0; n < 200; n++ { + N := make([]byte, n) + rnd.Read(N) + for s := 0; s < 200; s++ { + S := make([]byte, s) + rnd.Read(S) + + c := newCShake(N, S) + io.CopyN(c, rnd, 100 /* < rate */) + io.CopyN(acc, c, 200) + + c.Reset() + io.CopyN(c, rnd, rate) + io.CopyN(acc, c, 200) + + c.Reset() + io.CopyN(c, rnd, 200 /* > rate */) + io.CopyN(acc, c, 200) + } + } + if got := hex.EncodeToString(acc.Sum(nil)[:32]); got != exp { + t.Errorf("got %s, want %s", got, exp) + } +} + +func TestCSHAKELargeS(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + // See https://go.dev/issue/66232. + const s = (1<<32)/8 + 1000 // s * 8 > 2^32 + S := make([]byte, s) + rnd := NewShake128() + rnd.Read(S) + c := NewCShake128(nil, S) + io.CopyN(c, rnd, 1000) + + // Generated with pycryptodome@3.20.0 + // + // from Crypto.Hash import cSHAKE128 + // rng = cSHAKE128.new() + // S = rng.read(536871912) + // c = cSHAKE128.new(custom=S) + // c.update(rng.read(1000)) + // print(c.read(32).hex()) + // + exp := "2cb9f237767e98f2614b8779cf096a52da9b3a849280bbddec820771ae529cf0" + if got := hex.EncodeToString(c.Sum(nil)); got != exp { + t.Errorf("got %s, want %s", got, exp) + } +} + // BenchmarkPermutationFunction measures the speed of the permutation function // with no input data. func BenchmarkPermutationFunction(b *testing.B) { diff --git a/sha3/shake.go b/sha3/shake.go index a01ef43577..6d75811fd4 100644 --- a/sha3/shake.go +++ b/sha3/shake.go @@ -19,6 +19,7 @@ import ( "encoding/binary" "hash" "io" + "math/bits" ) // ShakeHash defines the interface to hash functions that support @@ -58,33 +59,33 @@ const ( rate256 = 136 ) -func bytepad(input []byte, w int) []byte { - // leftEncode always returns max 9 bytes - buf := make([]byte, 0, 9+len(input)+w) - buf = append(buf, leftEncode(uint64(w))...) - buf = append(buf, input...) - padlen := w - (len(buf) % w) - return append(buf, make([]byte, padlen)...) -} - -func leftEncode(value uint64) []byte { - var b [9]byte - binary.BigEndian.PutUint64(b[1:], value) - // Trim all but last leading zero bytes - i := byte(1) - for i < 8 && b[i] == 0 { - i++ +func bytepad(data []byte, rate int) []byte { + out := make([]byte, 0, 9+len(data)+rate-1) + out = append(out, leftEncode(uint64(rate))...) + out = append(out, data...) + if padlen := rate - len(out)%rate; padlen < rate { + out = append(out, make([]byte, padlen)...) } - // Prepend number of encoded bytes - b[i-1] = 9 - i - return b[i-1:] + return out +} + +func leftEncode(x uint64) []byte { + // Let n be the smallest positive integer for which 2^(8n) > x. + n := (bits.Len64(x) + 7) / 8 + if n == 0 { + n = 1 + } + // Return n || x with n as a byte and x an n bytes in big-endian order. + b := make([]byte, 9) + binary.BigEndian.PutUint64(b[1:], x) + b = b[9-n-1:] + b[0] = byte(n) + return b } func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash { c := cshakeState{state: &state{rate: rate, outputLen: outputLen, dsbyte: dsbyte}} - - // leftEncode returns max 9 bytes - c.initBlock = make([]byte, 0, 9*2+len(N)+len(S)) + c.initBlock = make([]byte, 0, 9+len(N)+9+len(S)) // leftEncode returns max 9 bytes c.initBlock = append(c.initBlock, leftEncode(uint64(len(N))*8)...) c.initBlock = append(c.initBlock, N...) c.initBlock = append(c.initBlock, leftEncode(uint64(len(S))*8)...)