Skip to content

Commit

Permalink
sha3: add MarshalBinary, AppendBinary, and UnmarshalBinary
Browse files Browse the repository at this point in the history
Fixes golang/go#24617

Change-Id: I1d9d529950aa8a5953435e8d3412cda44b075d55
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/616635
Reviewed-by: Roland Shoemaker <[email protected]>
Auto-Submit: Filippo Valsorda <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Daniel McCarney <[email protected]>
Reviewed-by: Michael Pratt <[email protected]>
  • Loading branch information
FiloSottile authored and gopherbot committed Oct 22, 2024
1 parent 36b1725 commit 750a45f
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 20 deletions.
4 changes: 4 additions & 0 deletions sha3/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
// Package sha3 implements the SHA-3 fixed-output-length hash functions and
// the SHAKE variable-output-length hash functions defined by FIPS-202.
//
// All types in this package also implement [encoding.BinaryMarshaler],
// [encoding.BinaryAppender] and [encoding.BinaryUnmarshaler] to marshal and
// unmarshal the internal state of the hash.
//
// Both types of hash function use the "sponge" construction and the Keccak
// permutation. For a detailed specification see http://keccak.noekeon.org/
//
Expand Down
31 changes: 25 additions & 6 deletions sha3/hashes.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,52 @@ func init() {
crypto.RegisterHash(crypto.SHA3_512, New512)
}

const (
dsbyteSHA3 = 0b00000110
dsbyteKeccak = 0b00000001
dsbyteShake = 0b00011111
dsbyteCShake = 0b00000100

// rateK[c] is the rate in bytes for Keccak[c] where c is the capacity in
// bits. Given the sponge size is 1600 bits, the rate is 1600 - c bits.
rateK256 = (1600 - 256) / 8
rateK448 = (1600 - 448) / 8
rateK512 = (1600 - 512) / 8
rateK768 = (1600 - 768) / 8
rateK1024 = (1600 - 1024) / 8
)

func new224Generic() *state {
return &state{rate: 144, outputLen: 28, dsbyte: 0x06}
return &state{rate: rateK448, outputLen: 28, dsbyte: dsbyteSHA3}
}

func new256Generic() *state {
return &state{rate: 136, outputLen: 32, dsbyte: 0x06}
return &state{rate: rateK512, outputLen: 32, dsbyte: dsbyteSHA3}
}

func new384Generic() *state {
return &state{rate: 104, outputLen: 48, dsbyte: 0x06}
return &state{rate: rateK768, outputLen: 48, dsbyte: dsbyteSHA3}
}

func new512Generic() *state {
return &state{rate: 72, outputLen: 64, dsbyte: 0x06}
return &state{rate: rateK1024, outputLen: 64, dsbyte: dsbyteSHA3}
}

// NewLegacyKeccak256 creates a new Keccak-256 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New256 instead.
func NewLegacyKeccak256() hash.Hash { return &state{rate: 136, outputLen: 32, dsbyte: 0x01} }
func NewLegacyKeccak256() hash.Hash {
return &state{rate: rateK512, outputLen: 32, dsbyte: dsbyteKeccak}
}

// NewLegacyKeccak512 creates a new Keccak-512 hash.
//
// Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New512 instead.
func NewLegacyKeccak512() hash.Hash { return &state{rate: 72, outputLen: 64, dsbyte: 0x01} }
func NewLegacyKeccak512() hash.Hash {
return &state{rate: rateK1024, outputLen: 64, dsbyte: dsbyteKeccak}
}

// Sum224 returns the SHA3-224 digest of the data.
func Sum224(data []byte) (digest [28]byte) {
Expand Down
72 changes: 72 additions & 0 deletions sha3/sha3.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package sha3
import (
"crypto/subtle"
"encoding/binary"
"errors"
"unsafe"

"golang.org/x/sys/cpu"
Expand Down Expand Up @@ -170,3 +171,74 @@ func (d *state) Sum(in []byte) []byte {
dup.Read(hash)
return append(in, hash...)
}

const (
magicSHA3 = "sha\x08"
magicShake = "sha\x09"
magicCShake = "sha\x0a"
magicKeccak = "sha\x0b"
// magic || rate || main state || n || sponge direction
marshaledSize = len(magicSHA3) + 1 + 200 + 1 + 1
)

func (d *state) MarshalBinary() ([]byte, error) {
return d.AppendBinary(make([]byte, 0, marshaledSize))
}

func (d *state) AppendBinary(b []byte) ([]byte, error) {
switch d.dsbyte {
case dsbyteSHA3:
b = append(b, magicSHA3...)
case dsbyteShake:
b = append(b, magicShake...)
case dsbyteCShake:
b = append(b, magicCShake...)
case dsbyteKeccak:
b = append(b, magicKeccak...)
default:
panic("unknown dsbyte")
}
// rate is at most 168, and n is at most rate.
b = append(b, byte(d.rate))
b = append(b, d.a[:]...)
b = append(b, byte(d.n), byte(d.state))
return b, nil
}

func (d *state) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize {
return errors.New("sha3: invalid hash state")
}

magic := string(b[:len(magicSHA3)])
b = b[len(magicSHA3):]
switch {
case magic == magicSHA3 && d.dsbyte == dsbyteSHA3:
case magic == magicShake && d.dsbyte == dsbyteShake:
case magic == magicCShake && d.dsbyte == dsbyteCShake:
case magic == magicKeccak && d.dsbyte == dsbyteKeccak:
default:
return errors.New("sha3: invalid hash state identifier")
}

rate := int(b[0])
b = b[1:]
if rate != d.rate {
return errors.New("sha3: invalid hash state function")
}

copy(d.a[:], b)
b = b[len(d.a):]

n, state := int(b[0]), spongeDirection(b[1])
if n > d.rate {
return errors.New("sha3: invalid hash state")
}
d.n = n
if state != spongeAbsorbing && state != spongeSqueezing {
return errors.New("sha3: invalid hash state")
}
d.state = state

return nil
}
42 changes: 40 additions & 2 deletions sha3/sha3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package sha3
import (
"bytes"
"compress/flate"
"encoding"
"encoding/hex"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -421,11 +422,11 @@ func TestCSHAKEAccumulated(t *testing.T) {
// console.log(bytesToHex(acc.xof(32)));
//
t.Run("cSHAKE128", func(t *testing.T) {
testCSHAKEAccumulated(t, NewCShake128, rate128,
testCSHAKEAccumulated(t, NewCShake128, rateK256,
"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
})
t.Run("cSHAKE256", func(t *testing.T) {
testCSHAKEAccumulated(t, NewCShake256, rate256,
testCSHAKEAccumulated(t, NewCShake256, rateK512,
"0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef")
})
}
Expand Down Expand Up @@ -486,6 +487,43 @@ func TestCSHAKELargeS(t *testing.T) {
}
}

func TestMarshalUnmarshal(t *testing.T) {
t.Run("SHA3-224", func(t *testing.T) { testMarshalUnmarshal(t, New224()) })
t.Run("SHA3-256", func(t *testing.T) { testMarshalUnmarshal(t, New256()) })
t.Run("SHA3-384", func(t *testing.T) { testMarshalUnmarshal(t, New384()) })
t.Run("SHA3-512", func(t *testing.T) { testMarshalUnmarshal(t, New512()) })
t.Run("SHAKE128", func(t *testing.T) { testMarshalUnmarshal(t, NewShake128()) })
t.Run("SHAKE256", func(t *testing.T) { testMarshalUnmarshal(t, NewShake256()) })
t.Run("cSHAKE128", func(t *testing.T) { testMarshalUnmarshal(t, NewCShake128([]byte("N"), []byte("S"))) })
t.Run("cSHAKE256", func(t *testing.T) { testMarshalUnmarshal(t, NewCShake256([]byte("N"), []byte("S"))) })
t.Run("Keccak-256", func(t *testing.T) { testMarshalUnmarshal(t, NewLegacyKeccak256()) })
t.Run("Keccak-512", func(t *testing.T) { testMarshalUnmarshal(t, NewLegacyKeccak512()) })
}

// TODO(filippo): move this to crypto/internal/cryptotest.
func testMarshalUnmarshal(t *testing.T, h hash.Hash) {
buf := make([]byte, 200)
rand.Read(buf)
n := rand.Intn(200)
h.Write(buf)
want := h.Sum(nil)
h.Reset()
h.Write(buf[:n])
b, err := h.(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
t.Errorf("MarshalBinary: %v", err)
}
h.Write(bytes.Repeat([]byte{0}, 200))
if err := h.(encoding.BinaryUnmarshaler).UnmarshalBinary(b); err != nil {
t.Errorf("UnmarshalBinary: %v", err)
}
h.Write(buf[n:])
got := h.Sum(nil)
if !bytes.Equal(got, want) {
t.Errorf("got %x, want %x", got, want)
}
}

// BenchmarkPermutationFunction measures the speed of the permutation function
// with no input data.
func BenchmarkPermutationFunction(b *testing.B) {
Expand Down
42 changes: 30 additions & 12 deletions sha3/shake.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ package sha3
// [2] https://doi.org/10.6028/NIST.SP.800-185

import (
"bytes"
"encoding/binary"
"errors"
"hash"
"io"
"math/bits"
Expand Down Expand Up @@ -51,14 +53,6 @@ type cshakeState struct {
initBlock []byte
}

// Consts for configuring initial SHA-3 state
const (
dsbyteShake = 0x1f
dsbyteCShake = 0x04
rate128 = 168
rate256 = 136
)

func bytepad(data []byte, rate int) []byte {
out := make([]byte, 0, 9+len(data)+rate-1)
out = append(out, leftEncode(uint64(rate))...)
Expand Down Expand Up @@ -112,6 +106,30 @@ func (c *state) Clone() ShakeHash {
return c.clone()
}

func (c *cshakeState) MarshalBinary() ([]byte, error) {
return c.AppendBinary(make([]byte, 0, marshaledSize+len(c.initBlock)))
}

func (c *cshakeState) AppendBinary(b []byte) ([]byte, error) {
b, err := c.state.AppendBinary(b)
if err != nil {
return nil, err
}
b = append(b, c.initBlock...)
return b, nil
}

func (c *cshakeState) UnmarshalBinary(b []byte) error {
if len(b) <= marshaledSize {
return errors.New("sha3: invalid hash state")
}
if err := c.state.UnmarshalBinary(b[:marshaledSize]); err != nil {
return err
}
c.initBlock = bytes.Clone(b[marshaledSize:])
return nil
}

// NewShake128 creates a new SHAKE128 variable-output-length ShakeHash.
// Its generic security strength is 128 bits against all attacks if at
// least 32 bytes of its output are used.
Expand All @@ -127,11 +145,11 @@ func NewShake256() ShakeHash {
}

func newShake128Generic() *state {
return &state{rate: rate128, outputLen: 32, dsbyte: dsbyteShake}
return &state{rate: rateK256, outputLen: 32, dsbyte: dsbyteShake}
}

func newShake256Generic() *state {
return &state{rate: rate256, outputLen: 64, dsbyte: dsbyteShake}
return &state{rate: rateK512, outputLen: 64, dsbyte: dsbyteShake}
}

// NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash,
Expand All @@ -144,7 +162,7 @@ func NewCShake128(N, S []byte) ShakeHash {
if len(N) == 0 && len(S) == 0 {
return NewShake128()
}
return newCShake(N, S, rate128, 32, dsbyteCShake)
return newCShake(N, S, rateK256, 32, dsbyteCShake)
}

// NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash,
Expand All @@ -157,7 +175,7 @@ func NewCShake256(N, S []byte) ShakeHash {
if len(N) == 0 && len(S) == 0 {
return NewShake256()
}
return newCShake(N, S, rate256, 64, dsbyteCShake)
return newCShake(N, S, rateK512, 64, dsbyteCShake)
}

// ShakeSum128 writes an arbitrary-length digest of data into hash.
Expand Down

0 comments on commit 750a45f

Please sign in to comment.