Skip to content

Commit

Permalink
port ecdh from boring
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal authored Nov 23, 2022
1 parent 817d9a1 commit f622d02
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 4 deletions.
214 changes: 214 additions & 0 deletions openssl/ecdh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

//go:build linux && !android
// +build linux,!android

package openssl

// #include "goopenssl.h"
import "C"
import (
"errors"
"runtime"
)

type PublicKeyECDH struct {
curve string
key C.GO_EC_POINT_PTR
bytes []byte
}

func (k *PublicKeyECDH) finalize() {
C.go_openssl_EC_POINT_free(k.key)
}

type PrivateKeyECDH struct {
curve string
key C.GO_EC_KEY_PTR
}

func (k *PrivateKeyECDH) finalize() {
C.go_openssl_EC_KEY_free(k.key)
}

func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
if len(bytes) < 1 {
return nil, errors.New("NewPublicKeyECDH: missing key")
}

nid, err := curveNID(curve)
if err != nil {
return nil, err
}

group := C.go_openssl_EC_GROUP_new_by_curve_name(nid)
if group == nil {
return nil, newOpenSSLError("EC_GROUP_new_by_curve_name")
}
defer C.go_openssl_EC_GROUP_free(group)
key := C.go_openssl_EC_POINT_new(group)
if key == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
ok := C.go_openssl_EC_POINT_oct2point(group, key, base(bytes), C.size_t(len(bytes)), nil) != 0
if !ok {
C.go_openssl_EC_POINT_free(key)
return nil, errors.New("point not on curve")
}

k := &PublicKeyECDH{curve, key, append([]byte(nil), bytes...)}
runtime.SetFinalizer(k, (*PublicKeyECDH).finalize)
return k, nil
}

func (k *PublicKeyECDH) Bytes() []byte { return k.bytes }

func NewPrivateKeyECDH(curve string, bytes []byte) (*PrivateKeyECDH, error) {
nid, err := curveNID(curve)
if err != nil {
return nil, err
}
key := C.go_openssl_EC_KEY_new_by_curve_name(nid)
if key == nil {
return nil, newOpenSSLError("EC_KEY_new_by_curve_name")
}
b := bytesToBN(bytes)
ok := b != nil && C.go_openssl_EC_KEY_set_private_key(key, b) != 0
if b != nil {
C.go_openssl_BN_free(b)
}
if !ok {
C.go_openssl_EC_KEY_free(key)
return nil, newOpenSSLError("EC_KEY_set_private_key")
}
k := &PrivateKeyECDH{curve, key}
// Note: Same as in NewPublicKeyECDH regarding finalizer and KeepAlive.
runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
return k, nil
}

func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) {
defer runtime.KeepAlive(k)

group := C.go_openssl_EC_KEY_get0_group(k.key)
if group == nil {
return nil, newOpenSSLError("EC_KEY_get0_group")
}
kbig := C.go_openssl_EC_KEY_get0_private_key(k.key)
if kbig == nil {
return nil, newOpenSSLError("EC_KEY_get0_private_key")
}
pt := C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
if C.go_openssl_EC_POINT_mul(group, pt, kbig, nil, nil, nil) == 0 {
C.go_openssl_EC_POINT_free(pt)
return nil, newOpenSSLError("EC_POINT_mul")
}
bytes, err := pointBytesECDH(k.curve, group, pt)
if err != nil {
C.go_openssl_EC_POINT_free(pt)
return nil, err
}
pub := &PublicKeyECDH{k.curve, pt, bytes}
// Note: Same as in NewPublicKeyECDH regarding finalizer and KeepAlive.
runtime.SetFinalizer(pub, (*PublicKeyECDH).finalize)
return pub, nil
}

func pointBytesECDH(curve string, group C.GO_EC_GROUP_PTR, pt C.GO_EC_POINT_PTR) ([]byte, error) {
out := make([]byte, 1+2*curveSize(curve))
n := C.go_openssl_EC_POINT_point2oct(group, pt, C.GO_POINT_CONVERSION_UNCOMPRESSED, base(out), C.size_t(len(out)), nil)
if int(n) != len(out) {
return nil, newOpenSSLError("EC_POINT_point2oct")
}
return out, nil
}

func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) {
group := C.go_openssl_EC_KEY_get0_group(priv.key)
if group == nil {
return nil, newOpenSSLError("EC_KEY_get0_group")
}
privBig := C.go_openssl_EC_KEY_get0_private_key(priv.key)
if privBig == nil {
return nil, newOpenSSLError("EC_KEY_get0_private_key")
}
pt := C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
defer C.go_openssl_EC_POINT_free(pt)
if C.go_openssl_EC_POINT_mul(group, pt, nil, pub.key, privBig, nil) == 0 {
return nil, newOpenSSLError("EC_POINT_mul")
}
out, err := xCoordBytesECDH(priv.curve, group, pt)
if err != nil {
return nil, err
}
return out, nil
}

func xCoordBytesECDH(curve string, group C.GO_EC_GROUP_PTR, pt C.GO_EC_POINT_PTR) ([]byte, error) {
big := C.go_openssl_BN_new()
defer C.go_openssl_BN_free(big)
if C.go_openssl_EC_POINT_get_affine_coordinates_GFp(group, pt, big, nil, nil) == 0 {
return nil, newOpenSSLError("EC_POINT_get_affine_coordinates_GFp")
}
return bigBytesECDH(curve, big)
}

func bigBytesECDH(curve string, big C.GO_BIGNUM_PTR) ([]byte, error) {
out := make([]byte, curveSize(curve))
if C.go_openssl_BN_bn2binpad(big, base(out), C.int(len(out))) == 0 {
return nil, newOpenSSLError("BN_bn2binpad")
}
return out, nil
}

func curveSize(curve string) int {
switch curve {
default:
panic("openssl: unknown curve " + curve)
case "P-256":
return 256 / 8
case "P-384":
return 384 / 8
case "P-521":
return (521 + 7) / 8
}
}

func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
pkey, err := generateEVPPKey(C.GO_EVP_PKEY_EC, 0, curve)
if err != nil {
return nil, nil, err
}
defer C.go_openssl_EVP_PKEY_free(pkey)
key := C.go_openssl_EVP_PKEY_get1_EC_KEY(pkey)
if key == nil {
return nil, nil, newOpenSSLError("EVP_PKEY_get1_EC_KEY")
}
group := C.go_openssl_EC_KEY_get0_group(key)
if group == nil {
C.go_openssl_EC_KEY_free(key)
return nil, nil, newOpenSSLError("EC_KEY_get0_group")
}
b := C.go_openssl_EC_KEY_get0_private_key(key)
if b == nil {
C.go_openssl_EC_KEY_free(key)
return nil, nil, newOpenSSLError("EC_KEY_get0_private_key")
}
bytes, err := bigBytesECDH(curve, b)
if err != nil {
C.go_openssl_EC_KEY_free(key)
return nil, nil, err
}

k := &PrivateKeyECDH{curve, key}
// Note: Same as in NewPublicKeyECDH regarding finalizer and KeepAlive.
runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
return k, bytes, nil
}
152 changes: 152 additions & 0 deletions openssl/ecdh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

//go:build linux && !android
// +build linux,!android

package openssl_test

import (
"bytes"
"encoding/hex"
"testing"

"github.com/microsoft/go-crypto-openssl/openssl"
)

// The following tests has been copied from
// https://github.com/golang/go/blob/master/src/crypto/ecdh/ecdh_test.go.

func TestECDH(t *testing.T) {
for _, tt := range []string{"P-256", "P-384", "P-521"} {
t.Run(tt, func(t *testing.T) {
name := tt
aliceKey, alicPrivBytes, err := openssl.GenerateKeyECDH(name)
if err != nil {
t.Fatal(err)
}
bobKey, _, err := openssl.GenerateKeyECDH(name)
if err != nil {
t.Fatal(err)
}

alicePubKeyFromPriv, err := aliceKey.PublicKey()
if err != nil {
t.Fatal(err)
}
alicePubBytes := alicePubKeyFromPriv.Bytes()
want := len(alicPrivBytes)
var got int
if tt == "X25519" {
got = len(alicePubBytes)
} else {
got = (len(alicePubBytes) - 1) / 2 // subtract encoding prefix and divide by the number of components
}
if want != got {
t.Fatalf("public key size mismatch: want: %v, got: %v", want, got)
}
alicePubKey, err := openssl.NewPublicKeyECDH(name, alicePubBytes)
if err != nil {
t.Error(err)
}

bobPubKeyFromPriv, err := bobKey.PublicKey()
if err != nil {
t.Error(err)
}
_, err = openssl.NewPublicKeyECDH(name, bobPubKeyFromPriv.Bytes())
if err != nil {
t.Error(err)
}

bobSecret, err := openssl.ECDH(bobKey, alicePubKey)
if err != nil {
t.Fatal(err)
}
aliceSecret, err := openssl.ECDH(aliceKey, bobPubKeyFromPriv)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(bobSecret, aliceSecret) {
t.Error("two ECDH computations came out different")
}
})
}
}

var ecdhvectors = []struct {
Name string
PrivateKey, PublicKey string
PeerPublicKey string
SharedSecret string
}{
// NIST vectors from CAVS 14.1, ECC CDH Primitive (SP800-56A).
{
Name: "P-256",
PrivateKey: "7d7dc5f71eb29ddaf80d6214632eeae03d9058af1fb6d22ed80badb62bc1a534",
PublicKey: "04ead218590119e8876b29146ff89ca61770c4edbbf97d38ce385ed281d8a6b230" +
"28af61281fd35e2fa7002523acc85a429cb06ee6648325389f59edfce1405141",
PeerPublicKey: "04700c48f77f56584c5cc632ca65640db91b6bacce3a4df6b42ce7cc838833d287" +
"db71e509e3fd9b060ddb20ba5c51dcc5948d46fbf640dfe0441782cab85fa4ac",
SharedSecret: "46fc62106420ff012e54a434fbdd2d25ccc5852060561e68040dd7778997bd7b",
},
{
Name: "P-384",
PrivateKey: "3cc3122a68f0d95027ad38c067916ba0eb8c38894d22e1b15618b6818a661774ad463b205da88cf699ab4d43c9cf98a1",
PublicKey: "049803807f2f6d2fd966cdd0290bd410c0190352fbec7ff6247de1302df86f25d34fe4a97bef60cff548355c015dbb3e5f" +
"ba26ca69ec2f5b5d9dad20cc9da711383a9dbe34ea3fa5a2af75b46502629ad54dd8b7d73a8abb06a3a3be47d650cc99",
PeerPublicKey: "04a7c76b970c3b5fe8b05d2838ae04ab47697b9eaf52e764592efda27fe7513272734466b400091adbf2d68c58e0c50066" +
"ac68f19f2e1cb879aed43a9969b91a0839c4c38a49749b661efedf243451915ed0905a32b060992b468c64766fc8437a",
SharedSecret: "5f9d29dc5e31a163060356213669c8ce132e22f57c9a04f40ba7fcead493b457e5621e766c40a2e3d4d6a04b25e533f1",
},
// For some reason all field elements in the test vector (both scalars and
// base field elements), but not the shared secret output, have two extra
// leading zero bytes (which in big-endian are irrelevant). Removed here.
{
Name: "P-521",
PrivateKey: "017eecc07ab4b329068fba65e56a1f8890aa935e57134ae0ffcce802735151f4eac6564f6ee9974c5e6887a1fefee5743ae2241bfeb95d5ce31ddcb6f9edb4d6fc47",
PublicKey: "0400602f9d0cf9e526b29e22381c203c48a886c2b0673033366314f1ffbcba240ba42f4ef38a76174635f91e6b4ed34275eb01c8467d05ca80315bf1a7bbd945f550a5" +
"01b7c85f26f5d4b2d7355cf6b02117659943762b6d1db5ab4f1dbc44ce7b2946eb6c7de342962893fd387d1b73d7a8672d1f236961170b7eb3579953ee5cdc88cd2d",
PeerPublicKey: "0400685a48e86c79f0f0875f7bc18d25eb5fc8c0b07e5da4f4370f3a9490340854334b1e1b87fa395464c60626124a4e70d0f785601d37c09870ebf176666877a2046d" +
"01ba52c56fc8776d9e8f5db4f0cc27636d0b741bbe05400697942e80b739884a83bde99e0f6716939e632bc8986fa18dccd443a348b6c3e522497955a4f3c302f676",
SharedSecret: "005fc70477c3e63bc3954bd0df3ea0d1f41ee21746ed95fc5e1fdf90930d5e136672d72cc770742d1711c3c3a4c334a0ad9759436a4d3c5bf6e74b9578fac148c831",
},
}

func TestVectors(t *testing.T) {
for _, tt := range ecdhvectors {
t.Run(tt.Name, func(t *testing.T) {
key, err := openssl.NewPrivateKeyECDH(tt.Name, hexDecode(t, tt.PrivateKey))
if err != nil {
t.Fatal(err)
}
pub, err := key.PublicKey()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(pub.Bytes(), hexDecode(t, tt.PublicKey)) {
t.Error("public key derived from the private key does not match")
}
peer, err := openssl.NewPublicKeyECDH(tt.Name, hexDecode(t, tt.PeerPublicKey))
if err != nil {
t.Fatal(err)
}
secret, err := openssl.ECDH(key, peer)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(secret, hexDecode(t, tt.SharedSecret)) {
t.Error("shared secret does not match")
}
})
}
}

func hexDecode(t *testing.T, s string) []byte {
b, err := hex.DecodeString(s)
if err != nil {
t.Fatal("invalid hex string:", s)
}
return b
}
4 changes: 2 additions & 2 deletions openssl/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func newECKey(curve string, X, Y, D BigInt) (pkey C.GO_EVP_PKEY_PTR, err error)
bx = bigToBN(X)
by = bigToBN(Y)
if bx == nil || by == nil {
return nil, newOpenSSLError("BN_bin2bn failed")
return nil, newOpenSSLError("BN_lebin2bn failed")
}
if key = C.go_openssl_EC_KEY_new_by_curve_name(nid); key == nil {
return nil, newOpenSSLError("EC_KEY_new_by_curve_name failed")
Expand All @@ -113,7 +113,7 @@ func newECKey(curve string, X, Y, D BigInt) (pkey C.GO_EVP_PKEY_PTR, err error)
if D != nil {
bd := bigToBN(D)
if bd == nil {
return nil, newOpenSSLError("BN_bin2bn failed")
return nil, newOpenSSLError("BN_lebin2bn failed")
}
defer C.go_openssl_BN_free(bd)
if C.go_openssl_EC_KEY_set_private_key(key, bd) != 1 {
Expand Down
7 changes: 7 additions & 0 deletions openssl/openssl.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,13 @@ func wbase(b BigInt) *C.uchar {
return (*C.uchar)(unsafe.Pointer(&b[0]))
}

func bytesToBN(x []byte) C.GO_BIGNUM_PTR {
if len(x) == 0 {
return nil
}
return C.go_openssl_BN_bin2bn(base(x), C.int(len(x)), nil)
}

func bigToBN(x BigInt) C.GO_BIGNUM_PTR {
if len(x) == 0 {
return nil
Expand Down
Loading

0 comments on commit f622d02

Please sign in to comment.