diff --git a/openpgp/kyber_ecdh/kyber_ecdh.go b/openpgp/kyber_ecdh/kyber_ecdh.go index 058f7b8a0..096cca143 100644 --- a/openpgp/kyber_ecdh/kyber_ecdh.go +++ b/openpgp/kyber_ecdh/kyber_ecdh.go @@ -4,6 +4,7 @@ package kyber_ecdh import ( goerrors "errors" + "golang.org/x/crypto/sha3" "io" "github.com/ProtonMail/go-crypto/internal/kmac" @@ -52,7 +53,7 @@ func GenerateKey(rand io.Reader, algId uint8, c ecc.ECDHCurve, k kem.Scheme) (pr // Encrypt implements Kyber + ECC encryption as specified in // https://www.ietf.org/archive/id/draft-wussler-openpgp-pqc-00.html#section-4.2.3 -func Encrypt(rand io.Reader, pub *PublicKey, msg, publicKeyHash []byte) (kEphemeral, ecEphemeral, ciphertext []byte, err error) { +func Encrypt(rand io.Reader, pub *PublicKey, msg []byte) (kEphemeral, ecEphemeral, ciphertext []byte, err error) { if len(msg) > 64 { return nil, nil, nil, goerrors.New("kyber_ecdh: session key too long") } @@ -79,7 +80,7 @@ func Encrypt(rand io.Reader, pub *PublicKey, msg, publicKeyHash []byte) (kEpheme return nil, nil, nil, err } - z, err := buildKey(pub, ecSS, kSS, publicKeyHash) + z, err := buildKey(pub, ecSS, ecEphemeral, pub.PublicPoint, kSS) if err != nil { return nil, nil, nil, err } @@ -93,7 +94,7 @@ func Encrypt(rand io.Reader, pub *PublicKey, msg, publicKeyHash []byte) (kEpheme // Decrypt implements Kyber + ECC decryption as specified in // https://www.ietf.org/archive/id/draft-wussler-openpgp-pqc-00.html#section-4.2.4 -func Decrypt(priv *PrivateKey, kEphemeral, ecEphemeral, ciphertext, publicKeyHash []byte) (msg []byte, err error) { +func Decrypt(priv *PrivateKey, kEphemeral, ecEphemeral, ciphertext []byte) (msg []byte, err error) { // EC shared secret derivation ecSS, err := priv.PublicKey.Curve.Decaps(ecEphemeral, priv.SecretEC) if err != nil { @@ -106,7 +107,7 @@ func Decrypt(priv *PrivateKey, kEphemeral, ecEphemeral, ciphertext, publicKeyHas return nil, err } - z, err := buildKey(&priv.PublicKey, ecSS, kSS, publicKeyHash) + z, err := buildKey(&priv.PublicKey, ecSS, ecEphemeral, priv.PublicPoint, kSS) if err != nil { return nil, err } @@ -117,9 +118,17 @@ func Decrypt(priv *PrivateKey, kEphemeral, ecEphemeral, ciphertext, publicKeyHas } // buildKey implements the composite KDF as specified in -// https://www.ietf.org/archive/id/draft-wussler-openpgp-pqc-00.html#section-4.2.2 +// https://www.ietf.org/archive/id/draft-wussler-openpgp-pqc-03.html#section-5.2.2 // Note: the domain separation has been already updated -func buildKey(pub *PublicKey, eccKeyShare, kyberKeyShare, publicKeyHash []byte) ([]byte, error) { +func buildKey(pub *PublicKey, eccSecretPoint, eccEphemeral, eccPublicKey, kyberKeyShare []byte) ([]byte, error) { + h := sha3.New256() + + // SHA3 never returns error + _, _ = h.Write(eccSecretPoint) + _, _ = h.Write(eccEphemeral) + _, _ = h.Write(eccPublicKey) + eccKeyShare := h.Sum(nil) + // fixedInfo = algID || SHA3-256(publicKey) // encKeyShares = counter || eccKeyShare || kyberKeyShare || fixedInfo // MB = KMAC256(domSeparation, encKeyShares, oBits, customizationString) @@ -130,7 +139,6 @@ func buildKey(pub *PublicKey, eccKeyShare, kyberKeyShare, publicKeyHash []byte) _, _ = k.Write(eccKeyShare) _, _ = k.Write(kyberKeyShare) _, _ = k.Write([]byte{pub.AlgId}) - _, _ = k.Write(publicKeyHash) return k.Sum(nil), nil } diff --git a/openpgp/kyber_ecdh/kyber_ecdh_test.go b/openpgp/kyber_ecdh/kyber_ecdh_test.go index d9f52512f..c6803684c 100644 --- a/openpgp/kyber_ecdh/kyber_ecdh_test.go +++ b/openpgp/kyber_ecdh/kyber_ecdh_test.go @@ -11,9 +11,6 @@ import ( ) func TestEncryptDecrypt(t *testing.T) { - randomData := make([]byte, 32) - rand.Read(randomData) - asymmAlgos := map[string] packet.PublicKeyAlgorithm { "Kyber768_X25519": packet.PubKeyAlgoKyber768X25519, "Kyber1024_X448": packet.PubKeyAlgoKyber1024X448, @@ -34,7 +31,7 @@ func TestEncryptDecrypt(t *testing.T) { key := testGenerateKeyAlgo(t, asymmAlgo) for symmName, symmAlgo := range symmAlgos { t.Run(symmName, func(t *testing.T) { - testEncryptDecryptAlgo(t, key, randomData, symmAlgo) + testEncryptDecryptAlgo(t, key, symmAlgo) }) } testvalidateAlgo(t, asymmAlgo) @@ -91,16 +88,16 @@ func testGenerateKeyAlgo(t *testing.T, algId packet.PublicKeyAlgorithm) *kyber_e return priv } -func testEncryptDecryptAlgo(t *testing.T, priv *kyber_ecdh.PrivateKey, publicKeyHash []byte, kdfCipher algorithm.Cipher) { +func testEncryptDecryptAlgo(t *testing.T, priv *kyber_ecdh.PrivateKey, kdfCipher algorithm.Cipher) { expectedMessage := make([]byte, kdfCipher.KeySize()) // encryption algo + checksum rand.Read(expectedMessage) - kE, ecE, c, err := kyber_ecdh.Encrypt(rand.Reader, &priv.PublicKey, expectedMessage, publicKeyHash) + kE, ecE, c, err := kyber_ecdh.Encrypt(rand.Reader, &priv.PublicKey, expectedMessage) if err != nil { t.Errorf("error encrypting: %s", err) } - decryptedMessage, err := kyber_ecdh.Decrypt(priv, kE, ecE, c, publicKeyHash) + decryptedMessage, err := kyber_ecdh.Decrypt(priv, kE, ecE, c) if err != nil { t.Errorf("error decrypting: %s", err) } diff --git a/openpgp/packet/encrypted_key.go b/openpgp/packet/encrypted_key.go index 8a9751b5c..cf9a0ef9f 100644 --- a/openpgp/packet/encrypted_key.go +++ b/openpgp/packet/encrypted_key.go @@ -11,7 +11,6 @@ import ( "encoding/binary" "encoding/hex" "github.com/ProtonMail/go-crypto/openpgp/kyber_ecdh" - "golang.org/x/crypto/sha3" "io" "math/big" "strconv" @@ -250,13 +249,8 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error { ecE := e.encryptedMPI1.Bytes() kE := e.encryptedMPI2.Bytes() m := e.encryptedMPI3.Bytes() - h := sha3.New256() - err = priv.PublicKey.SerializeForHash(h) - if err != nil { - break - } - b, err = kyber_ecdh.Decrypt(priv.PrivateKey.(*kyber_ecdh.PrivateKey), kE, ecE, m, h.Sum(nil)) + b, err = kyber_ecdh.Decrypt(priv.PrivateKey.(*kyber_ecdh.PrivateKey), kE, ecE, m) default: err = errors.InvalidArgumentError("cannot decrypt encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo))) } @@ -678,9 +672,7 @@ func serializeEncryptedKeyKyber(w io.Writer, rand io.Reader, header []byte, pub return errors.UnsupportedError("cannot create a non-v6 kyber_ecdh pkesk") } - h := sha3.New256() - publicKey.SerializeForHash(h) - kE, ecE, c, err := kyber_ecdh.Encrypt(rand, pub, keyBlock, h.Sum(nil)) + kE, ecE, c, err := kyber_ecdh.Encrypt(rand, pub, keyBlock) if err != nil { return errors.InvalidArgumentError("kyber_ecdh encryption failed: " + err.Error()) }