diff --git a/openpgp/mlkem_ecdh/mlkem_ecdh.go b/openpgp/mlkem_ecdh/mlkem_ecdh.go index a8fafb655..ea07c79c3 100644 --- a/openpgp/mlkem_ecdh/mlkem_ecdh.go +++ b/openpgp/mlkem_ecdh/mlkem_ecdh.go @@ -19,6 +19,7 @@ import ( const ( maxSessionKeyLength = 64 MlKemSeedLen = 64 + kdfContext = "OpenPGPCompositeKDFv1" ) type PublicKey struct { @@ -138,16 +139,11 @@ func Decrypt(priv *PrivateKey, kEphemeral, ecEphemeral, ciphertext []byte) (msg return keywrap.Unwrap(kek, ciphertext) } -// buildKey implements the composite KDF 2a from -// https://mailarchive.ietf.org/arch/msg/openpgp/NMTCy707LICtxIhP3Xt1U5C8MF0/ +// buildKey implements the composite KDF from +// https://github.com/openpgp-pqc/draft-openpgp-pqc/pull/161 func buildKey(pub *PublicKey, eccSecretPoint, eccEphemeral, eccPublicKey, mlkemKeyShare, mlkemEphemeral []byte, mlkemPublicKey kem.PublicKey) ([]byte, error) { - h := sha3.New256() - - // SHA3 never returns error - _, _ = h.Write(eccSecretPoint) - _, _ = h.Write(eccEphemeral) - _, _ = h.Write(eccPublicKey) - eccKeyShare := h.Sum(nil) + /// Set the output `ecdhKeyShare` to `eccSecretPoint` + eccKeyShare := eccSecretPoint serializedMlkemPublicKey, err := mlkemPublicKey.MarshalBinary() if err != nil { @@ -162,9 +158,9 @@ func buildKey(pub *PublicKey, eccSecretPoint, eccEphemeral, eccPublicKey, mlkemK // eccEphemeral - the ECDH ciphertext encoded as an octet string // eccPublicKey - The ECDH public key of the recipient as an octet string - // 2a. SHA3-256(mlkemKeyShare || eccKeyShare || eccEphemeral || eccPublicKey || Domain) - // where Domain is "Domain" for LAMPS, and "mlkemEphemeral || mlkemPublicKey || algId" for OpenPGP - h.Reset() + // SHA3-256(mlkemKeyShare || eccKeyShare || eccEphemeral || eccPublicKey || + // mlkemEphemeral || mlkemPublicKey || algId || "OpenPGPCompositeKDFv1") + h := sha3.New256() _, _ = h.Write(mlkemKeyShare) _, _ = h.Write(eccKeyShare) _, _ = h.Write(eccEphemeral) @@ -172,6 +168,7 @@ func buildKey(pub *PublicKey, eccSecretPoint, eccEphemeral, eccPublicKey, mlkemK _, _ = h.Write(mlkemEphemeral) _, _ = h.Write(serializedMlkemPublicKey) _, _ = h.Write([]byte{pub.AlgId}) + _, _ = h.Write([]byte(kdfContext)) return h.Sum(nil), nil }